1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport collections 4*da0073e9SAndroid Build Coastguard Workerimport contextlib 5*da0073e9SAndroid Build Coastguard Workerimport copy 6*da0073e9SAndroid Build Coastguard Workerimport itertools 7*da0073e9SAndroid Build Coastguard Workerimport os 8*da0073e9SAndroid Build Coastguard Workerimport tempfile 9*da0073e9SAndroid Build Coastguard Workerimport traceback 10*da0073e9SAndroid Build Coastguard Workerimport types 11*da0073e9SAndroid Build Coastguard Workerimport unittest 12*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy 13*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 14*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, NamedTuple, Tuple 15*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerimport torch 18*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 19*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 20*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 21*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.debug_utils import same_two_models 22*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.eval_frame import unsupported 23*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.mutation_guard import GenerationTracker 24*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import expectedFailureDynamic, same 25*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable 26*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.modules.lazy import LazyModuleMixin 27*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parameter import Parameter, UninitializedParameter 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Workertry: 31*da0073e9SAndroid Build Coastguard Worker from . import test_functions 32*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 33*da0073e9SAndroid Build Coastguard Worker import test_functions 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker_variable = 0 37*da0073e9SAndroid Build Coastguard Worker_variable1 = 0 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerdef update_global(): 41*da0073e9SAndroid Build Coastguard Worker global _variable, _variable1 42*da0073e9SAndroid Build Coastguard Worker _variable += 1 43*da0073e9SAndroid Build Coastguard Worker _variable1 += 1 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerclass BasicModule(torch.nn.Module): 47*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 48*da0073e9SAndroid Build Coastguard Worker super().__init__() 49*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(10, 10) 50*da0073e9SAndroid Build Coastguard Worker self.scale = torch.randn(1, 10) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 53*da0073e9SAndroid Build Coastguard Worker return F.relu(self.linear1(x)) * self.scale 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Workerclass FnMember(torch.nn.Module): 57*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 58*da0073e9SAndroid Build Coastguard Worker super().__init__() 59*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(10, 10) 60*da0073e9SAndroid Build Coastguard Worker self.activation = F.relu 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 63*da0073e9SAndroid Build Coastguard Worker x = self.linear1(x) 64*da0073e9SAndroid Build Coastguard Worker if self.activation: 65*da0073e9SAndroid Build Coastguard Worker x = self.activation(x) 66*da0073e9SAndroid Build Coastguard Worker return x 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Workerclass FnMemberCmp(torch.nn.Module): 70*da0073e9SAndroid Build Coastguard Worker def __init__(self, activation): 71*da0073e9SAndroid Build Coastguard Worker super().__init__() 72*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(10, 10) 73*da0073e9SAndroid Build Coastguard Worker self.activation = activation 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 76*da0073e9SAndroid Build Coastguard Worker x = self.linear1(x) 77*da0073e9SAndroid Build Coastguard Worker if self.activation is not None: 78*da0073e9SAndroid Build Coastguard Worker x = self.activation(x) 79*da0073e9SAndroid Build Coastguard Worker if self.activation is None: 80*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 81*da0073e9SAndroid Build Coastguard Worker return x 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Workerclass SubmoduleExample(torch.nn.Module): 85*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 86*da0073e9SAndroid Build Coastguard Worker super().__init__() 87*da0073e9SAndroid Build Coastguard Worker self.layer1 = BasicModule() 88*da0073e9SAndroid Build Coastguard Worker self.layer2 = BasicModule() 89*da0073e9SAndroid Build Coastguard Worker self.scale = torch.randn(1, 10) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 92*da0073e9SAndroid Build Coastguard Worker x = self.layer1(x) 93*da0073e9SAndroid Build Coastguard Worker x = self.layer2(x) 94*da0073e9SAndroid Build Coastguard Worker return x * self.scale 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Workerclass IsTrainingCheck(torch.nn.Module): 98*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 99*da0073e9SAndroid Build Coastguard Worker super().__init__() 100*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(10, 10) 101*da0073e9SAndroid Build Coastguard Worker self.linear2 = torch.nn.Linear(10, 10) 102*da0073e9SAndroid Build Coastguard Worker self.train(True) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 105*da0073e9SAndroid Build Coastguard Worker if self.training: 106*da0073e9SAndroid Build Coastguard Worker mod = self.linear1 107*da0073e9SAndroid Build Coastguard Worker else: 108*da0073e9SAndroid Build Coastguard Worker mod = self.linear2 109*da0073e9SAndroid Build Coastguard Worker return F.relu(mod(x)) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Workerclass IsEvalCheck(IsTrainingCheck): 113*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 114*da0073e9SAndroid Build Coastguard Worker super().__init__() 115*da0073e9SAndroid Build Coastguard Worker self.train(False) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Workerclass ModuleMethodCall(torch.nn.Module): 119*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 120*da0073e9SAndroid Build Coastguard Worker super().__init__() 121*da0073e9SAndroid Build Coastguard Worker self.layer1 = BasicModule() 122*da0073e9SAndroid Build Coastguard Worker self.layer2 = BasicModule() 123*da0073e9SAndroid Build Coastguard Worker self.scale = torch.randn(1, 10) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker def call_and_scale(self, mod, x): 126*da0073e9SAndroid Build Coastguard Worker x = mod(x) 127*da0073e9SAndroid Build Coastguard Worker return x * self.scale 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 130*da0073e9SAndroid Build Coastguard Worker x1 = self.call_and_scale(self.layer1, x) 131*da0073e9SAndroid Build Coastguard Worker x2 = self.call_and_scale(self.layer2, x) 132*da0073e9SAndroid Build Coastguard Worker return x1 + x2 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Workerclass UnsupportedMethodCall(torch.nn.Module): 136*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 137*da0073e9SAndroid Build Coastguard Worker super().__init__() 138*da0073e9SAndroid Build Coastguard Worker self.layer1 = BasicModule() 139*da0073e9SAndroid Build Coastguard Worker self.scale = torch.randn(1, 10) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker def call_and_scale(self, mod, x): 142*da0073e9SAndroid Build Coastguard Worker x = mod(x) 143*da0073e9SAndroid Build Coastguard Worker x = x * self.scale 144*da0073e9SAndroid Build Coastguard Worker return unsupported(x, x) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 147*da0073e9SAndroid Build Coastguard Worker x1 = self.call_and_scale(self.layer1, x) 148*da0073e9SAndroid Build Coastguard Worker return x + x1 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Workerclass UnsupportedModule(torch.nn.Module): 152*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 153*da0073e9SAndroid Build Coastguard Worker super().__init__() 154*da0073e9SAndroid Build Coastguard Worker self.layer1 = BasicModule() 155*da0073e9SAndroid Build Coastguard Worker self.scale = torch.randn(1, 10) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 158*da0073e9SAndroid Build Coastguard Worker x = self.layer1(x) * self.scale 159*da0073e9SAndroid Build Coastguard Worker return unsupported(x, x) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Workerclass UnsupportedModuleCall(torch.nn.Module): 163*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 164*da0073e9SAndroid Build Coastguard Worker super().__init__() 165*da0073e9SAndroid Build Coastguard Worker self.mod = UnsupportedModule() 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 168*da0073e9SAndroid Build Coastguard Worker return 1 + self.mod(x * 1.5) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Workerclass ModuleWithStaticForward(torch.nn.Module): 172*da0073e9SAndroid Build Coastguard Worker @staticmethod 173*da0073e9SAndroid Build Coastguard Worker def forward(x): 174*da0073e9SAndroid Build Coastguard Worker return x * torch.sigmoid(x) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Workerclass ModuleCallModuleWithStaticForward(torch.nn.Module): 178*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 179*da0073e9SAndroid Build Coastguard Worker super().__init__() 180*da0073e9SAndroid Build Coastguard Worker self.mod = ModuleWithStaticForward() 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 183*da0073e9SAndroid Build Coastguard Worker return self.mod(x) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Workerclass ModuleStaticMethodCall(torch.nn.Module): 187*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 188*da0073e9SAndroid Build Coastguard Worker super().__init__() 189*da0073e9SAndroid Build Coastguard Worker self.layer1 = BasicModule() 190*da0073e9SAndroid Build Coastguard Worker self.layer2 = BasicModule() 191*da0073e9SAndroid Build Coastguard Worker self.scale = torch.randn(1, 10) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker @staticmethod 194*da0073e9SAndroid Build Coastguard Worker def call_and_scale(scale, mod, x): 195*da0073e9SAndroid Build Coastguard Worker x = mod(x) 196*da0073e9SAndroid Build Coastguard Worker return x * scale 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 199*da0073e9SAndroid Build Coastguard Worker x1 = self.call_and_scale(self.scale, self.layer1, x) 200*da0073e9SAndroid Build Coastguard Worker x2 = self.call_and_scale(self.scale, self.layer2, x) 201*da0073e9SAndroid Build Coastguard Worker return x1 + x2 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Workerclass ModuleClassMethodCall(torch.nn.Module): 205*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 206*da0073e9SAndroid Build Coastguard Worker super().__init__() 207*da0073e9SAndroid Build Coastguard Worker self.layer1 = BasicModule() 208*da0073e9SAndroid Build Coastguard Worker self.layer2 = BasicModule() 209*da0073e9SAndroid Build Coastguard Worker self.scale = torch.randn(1, 10) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker @classmethod 212*da0073e9SAndroid Build Coastguard Worker def call_and_scale(cls, scale, mod, x): 213*da0073e9SAndroid Build Coastguard Worker x = mod(x) 214*da0073e9SAndroid Build Coastguard Worker return x * scale 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 217*da0073e9SAndroid Build Coastguard Worker x1 = self.call_and_scale(self.scale, self.layer1, x) 218*da0073e9SAndroid Build Coastguard Worker x2 = self.call_and_scale(self.scale, self.layer2, x) 219*da0073e9SAndroid Build Coastguard Worker return x1 + x2 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Workerclass ModuleProperty(torch.nn.Module): 223*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 224*da0073e9SAndroid Build Coastguard Worker super().__init__() 225*da0073e9SAndroid Build Coastguard Worker self.scale = torch.randn(1, 10) 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker @property 228*da0073e9SAndroid Build Coastguard Worker def scale_alias(self): 229*da0073e9SAndroid Build Coastguard Worker return self.scale 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 232*da0073e9SAndroid Build Coastguard Worker return x * self.scale_alias 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Workerclass NestedModuleList(torch.nn.Module): 236*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 237*da0073e9SAndroid Build Coastguard Worker super().__init__() 238*da0073e9SAndroid Build Coastguard Worker self.layers = torch.nn.ModuleList([]) 239*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 240*da0073e9SAndroid Build Coastguard Worker self.layers.append( 241*da0073e9SAndroid Build Coastguard Worker torch.nn.ModuleList( 242*da0073e9SAndroid Build Coastguard Worker [ 243*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 244*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 245*da0073e9SAndroid Build Coastguard Worker ] 246*da0073e9SAndroid Build Coastguard Worker ) 247*da0073e9SAndroid Build Coastguard Worker ) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 250*da0073e9SAndroid Build Coastguard Worker for layer, act in self.layers: 251*da0073e9SAndroid Build Coastguard Worker x = act(layer(x)) 252*da0073e9SAndroid Build Coastguard Worker return x 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Workerclass ConstLoop(torch.nn.Module): 256*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 257*da0073e9SAndroid Build Coastguard Worker super().__init__() 258*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(10, 10) 259*da0073e9SAndroid Build Coastguard Worker self.count = 3 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 262*da0073e9SAndroid Build Coastguard Worker for i in range(self.count): 263*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(self.linear1(x)) 264*da0073e9SAndroid Build Coastguard Worker return x 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Workerclass ViaModuleCall(torch.nn.Module): 268*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 269*da0073e9SAndroid Build Coastguard Worker super().__init__() 270*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(10, 10) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 273*da0073e9SAndroid Build Coastguard Worker return test_functions.constant3(torch.sigmoid(self.linear1(x)), x) 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Workerclass IsNoneLayer(torch.nn.Module): 277*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 278*da0073e9SAndroid Build Coastguard Worker super().__init__() 279*da0073e9SAndroid Build Coastguard Worker self.layer1 = torch.nn.Linear(10, 10) 280*da0073e9SAndroid Build Coastguard Worker self.layer2 = None 281*da0073e9SAndroid Build Coastguard Worker self.train(True) 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 284*da0073e9SAndroid Build Coastguard Worker if self.layer1 is not None: 285*da0073e9SAndroid Build Coastguard Worker x = self.layer1(x) 286*da0073e9SAndroid Build Coastguard Worker if self.layer2 is not None: 287*da0073e9SAndroid Build Coastguard Worker x = self.layer2(x) 288*da0073e9SAndroid Build Coastguard Worker return x 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Workerclass LayerList(torch.nn.Module): 292*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 293*da0073e9SAndroid Build Coastguard Worker super().__init__() 294*da0073e9SAndroid Build Coastguard Worker self.layers = [ 295*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 296*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 297*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 298*da0073e9SAndroid Build Coastguard Worker ] 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 301*da0073e9SAndroid Build Coastguard Worker for layer in self.layers: 302*da0073e9SAndroid Build Coastguard Worker x = layer(x) 303*da0073e9SAndroid Build Coastguard Worker return x 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Workerclass ModuleList(torch.nn.Module): 307*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 308*da0073e9SAndroid Build Coastguard Worker super().__init__() 309*da0073e9SAndroid Build Coastguard Worker self.layers = torch.nn.ModuleList( 310*da0073e9SAndroid Build Coastguard Worker [ 311*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 312*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 313*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 314*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 315*da0073e9SAndroid Build Coastguard Worker ] 316*da0073e9SAndroid Build Coastguard Worker ) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 319*da0073e9SAndroid Build Coastguard Worker for i in range(len(self.layers)): 320*da0073e9SAndroid Build Coastguard Worker x = self.layers[i](x) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker for layer in self.layers: 323*da0073e9SAndroid Build Coastguard Worker x = layer(x) 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker for layer, val in zip(self.layers, (x, x, x, x)): 326*da0073e9SAndroid Build Coastguard Worker x = layer(x) + val 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker for layer, val in zip(self.layers, (1, 2, 3, 4)): 329*da0073e9SAndroid Build Coastguard Worker x = layer(x) + val 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker for idx, layer in enumerate(self.layers): 332*da0073e9SAndroid Build Coastguard Worker x = layer(x) * idx 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker for idx, layer in enumerate(self.layers[::-1]): 335*da0073e9SAndroid Build Coastguard Worker x = layer(x) * idx 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker return x 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Workerclass CustomGetItemModuleList(torch.nn.Module): 341*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 342*da0073e9SAndroid Build Coastguard Worker super().__init__() 343*da0073e9SAndroid Build Coastguard Worker self.layers = torch.nn.ModuleList( 344*da0073e9SAndroid Build Coastguard Worker [ 345*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 346*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 347*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 348*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 349*da0073e9SAndroid Build Coastguard Worker ] 350*da0073e9SAndroid Build Coastguard Worker ) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, idx: int): 353*da0073e9SAndroid Build Coastguard Worker return self.layers[idx] 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker def __len__(self) -> int: 356*da0073e9SAndroid Build Coastguard Worker return len(self.layers) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 359*da0073e9SAndroid Build Coastguard Worker for i in range(len(self)): 360*da0073e9SAndroid Build Coastguard Worker x = self[i](x) 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker return x 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Workerclass ModuleDict(torch.nn.Module): 366*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 367*da0073e9SAndroid Build Coastguard Worker super().__init__() 368*da0073e9SAndroid Build Coastguard Worker self.layers = torch.nn.ModuleDict( 369*da0073e9SAndroid Build Coastguard Worker { 370*da0073e9SAndroid Build Coastguard Worker "0": torch.nn.Linear(10, 10), 371*da0073e9SAndroid Build Coastguard Worker } 372*da0073e9SAndroid Build Coastguard Worker ) 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 375*da0073e9SAndroid Build Coastguard Worker # TODO(future PR): handle more logic 376*da0073e9SAndroid Build Coastguard Worker x = self.layers["0"](x) 377*da0073e9SAndroid Build Coastguard Worker return x 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Workerclass ParameterDict(torch.nn.Module): 381*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 382*da0073e9SAndroid Build Coastguard Worker super().__init__() 383*da0073e9SAndroid Build Coastguard Worker self.layers = torch.nn.ParameterDict( 384*da0073e9SAndroid Build Coastguard Worker { 385*da0073e9SAndroid Build Coastguard Worker "0": torch.nn.Parameter(torch.randn(10, 10)), 386*da0073e9SAndroid Build Coastguard Worker } 387*da0073e9SAndroid Build Coastguard Worker ) 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 390*da0073e9SAndroid Build Coastguard Worker x = self.layers["0"].mm(x) 391*da0073e9SAndroid Build Coastguard Worker return x 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Workerclass CustomGetItemParameterDict(torch.nn.Module): 395*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 396*da0073e9SAndroid Build Coastguard Worker super().__init__() 397*da0073e9SAndroid Build Coastguard Worker self.layers = torch.nn.ParameterDict( 398*da0073e9SAndroid Build Coastguard Worker { 399*da0073e9SAndroid Build Coastguard Worker "0": torch.nn.Parameter(torch.randn(10, 10)), 400*da0073e9SAndroid Build Coastguard Worker } 401*da0073e9SAndroid Build Coastguard Worker ) 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, key: str) -> torch.nn.Module: 404*da0073e9SAndroid Build Coastguard Worker return self.layers[key] 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 407*da0073e9SAndroid Build Coastguard Worker x = self["0"].mm(x) 408*da0073e9SAndroid Build Coastguard Worker return x 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Workerclass CustomGetItemModuleDict(torch.nn.Module): 412*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 413*da0073e9SAndroid Build Coastguard Worker super().__init__() 414*da0073e9SAndroid Build Coastguard Worker self.layers = torch.nn.ModuleDict( 415*da0073e9SAndroid Build Coastguard Worker { 416*da0073e9SAndroid Build Coastguard Worker "0": torch.nn.Linear(10, 10), 417*da0073e9SAndroid Build Coastguard Worker } 418*da0073e9SAndroid Build Coastguard Worker ) 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, key: str) -> torch.nn.Module: 421*da0073e9SAndroid Build Coastguard Worker return self.layers[key] 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 424*da0073e9SAndroid Build Coastguard Worker x = self["0"](x) 425*da0073e9SAndroid Build Coastguard Worker return x 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Workerclass TensorList(torch.nn.Module): 429*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 430*da0073e9SAndroid Build Coastguard Worker super().__init__() 431*da0073e9SAndroid Build Coastguard Worker self.layers = ( 432*da0073e9SAndroid Build Coastguard Worker torch.randn((1, 10)), 433*da0073e9SAndroid Build Coastguard Worker torch.randn((10, 1)), 434*da0073e9SAndroid Build Coastguard Worker torch.randn((1, 10)), 435*da0073e9SAndroid Build Coastguard Worker torch.randn((10, 1)), 436*da0073e9SAndroid Build Coastguard Worker ) 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 439*da0073e9SAndroid Build Coastguard Worker for layer in self.layers: 440*da0073e9SAndroid Build Coastguard Worker x = x * layer 441*da0073e9SAndroid Build Coastguard Worker return x 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Workerclass Children(torch.nn.Module): 445*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 446*da0073e9SAndroid Build Coastguard Worker super().__init__() 447*da0073e9SAndroid Build Coastguard Worker self.l1 = torch.nn.Linear(10, 10) 448*da0073e9SAndroid Build Coastguard Worker self.l2 = torch.nn.ReLU() 449*da0073e9SAndroid Build Coastguard Worker self.l3 = torch.nn.Linear(10, 10) 450*da0073e9SAndroid Build Coastguard Worker self.l4 = torch.nn.ReLU() 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 453*da0073e9SAndroid Build Coastguard Worker for block in self.children(): 454*da0073e9SAndroid Build Coastguard Worker x = block(x) 455*da0073e9SAndroid Build Coastguard Worker return x 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Workerclass NamedChildren(torch.nn.Module): 459*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 460*da0073e9SAndroid Build Coastguard Worker super().__init__() 461*da0073e9SAndroid Build Coastguard Worker self.l1 = torch.nn.Linear(10, 10) 462*da0073e9SAndroid Build Coastguard Worker self.l2 = torch.nn.ReLU() 463*da0073e9SAndroid Build Coastguard Worker self.l3 = torch.nn.Linear(10, 10) 464*da0073e9SAndroid Build Coastguard Worker self.l4 = torch.nn.ReLU() 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 467*da0073e9SAndroid Build Coastguard Worker for _, block in self.named_children(): 468*da0073e9SAndroid Build Coastguard Worker x = block(x) 469*da0073e9SAndroid Build Coastguard Worker return x 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Workerclass IntArg(torch.nn.Module): 473*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 474*da0073e9SAndroid Build Coastguard Worker super().__init__() 475*da0073e9SAndroid Build Coastguard Worker self.layer1 = torch.nn.Linear(10, 10) 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker def forward(self, x, offset=1): 478*da0073e9SAndroid Build Coastguard Worker x = F.relu(self.layer1(x)) + offset 479*da0073e9SAndroid Build Coastguard Worker return x 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker 482*da0073e9SAndroid Build Coastguard Workerclass Seq(torch.nn.Module): 483*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 484*da0073e9SAndroid Build Coastguard Worker super().__init__() 485*da0073e9SAndroid Build Coastguard Worker self.layers = torch.nn.Sequential( 486*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 487*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 488*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 489*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 490*da0073e9SAndroid Build Coastguard Worker ) 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 493*da0073e9SAndroid Build Coastguard Worker return self.layers(x) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Workerclass Cfg: 497*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 498*da0073e9SAndroid Build Coastguard Worker self.val = 0.5 499*da0073e9SAndroid Build Coastguard Worker self.count = 3 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Workerclass CfgModule(torch.nn.Module): 503*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 504*da0073e9SAndroid Build Coastguard Worker super().__init__() 505*da0073e9SAndroid Build Coastguard Worker self.cfg = Cfg() 506*da0073e9SAndroid Build Coastguard Worker self.layer = torch.nn.Linear(10, 10) 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 509*da0073e9SAndroid Build Coastguard Worker for i in range(self.cfg.count): 510*da0073e9SAndroid Build Coastguard Worker x = self.layer(x + self.cfg.val) 511*da0073e9SAndroid Build Coastguard Worker return x 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Workerclass StringMember(torch.nn.Module): 515*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 516*da0073e9SAndroid Build Coastguard Worker super().__init__() 517*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(10, 10) 518*da0073e9SAndroid Build Coastguard Worker self.mode = "some_string" 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 521*da0073e9SAndroid Build Coastguard Worker if self.mode == "some_string": 522*da0073e9SAndroid Build Coastguard Worker return F.relu(self.linear1(x)) 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Workerclass _Block(torch.nn.Module): 526*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 527*da0073e9SAndroid Build Coastguard Worker return 1.5 * torch.cat(x, 1) 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Workerclass _DenseBlock(torch.nn.ModuleDict): 531*da0073e9SAndroid Build Coastguard Worker _version = 2 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker def __init__( 534*da0073e9SAndroid Build Coastguard Worker self, 535*da0073e9SAndroid Build Coastguard Worker num_layers: int = 3, 536*da0073e9SAndroid Build Coastguard Worker ) -> None: 537*da0073e9SAndroid Build Coastguard Worker super().__init__() 538*da0073e9SAndroid Build Coastguard Worker for i in range(num_layers): 539*da0073e9SAndroid Build Coastguard Worker self.add_module("denselayer%d" % (i + 1), _Block()) 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker def forward(self, init_features): 542*da0073e9SAndroid Build Coastguard Worker features = [init_features] 543*da0073e9SAndroid Build Coastguard Worker for layer in self.values(): 544*da0073e9SAndroid Build Coastguard Worker new_features = layer(features) 545*da0073e9SAndroid Build Coastguard Worker features.append(new_features) 546*da0073e9SAndroid Build Coastguard Worker return torch.cat(features, 1) 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Workerclass DenseNetBlocks(torch.nn.Module): 550*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 551*da0073e9SAndroid Build Coastguard Worker super().__init__() 552*da0073e9SAndroid Build Coastguard Worker self.layers = _DenseBlock() 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 555*da0073e9SAndroid Build Coastguard Worker return self.layers(x) 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Workerclass MaterializedModule(torch.nn.Module): 559*da0073e9SAndroid Build Coastguard Worker """Once the below lazy module is initialized with its first input, 560*da0073e9SAndroid Build Coastguard Worker it is transformed into this module.""" 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker param: Parameter 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 565*da0073e9SAndroid Build Coastguard Worker super().__init__() 566*da0073e9SAndroid Build Coastguard Worker self.register_parameter("param", None) 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 569*da0073e9SAndroid Build Coastguard Worker return x 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Workerclass LazyModule(LazyModuleMixin, MaterializedModule): 573*da0073e9SAndroid Build Coastguard Worker param: UninitializedParameter 574*da0073e9SAndroid Build Coastguard Worker cls_to_become = MaterializedModule 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 577*da0073e9SAndroid Build Coastguard Worker super().__init__() 578*da0073e9SAndroid Build Coastguard Worker self.param = UninitializedParameter() 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker def initialize_parameters(self, x): 581*da0073e9SAndroid Build Coastguard Worker # force graph break to ensure this was not inlined 582*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 583*da0073e9SAndroid Build Coastguard Worker self.param.materialize(x.shape) 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Workerclass LazyMLP(torch.nn.Module): 587*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 588*da0073e9SAndroid Build Coastguard Worker super().__init__() 589*da0073e9SAndroid Build Coastguard Worker self.fc1 = torch.nn.LazyLinear(10) 590*da0073e9SAndroid Build Coastguard Worker self.relu1 = torch.nn.ReLU() 591*da0073e9SAndroid Build Coastguard Worker self.fc2 = torch.nn.LazyLinear(1) 592*da0073e9SAndroid Build Coastguard Worker self.relu2 = torch.nn.ReLU() 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 595*da0073e9SAndroid Build Coastguard Worker x = self.relu1(self.fc1(input)) 596*da0073e9SAndroid Build Coastguard Worker y = self.relu2(self.fc2(x)) 597*da0073e9SAndroid Build Coastguard Worker return y 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Workerclass MyInput(NamedTuple): 601*da0073e9SAndroid Build Coastguard Worker x: Dict[str, Dict[str, torch.Tensor]] 602*da0073e9SAndroid Build Coastguard Worker y: torch.Tensor 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Workerclass LazyLayerWithNamedTupleInput(LazyModuleMixin, torch.nn.Module): 606*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 607*da0073e9SAndroid Build Coastguard Worker super().__init__() 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker def initialize_parameters(self, input): 610*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 611*da0073e9SAndroid Build Coastguard Worker self._param = torch.nn.Parameter( 612*da0073e9SAndroid Build Coastguard Worker torch.empty(input.x["a"][0].shape).fill_(0.5) 613*da0073e9SAndroid Build Coastguard Worker ) 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 616*da0073e9SAndroid Build Coastguard Worker input = input.x["a"] 617*da0073e9SAndroid Build Coastguard Worker x = 0 618*da0073e9SAndroid Build Coastguard Worker for i in range(len(input)): 619*da0073e9SAndroid Build Coastguard Worker x = x + input[i] 620*da0073e9SAndroid Build Coastguard Worker return x 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Workerclass LazyModuleWithNamedTupleInput(torch.nn.Module): 624*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 625*da0073e9SAndroid Build Coastguard Worker super().__init__() 626*da0073e9SAndroid Build Coastguard Worker self.layer = LazyLayerWithNamedTupleInput() 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 629*da0073e9SAndroid Build Coastguard Worker return self.layer(input) 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Workerclass LazyLayerWithListInput(LazyModuleMixin, torch.nn.Module): 633*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 634*da0073e9SAndroid Build Coastguard Worker super().__init__() 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker def initialize_parameters(self, input): 637*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 638*da0073e9SAndroid Build Coastguard Worker self._param = torch.nn.Parameter(torch.empty(input[0].shape).fill_(0.5)) 639*da0073e9SAndroid Build Coastguard Worker 640*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 641*da0073e9SAndroid Build Coastguard Worker x = 0 642*da0073e9SAndroid Build Coastguard Worker for i in range(len(input)): 643*da0073e9SAndroid Build Coastguard Worker x = x + input[i] 644*da0073e9SAndroid Build Coastguard Worker return x 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Workerclass LazyModuleWithListInput(torch.nn.Module): 648*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 649*da0073e9SAndroid Build Coastguard Worker super().__init__() 650*da0073e9SAndroid Build Coastguard Worker self.layer = LazyLayerWithListInput() 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 653*da0073e9SAndroid Build Coastguard Worker return self.layer(input[:-1]) 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Workerclass LazyModuleWithLazySubmodule(LazyModuleMixin, torch.nn.Module): 657*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 658*da0073e9SAndroid Build Coastguard Worker super().__init__() 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker def initialize_parameters(self, input): 661*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 662*da0073e9SAndroid Build Coastguard Worker self.layer = LazyLayerWithListInput() 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 665*da0073e9SAndroid Build Coastguard Worker return self.layer(x) 666*da0073e9SAndroid Build Coastguard Worker 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Workerclass LazyLayerWithInputs(LazyModuleMixin, torch.nn.Module): 669*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 670*da0073e9SAndroid Build Coastguard Worker super().__init__() 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker def initialize_parameters(self, x, y): 673*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 674*da0073e9SAndroid Build Coastguard Worker self._param_x = torch.nn.Parameter(torch.empty(x[0].shape).fill_(0.5)) 675*da0073e9SAndroid Build Coastguard Worker self._param_y = torch.nn.Parameter(torch.empty(y[0].shape).fill_(0.5)) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 678*da0073e9SAndroid Build Coastguard Worker res_x = 0 679*da0073e9SAndroid Build Coastguard Worker for i in range(len(x)): 680*da0073e9SAndroid Build Coastguard Worker res_x = res_x + x[i] 681*da0073e9SAndroid Build Coastguard Worker res_y = 0 682*da0073e9SAndroid Build Coastguard Worker for i in range(len(y)): 683*da0073e9SAndroid Build Coastguard Worker res_y = res_y + y[i] 684*da0073e9SAndroid Build Coastguard Worker return res_x + res_y 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Workerclass LazyModuleKwArgs(LazyModuleMixin, torch.nn.Module): 688*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 689*da0073e9SAndroid Build Coastguard Worker super().__init__() 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Worker def initialize_parameters(self, *args, **kwargs): 692*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 693*da0073e9SAndroid Build Coastguard Worker self.layer = LazyLayerWithInputs() 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 696*da0073e9SAndroid Build Coastguard Worker return self.layer(x, y=y) 697*da0073e9SAndroid Build Coastguard Worker 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Workerclass LazyParentModule(LazyModuleMixin, torch.nn.Module): 700*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 701*da0073e9SAndroid Build Coastguard Worker super().__init__() 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker def impl(self, x): 704*da0073e9SAndroid Build Coastguard Worker return x.cos() + self._val 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker 707*da0073e9SAndroid Build Coastguard Workerclass LazyChildModuleNoClsToBecome(LazyParentModule): 708*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 709*da0073e9SAndroid Build Coastguard Worker super().__init__() 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 712*da0073e9SAndroid Build Coastguard Worker return super().impl(x.sin()) 713*da0073e9SAndroid Build Coastguard Worker 714*da0073e9SAndroid Build Coastguard Worker def initialize_parameters(self, input): 715*da0073e9SAndroid Build Coastguard Worker self._val = torch.nn.Parameter(torch.ones(2, 2)) 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker 718*da0073e9SAndroid Build Coastguard Workerdef requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool: 719*da0073e9SAndroid Build Coastguard Worker requires_grad = any(p.requires_grad for p in module.parameters(recurse)) 720*da0073e9SAndroid Build Coastguard Worker return requires_grad 721*da0073e9SAndroid Build Coastguard Worker 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Workerdef requires_grad2(module: torch.nn.Module, recurse: bool = False) -> bool: 724*da0073e9SAndroid Build Coastguard Worker requires_grad = any(p.requires_grad for p in module.parameters(recurse)) 725*da0073e9SAndroid Build Coastguard Worker return requires_grad 726*da0073e9SAndroid Build Coastguard Worker 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Workerclass ParametersModule1(torch.nn.Module): 729*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 730*da0073e9SAndroid Build Coastguard Worker super().__init__() 731*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(10, 10) 732*da0073e9SAndroid Build Coastguard Worker self.scale = torch.nn.Parameter(torch.randn(1, 10)) 733*da0073e9SAndroid Build Coastguard Worker 734*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 735*da0073e9SAndroid Build Coastguard Worker if not requires_grad1(self): 736*da0073e9SAndroid Build Coastguard Worker return F.relu(self.linear1(x)) * self.scale 737*da0073e9SAndroid Build Coastguard Worker else: 738*da0073e9SAndroid Build Coastguard Worker return x + 1 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Workerclass ParametersModule2(ParametersModule1): 742*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 743*da0073e9SAndroid Build Coastguard Worker if not requires_grad2(self): 744*da0073e9SAndroid Build Coastguard Worker return F.relu(self.linear1(x)) * self.scale 745*da0073e9SAndroid Build Coastguard Worker else: 746*da0073e9SAndroid Build Coastguard Worker return x + 1 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Workerclass ParametersModule3(ParametersModule1): 750*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 751*da0073e9SAndroid Build Coastguard Worker ones = torch.ones(10, dtype=next(self.parameters()).dtype) 752*da0073e9SAndroid Build Coastguard Worker return F.relu(self.linear1(x)) * self.scale + ones 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Workerclass ParametersModule4(ParametersModule1): 756*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 757*da0073e9SAndroid Build Coastguard Worker ones = torch.ones(10, dtype=next(self.parameters(recurse=False)).dtype) 758*da0073e9SAndroid Build Coastguard Worker return F.relu(self.linear1(x)) * self.scale + ones 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Workerclass ParametersModule5(torch.nn.Module): 762*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 763*da0073e9SAndroid Build Coastguard Worker super().__init__() 764*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(10, 10) 765*da0073e9SAndroid Build Coastguard Worker self.scale = torch.nn.Parameter(torch.randn(10, 10)) 766*da0073e9SAndroid Build Coastguard Worker self.scale_dup = self.scale 767*da0073e9SAndroid Build Coastguard Worker 768*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 769*da0073e9SAndroid Build Coastguard Worker counter = 0 770*da0073e9SAndroid Build Coastguard Worker for param in self.parameters(): 771*da0073e9SAndroid Build Coastguard Worker counter += 1 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker return x * self.scale * counter 774*da0073e9SAndroid Build Coastguard Worker 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Workerclass SuperModule(BasicModule): 777*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 778*da0073e9SAndroid Build Coastguard Worker x = super().forward(x) 779*da0073e9SAndroid Build Coastguard Worker return x + 10.0 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker 782*da0073e9SAndroid Build Coastguard Workerclass SuperModule2(BasicModule): 783*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 784*da0073e9SAndroid Build Coastguard Worker return BasicModule.forward(self, x) 785*da0073e9SAndroid Build Coastguard Worker 786*da0073e9SAndroid Build Coastguard Worker 787*da0073e9SAndroid Build Coastguard Workerclass ComplicatedSuperParent(torch.nn.Module): 788*da0073e9SAndroid Build Coastguard Worker @classmethod 789*da0073e9SAndroid Build Coastguard Worker def custom_add(cls, x): 790*da0073e9SAndroid Build Coastguard Worker x = x + x 791*da0073e9SAndroid Build Coastguard Worker return x 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker 794*da0073e9SAndroid Build Coastguard Workerclass SuperChildCallsClassMethod(ComplicatedSuperParent): 795*da0073e9SAndroid Build Coastguard Worker @classmethod 796*da0073e9SAndroid Build Coastguard Worker def child_func(cls, x): 797*da0073e9SAndroid Build Coastguard Worker x = super().custom_add(x) 798*da0073e9SAndroid Build Coastguard Worker return x 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 801*da0073e9SAndroid Build Coastguard Worker x = self.child_func(x) 802*da0073e9SAndroid Build Coastguard Worker return x 803*da0073e9SAndroid Build Coastguard Worker 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Workerclass HasAttrModule(torch.nn.Module): 806*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 807*da0073e9SAndroid Build Coastguard Worker super().__init__() 808*da0073e9SAndroid Build Coastguard Worker self.scale = torch.nn.Parameter(torch.randn(1, 10)) 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 811*da0073e9SAndroid Build Coastguard Worker x = F.relu(x) 812*da0073e9SAndroid Build Coastguard Worker if hasattr(self, "scale"): 813*da0073e9SAndroid Build Coastguard Worker x *= self.scale 814*da0073e9SAndroid Build Coastguard Worker if hasattr(self, "scale2"): 815*da0073e9SAndroid Build Coastguard Worker x *= self.scale2 816*da0073e9SAndroid Build Coastguard Worker return x 817*da0073e9SAndroid Build Coastguard Worker 818*da0073e9SAndroid Build Coastguard Worker 819*da0073e9SAndroid Build Coastguard Workerclass EnumValues(torch.nn.ModuleDict): 820*da0073e9SAndroid Build Coastguard Worker def __init__( 821*da0073e9SAndroid Build Coastguard Worker self, 822*da0073e9SAndroid Build Coastguard Worker num_layers: int = 3, 823*da0073e9SAndroid Build Coastguard Worker ) -> None: 824*da0073e9SAndroid Build Coastguard Worker super().__init__() 825*da0073e9SAndroid Build Coastguard Worker for i in range(num_layers): 826*da0073e9SAndroid Build Coastguard Worker self.add_module("denselayer%d" % (i + 1), _Block()) 827*da0073e9SAndroid Build Coastguard Worker 828*da0073e9SAndroid Build Coastguard Worker def forward(self, init_features): 829*da0073e9SAndroid Build Coastguard Worker features = [init_features] 830*da0073e9SAndroid Build Coastguard Worker for idx, layer in enumerate(self.values()): 831*da0073e9SAndroid Build Coastguard Worker new_features = layer(features) 832*da0073e9SAndroid Build Coastguard Worker features.append(new_features) 833*da0073e9SAndroid Build Coastguard Worker return torch.cat(features, 1) 834*da0073e9SAndroid Build Coastguard Worker 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Workerclass AccessByKeys(torch.nn.ModuleDict): 837*da0073e9SAndroid Build Coastguard Worker def __init__( 838*da0073e9SAndroid Build Coastguard Worker self, 839*da0073e9SAndroid Build Coastguard Worker num_layers: int = 3, 840*da0073e9SAndroid Build Coastguard Worker ) -> None: 841*da0073e9SAndroid Build Coastguard Worker super().__init__() 842*da0073e9SAndroid Build Coastguard Worker for i in range(num_layers): 843*da0073e9SAndroid Build Coastguard Worker self.add_module("denselayer%d" % (i + 1), _Block()) 844*da0073e9SAndroid Build Coastguard Worker 845*da0073e9SAndroid Build Coastguard Worker def forward(self, init_features): 846*da0073e9SAndroid Build Coastguard Worker features = [init_features] 847*da0073e9SAndroid Build Coastguard Worker for k in self.keys(): 848*da0073e9SAndroid Build Coastguard Worker new_features = self[k](features) 849*da0073e9SAndroid Build Coastguard Worker features.append(new_features) 850*da0073e9SAndroid Build Coastguard Worker return torch.cat(features, 1) 851*da0073e9SAndroid Build Coastguard Worker 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard Workerclass CallForwardDirectly(torch.nn.Module): 854*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 855*da0073e9SAndroid Build Coastguard Worker super().__init__() 856*da0073e9SAndroid Build Coastguard Worker self.layer1 = BasicModule() 857*da0073e9SAndroid Build Coastguard Worker self.layer2 = torch.nn.Linear(10, 10) 858*da0073e9SAndroid Build Coastguard Worker 859*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 860*da0073e9SAndroid Build Coastguard Worker x = self.layer1.forward(x) 861*da0073e9SAndroid Build Coastguard Worker x = self.layer2.forward(x) 862*da0073e9SAndroid Build Coastguard Worker return x 863*da0073e9SAndroid Build Coastguard Worker 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Workerclass ConvCallForwardDirectly(torch.nn.Module): 866*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 867*da0073e9SAndroid Build Coastguard Worker super().__init__() 868*da0073e9SAndroid Build Coastguard Worker self.layer = torch.nn.Conv2d(3, 64, 3, 1, 1, bias=False) 869*da0073e9SAndroid Build Coastguard Worker 870*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 871*da0073e9SAndroid Build Coastguard Worker return self.layer.forward(x) 872*da0073e9SAndroid Build Coastguard Worker 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Workerclass ConvTransposeCallForwardDirectly(torch.nn.Module): 875*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 876*da0073e9SAndroid Build Coastguard Worker super().__init__() 877*da0073e9SAndroid Build Coastguard Worker self.layer = torch.nn.ConvTranspose2d(4, 4, 4) 878*da0073e9SAndroid Build Coastguard Worker 879*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 880*da0073e9SAndroid Build Coastguard Worker return self.layer.forward(x) 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker 883*da0073e9SAndroid Build Coastguard Workerclass ConvCallSuperForwardDirectly(torch.nn.Conv1d): 884*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 885*da0073e9SAndroid Build Coastguard Worker super().__init__( 886*da0073e9SAndroid Build Coastguard Worker in_channels=in_channels, 887*da0073e9SAndroid Build Coastguard Worker out_channels=out_channels, 888*da0073e9SAndroid Build Coastguard Worker kernel_size=kernel_size, 889*da0073e9SAndroid Build Coastguard Worker **kwargs, 890*da0073e9SAndroid Build Coastguard Worker ) 891*da0073e9SAndroid Build Coastguard Worker 892*da0073e9SAndroid Build Coastguard Worker def forward(self, inputs, mask=None): 893*da0073e9SAndroid Build Coastguard Worker outputs = super().forward(inputs) 894*da0073e9SAndroid Build Coastguard Worker return outputs 895*da0073e9SAndroid Build Coastguard Worker 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Workerclass ConvTransposeCallSuperForwardDirectly(torch.nn.ConvTranspose2d): 898*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 899*da0073e9SAndroid Build Coastguard Worker super().__init__( 900*da0073e9SAndroid Build Coastguard Worker in_channels=in_channels, 901*da0073e9SAndroid Build Coastguard Worker out_channels=out_channels, 902*da0073e9SAndroid Build Coastguard Worker kernel_size=kernel_size, 903*da0073e9SAndroid Build Coastguard Worker **kwargs, 904*da0073e9SAndroid Build Coastguard Worker ) 905*da0073e9SAndroid Build Coastguard Worker 906*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 907*da0073e9SAndroid Build Coastguard Worker if x.numel() > 0: 908*da0073e9SAndroid Build Coastguard Worker return super().forward(x) 909*da0073e9SAndroid Build Coastguard Worker output_shape = [ 910*da0073e9SAndroid Build Coastguard Worker ((i - 1) * d - 2 * p + (di * (k - 1) + 1) + op) 911*da0073e9SAndroid Build Coastguard Worker for i, p, di, k, d, op in zip( 912*da0073e9SAndroid Build Coastguard Worker x.shape[-2:], 913*da0073e9SAndroid Build Coastguard Worker self.padding, 914*da0073e9SAndroid Build Coastguard Worker self.dilation, 915*da0073e9SAndroid Build Coastguard Worker self.kernel_size, 916*da0073e9SAndroid Build Coastguard Worker self.stride, 917*da0073e9SAndroid Build Coastguard Worker self.output_padding, 918*da0073e9SAndroid Build Coastguard Worker ) 919*da0073e9SAndroid Build Coastguard Worker ] 920*da0073e9SAndroid Build Coastguard Worker output_shape = [x.shape[0], self.bias.shape[0]] + output_shape 921*da0073e9SAndroid Build Coastguard Worker return _NewEmptyTensorOp.apply(x, output_shape) # noqa: F821 922*da0073e9SAndroid Build Coastguard Worker 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Workerclass ModuleNameString(torch.nn.Module): 925*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 926*da0073e9SAndroid Build Coastguard Worker super().__init__() 927*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(10, 10) 928*da0073e9SAndroid Build Coastguard Worker 929*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 930*da0073e9SAndroid Build Coastguard Worker if self.__class__.__name__ == "ABC": 931*da0073e9SAndroid Build Coastguard Worker return 10 932*da0073e9SAndroid Build Coastguard Worker if self.linear1.__class__.__name__ == "Linear": 933*da0073e9SAndroid Build Coastguard Worker return F.relu(self.linear1(x) + 10) 934*da0073e9SAndroid Build Coastguard Worker return 11 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Worker 937*da0073e9SAndroid Build Coastguard Workerclass SelfMutatingModule(torch.nn.Module): 938*da0073e9SAndroid Build Coastguard Worker def __init__(self, layer): 939*da0073e9SAndroid Build Coastguard Worker super().__init__() 940*da0073e9SAndroid Build Coastguard Worker self.layer = layer 941*da0073e9SAndroid Build Coastguard Worker self.counter = 0 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 944*da0073e9SAndroid Build Coastguard Worker result = self.layer(x) + self.counter 945*da0073e9SAndroid Build Coastguard Worker self.counter += 1 946*da0073e9SAndroid Build Coastguard Worker return F.relu(result) 947*da0073e9SAndroid Build Coastguard Worker 948*da0073e9SAndroid Build Coastguard Worker 949*da0073e9SAndroid Build Coastguard Workerclass ModuleAttributePrecedenceBase(torch.nn.Module): 950*da0073e9SAndroid Build Coastguard Worker def linear(self, x, flag=None): 951*da0073e9SAndroid Build Coastguard Worker if flag: 952*da0073e9SAndroid Build Coastguard Worker return x * 2.0 953*da0073e9SAndroid Build Coastguard Worker return x * 3.0 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker 956*da0073e9SAndroid Build Coastguard Workerclass ModuleAttributePrecedence(ModuleAttributePrecedenceBase): 957*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 958*da0073e9SAndroid Build Coastguard Worker super().__init__() 959*da0073e9SAndroid Build Coastguard Worker self.activation = torch.nn.ReLU() 960*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(10, 10) 961*da0073e9SAndroid Build Coastguard Worker self.initializer = torch.ones([10, 10]) 962*da0073e9SAndroid Build Coastguard Worker self.scale = 0.5 963*da0073e9SAndroid Build Coastguard Worker 964*da0073e9SAndroid Build Coastguard Worker def activation(self, x): 965*da0073e9SAndroid Build Coastguard Worker return x * 1.2 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker def initializer(self): 968*da0073e9SAndroid Build Coastguard Worker return torch.zeros([10, 10]) 969*da0073e9SAndroid Build Coastguard Worker 970*da0073e9SAndroid Build Coastguard Worker def scale(self): 971*da0073e9SAndroid Build Coastguard Worker return 2.0 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 974*da0073e9SAndroid Build Coastguard Worker # object attribute takes precedence unless it's a nn.Module 975*da0073e9SAndroid Build Coastguard Worker return self.activation(self.linear(self.initializer + x)) * self.scale 976*da0073e9SAndroid Build Coastguard Worker 977*da0073e9SAndroid Build Coastguard Worker 978*da0073e9SAndroid Build Coastguard Workerclass ModuleForwardHasGraphBreak(torch.nn.Module): 979*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 980*da0073e9SAndroid Build Coastguard Worker super().__init__() 981*da0073e9SAndroid Build Coastguard Worker self.layer1 = BasicModule() 982*da0073e9SAndroid Build Coastguard Worker self.layer2 = BasicModule() 983*da0073e9SAndroid Build Coastguard Worker self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule()) 984*da0073e9SAndroid Build Coastguard Worker self.layer4 = torch.nn.ModuleList( 985*da0073e9SAndroid Build Coastguard Worker [ 986*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 987*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 988*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 989*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 990*da0073e9SAndroid Build Coastguard Worker ] 991*da0073e9SAndroid Build Coastguard Worker ) 992*da0073e9SAndroid Build Coastguard Worker self.layer5 = torch.nn.ModuleDict( 993*da0073e9SAndroid Build Coastguard Worker { 994*da0073e9SAndroid Build Coastguard Worker "0": torch.nn.Linear(10, 10), 995*da0073e9SAndroid Build Coastguard Worker } 996*da0073e9SAndroid Build Coastguard Worker ) 997*da0073e9SAndroid Build Coastguard Worker self.scale = torch.randn(1, 10) 998*da0073e9SAndroid Build Coastguard Worker 999*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1000*da0073e9SAndroid Build Coastguard Worker """ 1001*da0073e9SAndroid Build Coastguard Worker This is used to test if the results of functions like `named_parameters` 1002*da0073e9SAndroid Build Coastguard Worker can be reconstructed correctly after graph break. 1003*da0073e9SAndroid Build Coastguard Worker 1004*da0073e9SAndroid Build Coastguard Worker https://github.com/pytorch/torchdynamo/issues/1931 1005*da0073e9SAndroid Build Coastguard Worker """ 1006*da0073e9SAndroid Build Coastguard Worker x = self.layer1(x) 1007*da0073e9SAndroid Build Coastguard Worker params1 = dict(self.named_parameters()) 1008*da0073e9SAndroid Build Coastguard Worker params2 = list(self.parameters()) 1009*da0073e9SAndroid Build Coastguard Worker buffers1 = dict(self.named_buffers()) 1010*da0073e9SAndroid Build Coastguard Worker buffers2 = list(self.buffers()) 1011*da0073e9SAndroid Build Coastguard Worker modules1 = dict(self.named_modules()) 1012*da0073e9SAndroid Build Coastguard Worker modules2 = list(self.modules()) 1013*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 1014*da0073e9SAndroid Build Coastguard Worker y = modules2 1015*da0073e9SAndroid Build Coastguard Worker y = modules1 1016*da0073e9SAndroid Build Coastguard Worker y = buffers2 1017*da0073e9SAndroid Build Coastguard Worker y = buffers1 1018*da0073e9SAndroid Build Coastguard Worker y = params2 1019*da0073e9SAndroid Build Coastguard Worker y = params1 1020*da0073e9SAndroid Build Coastguard Worker x = ( 1021*da0073e9SAndroid Build Coastguard Worker self.layer2(x) 1022*da0073e9SAndroid Build Coastguard Worker + y["layer3.1.linear1.weight"] 1023*da0073e9SAndroid Build Coastguard Worker + y["layer4.2.weight"] 1024*da0073e9SAndroid Build Coastguard Worker + y["layer5.0.weight"] 1025*da0073e9SAndroid Build Coastguard Worker ) 1026*da0073e9SAndroid Build Coastguard Worker return x * self.scale 1027*da0073e9SAndroid Build Coastguard Worker 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Workerclass ModuleGuardNameIsValid(torch.nn.ModuleDict): 1030*da0073e9SAndroid Build Coastguard Worker # Guard names should be valid python identifier as we use eval() to get 1031*da0073e9SAndroid Build Coastguard Worker # corresponding guard value. Some guard names come from source(module path) 1032*da0073e9SAndroid Build Coastguard Worker # where special symbols are valid. But they are not valid python identifier, 1033*da0073e9SAndroid Build Coastguard Worker # we should identify these pattern and rewrite them with getattr. 1034*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1035*da0073e9SAndroid Build Coastguard Worker super().__init__() 1036*da0073e9SAndroid Build Coastguard Worker for i in range(2): 1037*da0073e9SAndroid Build Coastguard Worker self.add_module("l@yer-%d" % (i + 1), BasicModule()) 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1040*da0073e9SAndroid Build Coastguard Worker for layer in self.values(): 1041*da0073e9SAndroid Build Coastguard Worker x = layer(x) 1042*da0073e9SAndroid Build Coastguard Worker return x 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker 1045*da0073e9SAndroid Build Coastguard Workerclass SequentialWithDuplicatedModule(torch.nn.Module): 1046*da0073e9SAndroid Build Coastguard Worker # Sequential module(self.layer) contains three duplicated ReLU module. 1047*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1048*da0073e9SAndroid Build Coastguard Worker super().__init__() 1049*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 1050*da0073e9SAndroid Build Coastguard Worker self.layer = torch.nn.Sequential( 1051*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 20), 1052*da0073e9SAndroid Build Coastguard Worker self.relu, 1053*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(20, 20), 1054*da0073e9SAndroid Build Coastguard Worker self.relu, 1055*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(20, 10), 1056*da0073e9SAndroid Build Coastguard Worker self.relu, 1057*da0073e9SAndroid Build Coastguard Worker ) 1058*da0073e9SAndroid Build Coastguard Worker 1059*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1060*da0073e9SAndroid Build Coastguard Worker return self.layer(x) 1061*da0073e9SAndroid Build Coastguard Worker 1062*da0073e9SAndroid Build Coastguard Worker 1063*da0073e9SAndroid Build Coastguard Workerclass SequentialWithDuplicatedModule2(torch.nn.Module): 1064*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1065*da0073e9SAndroid Build Coastguard Worker super().__init__() 1066*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 1067*da0073e9SAndroid Build Coastguard Worker self.layer = torch.nn.Sequential( 1068*da0073e9SAndroid Build Coastguard Worker collections.OrderedDict( 1069*da0073e9SAndroid Build Coastguard Worker [ 1070*da0073e9SAndroid Build Coastguard Worker ("linear1", torch.nn.Linear(10, 20)), 1071*da0073e9SAndroid Build Coastguard Worker ("relu1", self.relu), 1072*da0073e9SAndroid Build Coastguard Worker ("linear2", torch.nn.Linear(20, 20)), 1073*da0073e9SAndroid Build Coastguard Worker ("relu2", self.relu), 1074*da0073e9SAndroid Build Coastguard Worker ("linear3", torch.nn.Linear(20, 10)), 1075*da0073e9SAndroid Build Coastguard Worker ("relu3", self.relu), 1076*da0073e9SAndroid Build Coastguard Worker ] 1077*da0073e9SAndroid Build Coastguard Worker ) 1078*da0073e9SAndroid Build Coastguard Worker ) 1079*da0073e9SAndroid Build Coastguard Worker 1080*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1081*da0073e9SAndroid Build Coastguard Worker return self.layer(x) 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker 1084*da0073e9SAndroid Build Coastguard Workerclass ModuleComparison(torch.nn.Module): 1085*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1086*da0073e9SAndroid Build Coastguard Worker super().__init__() 1087*da0073e9SAndroid Build Coastguard Worker self.layer0 = torch.nn.Linear(10, 10) 1088*da0073e9SAndroid Build Coastguard Worker self.layer1 = torch.nn.Linear(10, 10) 1089*da0073e9SAndroid Build Coastguard Worker self.layer2 = torch.nn.Linear(10, 10) 1090*da0073e9SAndroid Build Coastguard Worker 1091*da0073e9SAndroid Build Coastguard Worker @property 1092*da0073e9SAndroid Build Coastguard Worker def encoder_layers(self): 1093*da0073e9SAndroid Build Coastguard Worker return [self.layer0, self.layer1, self.layer2] 1094*da0073e9SAndroid Build Coastguard Worker 1095*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1096*da0073e9SAndroid Build Coastguard Worker for layer in self.encoder_layers: 1097*da0073e9SAndroid Build Coastguard Worker output = layer(x) 1098*da0073e9SAndroid Build Coastguard Worker if layer is None or layer == self.layer0: 1099*da0073e9SAndroid Build Coastguard Worker output = F.relu6(output) 1100*da0073e9SAndroid Build Coastguard Worker else: 1101*da0073e9SAndroid Build Coastguard Worker output = F.relu(output) 1102*da0073e9SAndroid Build Coastguard Worker return output 1103*da0073e9SAndroid Build Coastguard Worker 1104*da0073e9SAndroid Build Coastguard Worker 1105*da0073e9SAndroid Build Coastguard Workerclass ModulePatch1(torch.nn.Module): 1106*da0073e9SAndroid Build Coastguard Worker pass 1107*da0073e9SAndroid Build Coastguard Worker 1108*da0073e9SAndroid Build Coastguard Worker 1109*da0073e9SAndroid Build Coastguard Workerclass ModulePatch2(torch.nn.Module): 1110*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1111*da0073e9SAndroid Build Coastguard Worker return x - 1 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Worker 1114*da0073e9SAndroid Build Coastguard Workerclass UnspecNonInlinableModule(torch.nn.Module): 1115*da0073e9SAndroid Build Coastguard Worker torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule 1116*da0073e9SAndroid Build Coastguard Worker 1117*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1118*da0073e9SAndroid Build Coastguard Worker if x.sum() > 0: 1119*da0073e9SAndroid Build Coastguard Worker return x + 1 1120*da0073e9SAndroid Build Coastguard Worker else: 1121*da0073e9SAndroid Build Coastguard Worker return x - 1 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker 1124*da0073e9SAndroid Build Coastguard Workerclass UnspecNonInlinableToplevelModule(torch.nn.Module): 1125*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1126*da0073e9SAndroid Build Coastguard Worker super().__init__() 1127*da0073e9SAndroid Build Coastguard Worker self.m = UnspecNonInlinableModule() 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1130*da0073e9SAndroid Build Coastguard Worker return self.m(x) 1131*da0073e9SAndroid Build Coastguard Worker 1132*da0073e9SAndroid Build Coastguard Worker 1133*da0073e9SAndroid Build Coastguard Workerdef make_test(fn, expected_ops=None): 1134*da0073e9SAndroid Build Coastguard Worker def test_fn(self): 1135*da0073e9SAndroid Build Coastguard Worker return torch._dynamo.testing.standard_test( 1136*da0073e9SAndroid Build Coastguard Worker self, fn=fn, nargs=1, expected_ops=expected_ops 1137*da0073e9SAndroid Build Coastguard Worker ) 1138*da0073e9SAndroid Build Coastguard Worker 1139*da0073e9SAndroid Build Coastguard Worker fn.eval() 1140*da0073e9SAndroid Build Coastguard Worker return test_fn 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker 1143*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 1144*da0073e9SAndroid Build Coastguard Workerdef temporary_tensor_subclass(torch_function=None): 1145*da0073e9SAndroid Build Coastguard Worker class TensorProxy(torch.Tensor): 1146*da0073e9SAndroid Build Coastguard Worker @classmethod 1147*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 1148*da0073e9SAndroid Build Coastguard Worker if torch_function is not None: 1149*da0073e9SAndroid Build Coastguard Worker torch_function() 1150*da0073e9SAndroid Build Coastguard Worker return super().__torch_function__(func, types, args, kwargs) 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy) 1153*da0073e9SAndroid Build Coastguard Worker try: 1154*da0073e9SAndroid Build Coastguard Worker yield TensorProxy 1155*da0073e9SAndroid Build Coastguard Worker finally: 1156*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy) 1157*da0073e9SAndroid Build Coastguard Worker 1158*da0073e9SAndroid Build Coastguard Worker 1159*da0073e9SAndroid Build Coastguard Workerclass NNModuleTests(torch._dynamo.test_case.TestCase): 1160*da0073e9SAndroid Build Coastguard Worker test_seq = make_test(Seq()) 1161*da0073e9SAndroid Build Coastguard Worker test_basicmodule1 = make_test(BasicModule()) 1162*da0073e9SAndroid Build Coastguard Worker test_basicmodule2 = make_test(BasicModule()) 1163*da0073e9SAndroid Build Coastguard Worker test_submodules1 = make_test(SubmoduleExample()) 1164*da0073e9SAndroid Build Coastguard Worker test_submodules2 = make_test(SubmoduleExample()) 1165*da0073e9SAndroid Build Coastguard Worker test_modulemethod1 = make_test(ModuleMethodCall()) 1166*da0073e9SAndroid Build Coastguard Worker test_modulemethod2 = make_test(ModuleMethodCall()) 1167*da0073e9SAndroid Build Coastguard Worker test_module_call_module_with_static_forward = make_test( 1168*da0073e9SAndroid Build Coastguard Worker ModuleCallModuleWithStaticForward() 1169*da0073e9SAndroid Build Coastguard Worker ) 1170*da0073e9SAndroid Build Coastguard Worker test_module_static_method = make_test(ModuleStaticMethodCall()) 1171*da0073e9SAndroid Build Coastguard Worker test_fnmember = make_test(FnMember()) 1172*da0073e9SAndroid Build Coastguard Worker test_fnmembercmp1 = make_test(FnMemberCmp(F.relu)) 1173*da0073e9SAndroid Build Coastguard Worker test_fnmembercmp2 = make_test(FnMemberCmp(None)) 1174*da0073e9SAndroid Build Coastguard Worker test_constloop = make_test(ConstLoop()) 1175*da0073e9SAndroid Build Coastguard Worker test_istraining1 = make_test(IsTrainingCheck()) 1176*da0073e9SAndroid Build Coastguard Worker test_istraining2 = make_test(IsTrainingCheck()) 1177*da0073e9SAndroid Build Coastguard Worker test_iseval1 = make_test(IsEvalCheck()) 1178*da0073e9SAndroid Build Coastguard Worker test_iseval2 = make_test(IsEvalCheck()) 1179*da0073e9SAndroid Build Coastguard Worker test_viamodulecall = make_test(ViaModuleCall()) 1180*da0073e9SAndroid Build Coastguard Worker test_isnonelayer = make_test(IsNoneLayer()) 1181*da0073e9SAndroid Build Coastguard Worker test_layerlist = make_test(LayerList()) 1182*da0073e9SAndroid Build Coastguard Worker test_tensorlist = make_test(TensorList()) 1183*da0073e9SAndroid Build Coastguard Worker test_intarg = make_test(IntArg()) 1184*da0073e9SAndroid Build Coastguard Worker test_cfgmod = make_test(CfgModule()) 1185*da0073e9SAndroid Build Coastguard Worker test_stringmember = make_test(StringMember()) 1186*da0073e9SAndroid Build Coastguard Worker test_modulelist = make_test(ModuleList()) 1187*da0073e9SAndroid Build Coastguard Worker test_modulelist_nested = make_test(NestedModuleList()) 1188*da0073e9SAndroid Build Coastguard Worker test_modulelist_custom = make_test(CustomGetItemModuleList()) 1189*da0073e9SAndroid Build Coastguard Worker test_moduledict = make_test(ModuleDict()) 1190*da0073e9SAndroid Build Coastguard Worker test_moduledict_custom = make_test(CustomGetItemModuleDict()) 1191*da0073e9SAndroid Build Coastguard Worker test_parameterdict = make_test(ParameterDict()) 1192*da0073e9SAndroid Build Coastguard Worker test_parameterdict_custom = make_test(CustomGetItemParameterDict()) 1193*da0073e9SAndroid Build Coastguard Worker test_super1 = make_test(SuperModule()) 1194*da0073e9SAndroid Build Coastguard Worker test_super2 = make_test(SuperModule2()) 1195*da0073e9SAndroid Build Coastguard Worker test_super_class_method = make_test(SuperChildCallsClassMethod()) 1196*da0073e9SAndroid Build Coastguard Worker test_children = make_test(Children()) 1197*da0073e9SAndroid Build Coastguard Worker test_named_children = make_test(NamedChildren()) 1198*da0073e9SAndroid Build Coastguard Worker test_densenet = make_test(DenseNetBlocks()) 1199*da0073e9SAndroid Build Coastguard Worker test_parameters1 = make_test(ParametersModule1()) 1200*da0073e9SAndroid Build Coastguard Worker test_parameters2 = make_test(ParametersModule2()) 1201*da0073e9SAndroid Build Coastguard Worker test_parameters3 = make_test(ParametersModule3(), expected_ops=5) 1202*da0073e9SAndroid Build Coastguard Worker test_parameters4 = make_test(ParametersModule4()) 1203*da0073e9SAndroid Build Coastguard Worker test_parameters5 = make_test(ParametersModule5()) 1204*da0073e9SAndroid Build Coastguard Worker test_hasattr = make_test(HasAttrModule()) 1205*da0073e9SAndroid Build Coastguard Worker test_enumvalues = make_test(EnumValues()) 1206*da0073e9SAndroid Build Coastguard Worker test_access_by_keys = make_test(AccessByKeys()) 1207*da0073e9SAndroid Build Coastguard Worker test_module_class_method = make_test(ModuleClassMethodCall()) 1208*da0073e9SAndroid Build Coastguard Worker test_module_property = make_test(ModuleProperty()) 1209*da0073e9SAndroid Build Coastguard Worker test_forward_directly = make_test(CallForwardDirectly()) 1210*da0073e9SAndroid Build Coastguard Worker test_module_name_string = make_test(ModuleNameString()) 1211*da0073e9SAndroid Build Coastguard Worker test_module_attribute_precedence = make_test(ModuleAttributePrecedence()) 1212*da0073e9SAndroid Build Coastguard Worker test_module_guard_name_is_valid = make_test(ModuleGuardNameIsValid()) 1213*da0073e9SAndroid Build Coastguard Worker test_sequential_with_duplicated_module = make_test(SequentialWithDuplicatedModule()) 1214*da0073e9SAndroid Build Coastguard Worker test_sequential_with_duplicated_module2 = make_test( 1215*da0073e9SAndroid Build Coastguard Worker SequentialWithDuplicatedModule2() 1216*da0073e9SAndroid Build Coastguard Worker ) 1217*da0073e9SAndroid Build Coastguard Worker test_module_comparison = make_test(ModuleComparison()) 1218*da0073e9SAndroid Build Coastguard Worker 1219*da0073e9SAndroid Build Coastguard Worker def test_module_forward_has_graph_break(self): 1220*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardHasGraphBreak() 1221*da0073e9SAndroid Build Coastguard Worker x = torch.rand([10, 10]) 1222*da0073e9SAndroid Build Coastguard Worker ref = m(x) 1223*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize("eager")(m) 1224*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 1225*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 1226*da0073e9SAndroid Build Coastguard Worker 1227*da0073e9SAndroid Build Coastguard Worker def test_unsupportedmethod(self): 1228*da0073e9SAndroid Build Coastguard Worker m = UnsupportedMethodCall() 1229*da0073e9SAndroid Build Coastguard Worker i = torch.randn(10) 1230*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1231*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize(cnt)(m) 1232*da0073e9SAndroid Build Coastguard Worker r = opt_m(i) 1233*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(r, m(i))) 1234*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 5) 1235*da0073e9SAndroid Build Coastguard Worker 1236*da0073e9SAndroid Build Coastguard Worker def test_unsupportedmodule(self): 1237*da0073e9SAndroid Build Coastguard Worker m = UnsupportedModuleCall() 1238*da0073e9SAndroid Build Coastguard Worker i = torch.randn(10) 1239*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1240*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize(cnt)(m) 1241*da0073e9SAndroid Build Coastguard Worker r = opt_m(i) 1242*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(r, m(i))) 1243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 6) 1244*da0073e9SAndroid Build Coastguard Worker 1245*da0073e9SAndroid Build Coastguard Worker def test_self_mutating1(self): 1246*da0073e9SAndroid Build Coastguard Worker m1 = torch.nn.Linear(10, 10) 1247*da0073e9SAndroid Build Coastguard Worker m2 = SelfMutatingModule(m1) 1248*da0073e9SAndroid Build Coastguard Worker m3 = SelfMutatingModule(m1) 1249*da0073e9SAndroid Build Coastguard Worker m4 = SelfMutatingModule(m1) 1250*da0073e9SAndroid Build Coastguard Worker i = torch.randn(10) 1251*da0073e9SAndroid Build Coastguard Worker out2 = [m2(i), m2(i), m2(i)] 1252*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1253*da0073e9SAndroid Build Coastguard Worker opt_m3 = torch._dynamo.optimize_assert(cnt)(m3) 1254*da0073e9SAndroid Build Coastguard Worker opt_m4 = torch._dynamo.optimize_assert(cnt)(m4) 1255*da0073e9SAndroid Build Coastguard Worker out3 = [opt_m3(i), opt_m3(i), opt_m3(i)] 1256*da0073e9SAndroid Build Coastguard Worker out4 = [opt_m4(i), opt_m4(i), opt_m4(i)] 1257*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(out2, out3)) 1258*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(out2, out4)) 1259*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 1260*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """2""") 1261*da0073e9SAndroid Build Coastguard Worker else: 1262*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False) 1265*da0073e9SAndroid Build Coastguard Worker def test_generation_tag(self): 1266*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1267*da0073e9SAndroid Build Coastguard Worker 1268*da0073e9SAndroid Build Coastguard Worker # guarantee that we have installed 1269*da0073e9SAndroid Build Coastguard Worker # the generation tagging function 1270*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.optimize_assert(cnt): 1271*da0073e9SAndroid Build Coastguard Worker pass 1272*da0073e9SAndroid Build Coastguard Worker 1273*da0073e9SAndroid Build Coastguard Worker m1 = torch.nn.Linear(10, 10) 1274*da0073e9SAndroid Build Coastguard Worker prev_generation = GenerationTracker.get_generation_value(m1) 1275*da0073e9SAndroid Build Coastguard Worker cur_generation = prev_generation + 1 1276*da0073e9SAndroid Build Coastguard Worker 1277*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.optimize_assert(cnt): 1278*da0073e9SAndroid Build Coastguard Worker m2 = torch.nn.Linear(10, 10) 1279*da0073e9SAndroid Build Coastguard Worker 1280*da0073e9SAndroid Build Coastguard Worker self.assertEqual(GenerationTracker.get_generation_value(m1), prev_generation) 1281*da0073e9SAndroid Build Coastguard Worker self.assertEqual(GenerationTracker.get_generation_value(m2), cur_generation) 1282*da0073e9SAndroid Build Coastguard Worker # check that newly constructed instances 1283*da0073e9SAndroid Build Coastguard Worker # also have the same generation (even if copied from an old instance) 1284*da0073e9SAndroid Build Coastguard Worker m3 = deepcopy(m1) 1285*da0073e9SAndroid Build Coastguard Worker self.assertEqual(GenerationTracker.get_generation_value(m3), cur_generation) 1286*da0073e9SAndroid Build Coastguard Worker 1287*da0073e9SAndroid Build Coastguard Worker def test_simple_torch_function(self): 1288*da0073e9SAndroid Build Coastguard Worker def foo(x): 1289*da0073e9SAndroid Build Coastguard Worker # function call, twice to test wrapping 1290*da0073e9SAndroid Build Coastguard Worker x = F.sigmoid(x) 1291*da0073e9SAndroid Build Coastguard Worker x = F.sigmoid(x) 1292*da0073e9SAndroid Build Coastguard Worker # method call, twice to test wrapping 1293*da0073e9SAndroid Build Coastguard Worker x = x.sigmoid() 1294*da0073e9SAndroid Build Coastguard Worker x = x.sigmoid() 1295*da0073e9SAndroid Build Coastguard Worker return x 1296*da0073e9SAndroid Build Coastguard Worker 1297*da0073e9SAndroid Build Coastguard Worker with temporary_tensor_subclass() as TensorProxy: 1298*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1).as_subclass(TensorProxy) 1299*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1300*da0073e9SAndroid Build Coastguard Worker out1 = foo(x) 1301*da0073e9SAndroid Build Coastguard Worker opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) 1302*da0073e9SAndroid Build Coastguard Worker out2 = opt_foo(x) 1303*da0073e9SAndroid Build Coastguard Worker 1304*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 4) 1305*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(out1, out2)) 1306*da0073e9SAndroid Build Coastguard Worker 1307*da0073e9SAndroid Build Coastguard Worker def test_torch_function_with_closure(self): 1308*da0073e9SAndroid Build Coastguard Worker def run(): 1309*da0073e9SAndroid Build Coastguard Worker def foo(x): 1310*da0073e9SAndroid Build Coastguard Worker # function call, twice to test wrapping 1311*da0073e9SAndroid Build Coastguard Worker x = F.sigmoid(x) 1312*da0073e9SAndroid Build Coastguard Worker x = F.sigmoid(x) 1313*da0073e9SAndroid Build Coastguard Worker # method call, twice to test wrapping 1314*da0073e9SAndroid Build Coastguard Worker x = x.sigmoid() 1315*da0073e9SAndroid Build Coastguard Worker x = x.sigmoid() 1316*da0073e9SAndroid Build Coastguard Worker return x 1317*da0073e9SAndroid Build Coastguard Worker 1318*da0073e9SAndroid Build Coastguard Worker counter = 0 1319*da0073e9SAndroid Build Coastguard Worker 1320*da0073e9SAndroid Build Coastguard Worker def function(): 1321*da0073e9SAndroid Build Coastguard Worker nonlocal counter 1322*da0073e9SAndroid Build Coastguard Worker # for now, only support reads from closure cells 1323*da0073e9SAndroid Build Coastguard Worker # TODO(future PR): support writes as well 1324*da0073e9SAndroid Build Coastguard Worker counter + 1 1325*da0073e9SAndroid Build Coastguard Worker 1326*da0073e9SAndroid Build Coastguard Worker with temporary_tensor_subclass(function) as TensorProxy: 1327*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1).as_subclass(TensorProxy) 1328*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 1329*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1330*da0073e9SAndroid Build Coastguard Worker out1 = foo(x) 1331*da0073e9SAndroid Build Coastguard Worker opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) 1332*da0073e9SAndroid Build Coastguard Worker out2 = opt_foo(x) 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 4) 1335*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(out1, out2)) 1336*da0073e9SAndroid Build Coastguard Worker 1337*da0073e9SAndroid Build Coastguard Worker run() 1338*da0073e9SAndroid Build Coastguard Worker 1339*da0073e9SAndroid Build Coastguard Worker def test_torch_mangled_class_name(self): 1340*da0073e9SAndroid Build Coastguard Worker original = TensorWithTFOverrideVariable.global_mangled_class_name 1341*da0073e9SAndroid Build Coastguard Worker results = [] 1342*da0073e9SAndroid Build Coastguard Worker 1343*da0073e9SAndroid Build Coastguard Worker def instrumented(self, tx): 1344*da0073e9SAndroid Build Coastguard Worker result = original(self, tx) 1345*da0073e9SAndroid Build Coastguard Worker results.append(result) 1346*da0073e9SAndroid Build Coastguard Worker return result 1347*da0073e9SAndroid Build Coastguard Worker 1348*da0073e9SAndroid Build Coastguard Worker TensorWithTFOverrideVariable.global_mangled_class_name = instrumented 1349*da0073e9SAndroid Build Coastguard Worker 1350*da0073e9SAndroid Build Coastguard Worker def one_break(x): 1351*da0073e9SAndroid Build Coastguard Worker x = F.sigmoid(x) 1352*da0073e9SAndroid Build Coastguard Worker print() # force break 1353*da0073e9SAndroid Build Coastguard Worker x = x.sigmoid() 1354*da0073e9SAndroid Build Coastguard Worker return x 1355*da0073e9SAndroid Build Coastguard Worker 1356*da0073e9SAndroid Build Coastguard Worker try: 1357*da0073e9SAndroid Build Coastguard Worker with temporary_tensor_subclass() as TensorProxy: 1358*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1).as_subclass(TensorProxy) 1359*da0073e9SAndroid Build Coastguard Worker x1 = one_break(x) 1360*da0073e9SAndroid Build Coastguard Worker 1361*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1362*da0073e9SAndroid Build Coastguard Worker opt_one_break = torch._dynamo.optimize(cnt)(one_break) 1363*da0073e9SAndroid Build Coastguard Worker x2 = opt_one_break(x) 1364*da0073e9SAndroid Build Coastguard Worker 1365*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(x1, x2)) 1366*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 1367*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 2) 1368*da0073e9SAndroid Build Coastguard Worker 1369*da0073e9SAndroid Build Coastguard Worker compile_ids = set() 1370*da0073e9SAndroid Build Coastguard Worker for r in results: 1371*da0073e9SAndroid Build Coastguard Worker # A mangled classname looks like __subclass_TensorProxy_94524181138240_c0 1372*da0073e9SAndroid Build Coastguard Worker # where the last segment contains the compile_id. 1373*da0073e9SAndroid Build Coastguard Worker prefix = "__subclass_TensorProxy_" 1374*da0073e9SAndroid Build Coastguard Worker before, sep, after = r.partition(prefix) 1375*da0073e9SAndroid Build Coastguard Worker self.assertEqual(before, "") 1376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sep, prefix) 1377*da0073e9SAndroid Build Coastguard Worker 1378*da0073e9SAndroid Build Coastguard Worker class_type_id, compile_id = after.split("_") 1379*da0073e9SAndroid Build Coastguard Worker self.assertTrue(class_type_id.isnumeric()) 1380*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compile_id.startswith("c")) 1381*da0073e9SAndroid Build Coastguard Worker 1382*da0073e9SAndroid Build Coastguard Worker cid = compile_id[1:] 1383*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cid.isnumeric()) 1384*da0073e9SAndroid Build Coastguard Worker compile_ids.add(cid) 1385*da0073e9SAndroid Build Coastguard Worker 1386*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(compile_ids), 3) 1387*da0073e9SAndroid Build Coastguard Worker 1388*da0073e9SAndroid Build Coastguard Worker finally: 1389*da0073e9SAndroid Build Coastguard Worker TensorWithTFOverrideVariable.global_mangled_class_name = original 1390*da0073e9SAndroid Build Coastguard Worker 1391*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False) 1392*da0073e9SAndroid Build Coastguard Worker def test_nn_moduledict_contains(self): 1393*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 1394*da0073e9SAndroid Build Coastguard Worker def __init__(self, module_dict): 1395*da0073e9SAndroid Build Coastguard Worker super().__init__() 1396*da0073e9SAndroid Build Coastguard Worker self.module_dict = module_dict 1397*da0073e9SAndroid Build Coastguard Worker 1398*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1399*da0073e9SAndroid Build Coastguard Worker if "foo" in self.module_dict: 1400*da0073e9SAndroid Build Coastguard Worker x = torch.mul(x, 1.0) 1401*da0073e9SAndroid Build Coastguard Worker x = torch.add(x, 1.0) 1402*da0073e9SAndroid Build Coastguard Worker return x 1403*da0073e9SAndroid Build Coastguard Worker 1404*da0073e9SAndroid Build Coastguard Worker module_dict = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)}) 1405*da0073e9SAndroid Build Coastguard Worker m = M(module_dict) 1406*da0073e9SAndroid Build Coastguard Worker data = torch.randn(1) 1407*da0073e9SAndroid Build Coastguard Worker out1 = m(data) 1408*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1409*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) 1410*da0073e9SAndroid Build Coastguard Worker out2 = opt_m(data) 1411*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 2) 1412*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(out1, out2)) 1413*da0073e9SAndroid Build Coastguard Worker 1414*da0073e9SAndroid Build Coastguard Worker module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)}) 1415*da0073e9SAndroid Build Coastguard Worker m = M(module_dict) 1416*da0073e9SAndroid Build Coastguard Worker data = torch.randn(1) 1417*da0073e9SAndroid Build Coastguard Worker out1 = m(data) 1418*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1419*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 1420*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) 1421*da0073e9SAndroid Build Coastguard Worker out2 = opt_m(data) 1422*da0073e9SAndroid Build Coastguard Worker 1423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 1) 1424*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(out1, out2)) 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)}) 1427*da0073e9SAndroid Build Coastguard Worker pre = m(data) 1428*da0073e9SAndroid Build Coastguard Worker cnt.clear() 1429*da0073e9SAndroid Build Coastguard Worker 1430*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.optimize(cnt, nopython=False): 1431*da0073e9SAndroid Build Coastguard Worker opt_pre = m(data) 1432*da0073e9SAndroid Build Coastguard Worker m = M(module_dict) 1433*da0073e9SAndroid Build Coastguard Worker data = torch.randn(1) 1434*da0073e9SAndroid Build Coastguard Worker out1 = m(data) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker out_post = m(data) 1437*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 1) 1439*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(pre, opt_pre)) 1440*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(out1, out_post)) 1441*da0073e9SAndroid Build Coastguard Worker 1442*da0073e9SAndroid Build Coastguard Worker # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1443*da0073e9SAndroid Build Coastguard Worker @expectedFailureDynamic 1444*da0073e9SAndroid Build Coastguard Worker def test_lazy_module1(self): 1445*da0073e9SAndroid Build Coastguard Worker input_shape = (16, 3, 6, 7, 8) 1446*da0073e9SAndroid Build Coastguard Worker 1447*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1448*da0073e9SAndroid Build Coastguard Worker module = LazyModule() 1449*da0073e9SAndroid Build Coastguard Worker 1450*da0073e9SAndroid Build Coastguard Worker def test_static_module(): 1451*da0073e9SAndroid Build Coastguard Worker input = torch.ones(*input_shape) 1452*da0073e9SAndroid Build Coastguard Worker module(input) 1453*da0073e9SAndroid Build Coastguard Worker 1454*da0073e9SAndroid Build Coastguard Worker # test no graph break 1455*da0073e9SAndroid Build Coastguard Worker opt_test_static_module = torch._dynamo.optimize(cnt, nopython=True)( 1456*da0073e9SAndroid Build Coastguard Worker test_static_module 1457*da0073e9SAndroid Build Coastguard Worker ) 1458*da0073e9SAndroid Build Coastguard Worker opt_test_static_module() 1459*da0073e9SAndroid Build Coastguard Worker 1460*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1461*da0073e9SAndroid Build Coastguard Worker isinstance(module, MaterializedModule), 1462*da0073e9SAndroid Build Coastguard Worker "Module should be transformed to an instance of MaterializedModule.", 1463*da0073e9SAndroid Build Coastguard Worker ) 1464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.param.shape, input_shape) 1465*da0073e9SAndroid Build Coastguard Worker 1466*da0073e9SAndroid Build Coastguard Worker # test when mapped to UnspecializedNNModule 1467*da0073e9SAndroid Build Coastguard Worker module = LazyModule() 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Worker def test_unspecialized(): 1470*da0073e9SAndroid Build Coastguard Worker nonlocal module 1471*da0073e9SAndroid Build Coastguard Worker module = LazyModule() 1472*da0073e9SAndroid Build Coastguard Worker input = torch.ones(*input_shape) 1473*da0073e9SAndroid Build Coastguard Worker module(input) 1474*da0073e9SAndroid Build Coastguard Worker 1475*da0073e9SAndroid Build Coastguard Worker opt_test_unspecialized = torch._dynamo.optimize(cnt)(test_unspecialized) 1476*da0073e9SAndroid Build Coastguard Worker opt_test_unspecialized() 1477*da0073e9SAndroid Build Coastguard Worker 1478*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1479*da0073e9SAndroid Build Coastguard Worker isinstance(module, MaterializedModule), 1480*da0073e9SAndroid Build Coastguard Worker "Module should be transformed to an instance of MaterializedModule.", 1481*da0073e9SAndroid Build Coastguard Worker ) 1482*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.param.shape, input_shape) 1483*da0073e9SAndroid Build Coastguard Worker 1484*da0073e9SAndroid Build Coastguard Worker # test with a static module in torch.* 1485*da0073e9SAndroid Build Coastguard Worker module = torch.nn.modules.LazyBatchNorm3d( 1486*da0073e9SAndroid Build Coastguard Worker affine=False, track_running_stats=False 1487*da0073e9SAndroid Build Coastguard Worker ) 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1490*da0073e9SAndroid Build Coastguard Worker 1491*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 1492*da0073e9SAndroid Build Coastguard Worker 1493*da0073e9SAndroid Build Coastguard Worker def test_torch_static(): 1494*da0073e9SAndroid Build Coastguard Worker input = torch.ones(*input_shape) 1495*da0073e9SAndroid Build Coastguard Worker return module(input) # fully materialized 1496*da0073e9SAndroid Build Coastguard Worker 1497*da0073e9SAndroid Build Coastguard Worker # test no graph break 1498*da0073e9SAndroid Build Coastguard Worker opt_test_torch_static = torch._dynamo.optimize(cnt, nopython=True)( 1499*da0073e9SAndroid Build Coastguard Worker test_torch_static 1500*da0073e9SAndroid Build Coastguard Worker ) 1501*da0073e9SAndroid Build Coastguard Worker opt_test_torch_static() 1502*da0073e9SAndroid Build Coastguard Worker out = opt_test_torch_static() 1503*da0073e9SAndroid Build Coastguard Worker 1504*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(out, module(torch.ones(*input_shape)))) 1505*da0073e9SAndroid Build Coastguard Worker 1506*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1507*da0073e9SAndroid Build Coastguard Worker isinstance(module, torch.nn.modules.batchnorm.BatchNorm3d), 1508*da0073e9SAndroid Build Coastguard Worker "Module should be transformed to an instance of BatchNorm3d.", 1509*da0073e9SAndroid Build Coastguard Worker ) 1510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1, "No guards should have triggered.") 1511*da0073e9SAndroid Build Coastguard Worker 1512*da0073e9SAndroid Build Coastguard Worker # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1513*da0073e9SAndroid Build Coastguard Worker @expectedFailureDynamic 1514*da0073e9SAndroid Build Coastguard Worker def test_lazy_module2(self): 1515*da0073e9SAndroid Build Coastguard Worker # Test FX graph 'call_module' works well if argument is lazy module 1516*da0073e9SAndroid Build Coastguard Worker m = LazyMLP() 1517*da0073e9SAndroid Build Coastguard Worker x = torch.rand([10, 10]) 1518*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize("eager", nopython=True)(m) 1519*da0073e9SAndroid Build Coastguard Worker # We should run compile mode firstly, otherwise the module 1520*da0073e9SAndroid Build Coastguard Worker # would be initialized when running eager mode. 1521*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 1522*da0073e9SAndroid Build Coastguard Worker ref = m(x) 1523*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 1524*da0073e9SAndroid Build Coastguard Worker 1525*da0073e9SAndroid Build Coastguard Worker # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1526*da0073e9SAndroid Build Coastguard Worker @expectedFailureDynamic 1527*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 1528*da0073e9SAndroid Build Coastguard Worker def test_lazy_module3(self): 1529*da0073e9SAndroid Build Coastguard Worker m = LazyMLP() 1530*da0073e9SAndroid Build Coastguard Worker x = torch.rand([10, 10]) 1531*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1532*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) 1533*da0073e9SAndroid Build Coastguard Worker # first iteration 1534*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 1535*da0073e9SAndroid Build Coastguard Worker ref = m(x) 1536*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 1537*da0073e9SAndroid Build Coastguard Worker # move to cuda and second iteration 1538*da0073e9SAndroid Build Coastguard Worker m = m.to("cuda") 1539*da0073e9SAndroid Build Coastguard Worker x = x.to("cuda") 1540*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 1541*da0073e9SAndroid Build Coastguard Worker ref = m(x) 1542*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 1543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 1544*da0073e9SAndroid Build Coastguard Worker 1545*da0073e9SAndroid Build Coastguard Worker # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1546*da0073e9SAndroid Build Coastguard Worker @expectedFailureDynamic 1547*da0073e9SAndroid Build Coastguard Worker def test_lazy_module4(self): 1548*da0073e9SAndroid Build Coastguard Worker m = LazyMLP() 1549*da0073e9SAndroid Build Coastguard Worker x = torch.rand([10, 10]) 1550*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1551*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) 1552*da0073e9SAndroid Build Coastguard Worker # first iteration 1553*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 1554*da0073e9SAndroid Build Coastguard Worker ref = m(x) 1555*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 1556*da0073e9SAndroid Build Coastguard Worker # input shape changed and second iteration 1557*da0073e9SAndroid Build Coastguard Worker x = torch.rand([20, 20]) 1558*da0073e9SAndroid Build Coastguard Worker try: 1559*da0073e9SAndroid Build Coastguard Worker opt_m(x) 1560*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 1561*da0073e9SAndroid Build Coastguard Worker self.assertIn("must have same reduction dim", traceback.format_exc()) 1562*da0073e9SAndroid Build Coastguard Worker 1563*da0073e9SAndroid Build Coastguard Worker # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1564*da0073e9SAndroid Build Coastguard Worker @expectedFailureDynamic 1565*da0073e9SAndroid Build Coastguard Worker def test_lazy_module5(self): 1566*da0073e9SAndroid Build Coastguard Worker # Test lazy module works well with list/tuple input 1567*da0073e9SAndroid Build Coastguard Worker m = LazyModuleWithListInput() 1568*da0073e9SAndroid Build Coastguard Worker x = [torch.rand([5, 5])] * 3 + [None] 1569*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize("eager", nopython=True)(m) 1570*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 1571*da0073e9SAndroid Build Coastguard Worker ref = m(x) 1572*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 1573*da0073e9SAndroid Build Coastguard Worker 1574*da0073e9SAndroid Build Coastguard Worker # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1575*da0073e9SAndroid Build Coastguard Worker @expectedFailureDynamic 1576*da0073e9SAndroid Build Coastguard Worker def test_lazy_module6(self): 1577*da0073e9SAndroid Build Coastguard Worker # Test new lazy submodule in lazy module's initialize_parameters 1578*da0073e9SAndroid Build Coastguard Worker m = LazyModuleWithLazySubmodule() 1579*da0073e9SAndroid Build Coastguard Worker x = [torch.rand([5, 5])] * 3 1580*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize("eager", nopython=True)(m) 1581*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 1582*da0073e9SAndroid Build Coastguard Worker ref = m(x) 1583*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 1584*da0073e9SAndroid Build Coastguard Worker 1585*da0073e9SAndroid Build Coastguard Worker # RuntimeError: SymIntArrayRef expected to contain only concrete integers 1586*da0073e9SAndroid Build Coastguard Worker @expectedFailureDynamic 1587*da0073e9SAndroid Build Coastguard Worker def test_lazy_module7(self): 1588*da0073e9SAndroid Build Coastguard Worker # Test lazy module works well with namedtuple/dict input 1589*da0073e9SAndroid Build Coastguard Worker m = LazyModuleWithNamedTupleInput() 1590*da0073e9SAndroid Build Coastguard Worker x = MyInput( 1591*da0073e9SAndroid Build Coastguard Worker x={"a": [torch.rand([5, 5])] * 3, "b": torch.rand([5, 5])}, 1592*da0073e9SAndroid Build Coastguard Worker y=torch.rand([5, 5]), 1593*da0073e9SAndroid Build Coastguard Worker ) 1594*da0073e9SAndroid Build Coastguard Worker opt_m = torch.compile(backend="eager", fullgraph=True)(m) 1595*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 1596*da0073e9SAndroid Build Coastguard Worker ref = m(x) 1597*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 1598*da0073e9SAndroid Build Coastguard Worker 1599*da0073e9SAndroid Build Coastguard Worker def test_lazy_module_no_cls_to_become(self): 1600*da0073e9SAndroid Build Coastguard Worker # make sure super() works in the case where cls_to_become is None 1601*da0073e9SAndroid Build Coastguard Worker m = LazyChildModuleNoClsToBecome() 1602*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 2) 1603*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize("eager", nopython=True)(m) 1604*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 1605*da0073e9SAndroid Build Coastguard Worker ref = m(x) 1606*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 1607*da0073e9SAndroid Build Coastguard Worker 1608*da0073e9SAndroid Build Coastguard Worker def test_lazy_module_kwargs(self): 1609*da0073e9SAndroid Build Coastguard Worker m = LazyModuleKwArgs() 1610*da0073e9SAndroid Build Coastguard Worker x = [torch.rand([5, 5])] * 3 1611*da0073e9SAndroid Build Coastguard Worker y = [torch.rand([5, 5])] * 2 1612*da0073e9SAndroid Build Coastguard Worker opt_m = torch.compile(backend="eager", fullgraph=True)(m) 1613*da0073e9SAndroid Build Coastguard Worker exp_res = m(x, y) 1614*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(exp_res, opt_m(x, y))) 1615*da0073e9SAndroid Build Coastguard Worker 1616*da0073e9SAndroid Build Coastguard Worker def test_call_fn_with_non_const_inputs_safe(self): 1617*da0073e9SAndroid Build Coastguard Worker class ModuleSpecialFwd(torch.nn.Module): 1618*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1619*da0073e9SAndroid Build Coastguard Worker super().__init__() 1620*da0073e9SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d( 1621*da0073e9SAndroid Build Coastguard Worker in_channels=3, out_channels=20, kernel_size=(5, 5) 1622*da0073e9SAndroid Build Coastguard Worker ) 1623*da0073e9SAndroid Build Coastguard Worker 1624*da0073e9SAndroid Build Coastguard Worker def _conv_forward(self, x): 1625*da0073e9SAndroid Build Coastguard Worker return self.conv._conv_forward(x, self.conv.weight, self.conv.bias) 1626*da0073e9SAndroid Build Coastguard Worker 1627*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1628*da0073e9SAndroid Build Coastguard Worker return self._conv_forward(x) 1629*da0073e9SAndroid Build Coastguard Worker 1630*da0073e9SAndroid Build Coastguard Worker mod = ModuleSpecialFwd() 1631*da0073e9SAndroid Build Coastguard Worker rx = torch.randn([3, 10, 10]) 1632*da0073e9SAndroid Build Coastguard Worker real = mod(rx) 1633*da0073e9SAndroid Build Coastguard Worker graph, _ = torch._dynamo.export(mod)(rx) 1634*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) 1635*da0073e9SAndroid Build Coastguard Worker 1636*da0073e9SAndroid Build Coastguard Worker def test_conv_call_forward_directly(self): 1637*da0073e9SAndroid Build Coastguard Worker m = ConvCallForwardDirectly() 1638*da0073e9SAndroid Build Coastguard Worker x = torch.rand([4, 3, 9, 9]) 1639*da0073e9SAndroid Build Coastguard Worker ref = m(x) 1640*da0073e9SAndroid Build Coastguard Worker opt_m = torch.compile(backend="eager", fullgraph=True)(m) 1641*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 1642*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 1643*da0073e9SAndroid Build Coastguard Worker 1644*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose_call_forward_directly(self): 1645*da0073e9SAndroid Build Coastguard Worker m = ConvTransposeCallForwardDirectly() 1646*da0073e9SAndroid Build Coastguard Worker x = torch.rand([4, 4, 4, 4]) 1647*da0073e9SAndroid Build Coastguard Worker ref = m(x) 1648*da0073e9SAndroid Build Coastguard Worker opt_m = torch.compile(backend="eager", fullgraph=True)(m) 1649*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 1650*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 1651*da0073e9SAndroid Build Coastguard Worker 1652*da0073e9SAndroid Build Coastguard Worker def test_conv_call_super_forward_directly(self): 1653*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4) 1654*da0073e9SAndroid Build Coastguard Worker m = ConvCallSuperForwardDirectly(4, 4, 4) 1655*da0073e9SAndroid Build Coastguard Worker ref = m(x) 1656*da0073e9SAndroid Build Coastguard Worker opt_m = torch.compile(backend="eager", fullgraph=True)(m) 1657*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 1658*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 1659*da0073e9SAndroid Build Coastguard Worker 1660*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose_call_super_forward_directly(self): 1661*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, 4) 1662*da0073e9SAndroid Build Coastguard Worker m = ConvTransposeCallSuperForwardDirectly(4, 4, 4) 1663*da0073e9SAndroid Build Coastguard Worker ref = m(x) 1664*da0073e9SAndroid Build Coastguard Worker opt_m = torch.compile(backend="eager", fullgraph=True)(m) 1665*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 1666*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref, res)) 1667*da0073e9SAndroid Build Coastguard Worker 1668*da0073e9SAndroid Build Coastguard Worker 1669*da0073e9SAndroid Build Coastguard Workerclass MockModule(torch.nn.Module): 1670*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1671*da0073e9SAndroid Build Coastguard Worker super().__init__() 1672*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 1673*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(10, 10) 1674*da0073e9SAndroid Build Coastguard Worker self.buf0 = torch.nn.Buffer(torch.randn(10, 10)) 1675*da0073e9SAndroid Build Coastguard Worker 1676*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1677*da0073e9SAndroid Build Coastguard Worker return self.relu(self.linear(x) + self.buf0) 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Worker 1680*da0073e9SAndroid Build Coastguard Workerclass OptimizedModuleTest(torch._dynamo.test_case.TestCase): 1681*da0073e9SAndroid Build Coastguard Worker def test_nn_module(self): 1682*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 1683*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1684*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize(cnt)(mod) 1685*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) 1686*da0073e9SAndroid Build Coastguard Worker 1687*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 1688*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) 1689*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1690*da0073e9SAndroid Build Coastguard Worker 1691*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(guard_nn_modules=True) 1692*da0073e9SAndroid Build Coastguard Worker def test_attr_precedence(self): 1693*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 1694*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1695*da0073e9SAndroid Build Coastguard Worker super().__init__() 1696*da0073e9SAndroid Build Coastguard Worker self.a = 3 1697*da0073e9SAndroid Build Coastguard Worker 1698*da0073e9SAndroid Build Coastguard Worker def forward(self, x, c=4): 1699*da0073e9SAndroid Build Coastguard Worker return x * c 1700*da0073e9SAndroid Build Coastguard Worker 1701*da0073e9SAndroid Build Coastguard Worker def linear(self, x): 1702*da0073e9SAndroid Build Coastguard Worker return x 1703*da0073e9SAndroid Build Coastguard Worker 1704*da0073e9SAndroid Build Coastguard Worker def b(self, x): 1705*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Should not be called") 1706*da0073e9SAndroid Build Coastguard Worker 1707*da0073e9SAndroid Build Coastguard Worker class MyMod(Mod): 1708*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1709*da0073e9SAndroid Build Coastguard Worker super().__init__() 1710*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(11, 11) 1711*da0073e9SAndroid Build Coastguard Worker self.a = 2 1712*da0073e9SAndroid Build Coastguard Worker self.b = 2 1713*da0073e9SAndroid Build Coastguard Worker self.scale = 1 1714*da0073e9SAndroid Build Coastguard Worker 1715*da0073e9SAndroid Build Coastguard Worker def scale(self, x): 1716*da0073e9SAndroid Build Coastguard Worker # Should not be called because it is shadowed by the instance 1717*da0073e9SAndroid Build Coastguard Worker # attribute 1718*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Should not be called") 1719*da0073e9SAndroid Build Coastguard Worker 1720*da0073e9SAndroid Build Coastguard Worker def forward(self, x, c=None): 1721*da0073e9SAndroid Build Coastguard Worker return self.linear(x) * self.a * self.b * self.scale 1722*da0073e9SAndroid Build Coastguard Worker 1723*da0073e9SAndroid Build Coastguard Worker mod = MyMod() 1724*da0073e9SAndroid Build Coastguard Worker x = torch.ones(3, 3) 1725*da0073e9SAndroid Build Coastguard Worker ref = mod(x) 1726*da0073e9SAndroid Build Coastguard Worker 1727*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1728*da0073e9SAndroid Build Coastguard Worker opt_mod = torch.compile(mod, backend=cnts) 1729*da0073e9SAndroid Build Coastguard Worker opt_mod(torch.ones(3, 3)) 1730*da0073e9SAndroid Build Coastguard Worker res = opt_mod(torch.ones(3, 3)) 1731*da0073e9SAndroid Build Coastguard Worker 1732*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1733*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1734*da0073e9SAndroid Build Coastguard Worker 1735*da0073e9SAndroid Build Coastguard Worker def test_to(self): 1736*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 1737*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1738*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize(cnt)(mod) 1739*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 1740*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) 1741*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1742*da0073e9SAndroid Build Coastguard Worker 1743*da0073e9SAndroid Build Coastguard Worker # Ensure that there is no recompilation 1744*da0073e9SAndroid Build Coastguard Worker opt_mod(x) 1745*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1746*da0073e9SAndroid Build Coastguard Worker 1747*da0073e9SAndroid Build Coastguard Worker opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64) 1748*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) 1749*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10).to(dtype=torch.float64) 1750*da0073e9SAndroid Build Coastguard Worker opt_mod(x) 1751*da0073e9SAndroid Build Coastguard Worker # Ensure that there is a recompilation 1752*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 1753*da0073e9SAndroid Build Coastguard Worker 1754*da0073e9SAndroid Build Coastguard Worker # Ensure that there is no recompilation 1755*da0073e9SAndroid Build Coastguard Worker opt_mod(x) 1756*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 1757*da0073e9SAndroid Build Coastguard Worker 1758*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 1759*da0073e9SAndroid Build Coastguard Worker opt_mod(x) 1760*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 3) 1761*da0073e9SAndroid Build Coastguard Worker 1762*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(guard_nn_modules=True) 1763*da0073e9SAndroid Build Coastguard Worker def test_param_order(self): 1764*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1765*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1766*da0073e9SAndroid Build Coastguard Worker super().__init__() 1767*da0073e9SAndroid Build Coastguard Worker self.param1 = torch.nn.Parameter(torch.ones([1])) 1768*da0073e9SAndroid Build Coastguard Worker self.param2 = torch.nn.Parameter(torch.ones([2])) 1769*da0073e9SAndroid Build Coastguard Worker 1770*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1771*da0073e9SAndroid Build Coastguard Worker return x 1772*da0073e9SAndroid Build Coastguard Worker 1773*da0073e9SAndroid Build Coastguard Worker mod = MyModule() 1774*da0073e9SAndroid Build Coastguard Worker coeffs = [2, 3] 1775*da0073e9SAndroid Build Coastguard Worker 1776*da0073e9SAndroid Build Coastguard Worker def fn(x): 1777*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(mod.parameters()): 1778*da0073e9SAndroid Build Coastguard Worker x += p.sum() * coeffs[idx] 1779*da0073e9SAndroid Build Coastguard Worker 1780*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(mod.named_parameters()): 1781*da0073e9SAndroid Build Coastguard Worker x += p[1].sum() * coeffs[idx] 1782*da0073e9SAndroid Build Coastguard Worker 1783*da0073e9SAndroid Build Coastguard Worker return x 1784*da0073e9SAndroid Build Coastguard Worker 1785*da0073e9SAndroid Build Coastguard Worker ref = fn(torch.ones(1)) 1786*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1787*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1788*da0073e9SAndroid Build Coastguard Worker res = opt_fn(torch.ones(1)) 1789*da0073e9SAndroid Build Coastguard Worker 1790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1792*da0073e9SAndroid Build Coastguard Worker 1793*da0073e9SAndroid Build Coastguard Worker mod._parameters["param1"] = mod._parameters.pop("param1") 1794*da0073e9SAndroid Build Coastguard Worker ref = fn(torch.ones(1)) 1795*da0073e9SAndroid Build Coastguard Worker res = opt_fn(torch.ones(1)) 1796*da0073e9SAndroid Build Coastguard Worker 1797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1798*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 1799*da0073e9SAndroid Build Coastguard Worker 1800*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(guard_nn_modules=True) 1801*da0073e9SAndroid Build Coastguard Worker def test_buffer_order(self): 1802*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1803*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1804*da0073e9SAndroid Build Coastguard Worker super().__init__() 1805*da0073e9SAndroid Build Coastguard Worker self.b1 = torch.nn.Buffer(torch.ones([1])) 1806*da0073e9SAndroid Build Coastguard Worker self.b2 = torch.nn.Buffer(torch.ones([2])) 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1809*da0073e9SAndroid Build Coastguard Worker return x 1810*da0073e9SAndroid Build Coastguard Worker 1811*da0073e9SAndroid Build Coastguard Worker mod = MyModule() 1812*da0073e9SAndroid Build Coastguard Worker coeffs = [2, 3] 1813*da0073e9SAndroid Build Coastguard Worker 1814*da0073e9SAndroid Build Coastguard Worker def fn(x): 1815*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(mod.buffers()): 1816*da0073e9SAndroid Build Coastguard Worker x += p.sum() * coeffs[idx] 1817*da0073e9SAndroid Build Coastguard Worker 1818*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(mod.named_buffers()): 1819*da0073e9SAndroid Build Coastguard Worker x += p[1].sum() * coeffs[idx] 1820*da0073e9SAndroid Build Coastguard Worker 1821*da0073e9SAndroid Build Coastguard Worker return x 1822*da0073e9SAndroid Build Coastguard Worker 1823*da0073e9SAndroid Build Coastguard Worker ref = fn(torch.ones(1)) 1824*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1825*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1826*da0073e9SAndroid Build Coastguard Worker res = opt_fn(torch.ones(1)) 1827*da0073e9SAndroid Build Coastguard Worker 1828*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1829*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1830*da0073e9SAndroid Build Coastguard Worker 1831*da0073e9SAndroid Build Coastguard Worker mod._buffers["b1"] = mod._buffers.pop("b1") 1832*da0073e9SAndroid Build Coastguard Worker ref = fn(torch.ones(1)) 1833*da0073e9SAndroid Build Coastguard Worker res = opt_fn(torch.ones(1)) 1834*da0073e9SAndroid Build Coastguard Worker 1835*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 1837*da0073e9SAndroid Build Coastguard Worker 1838*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(guard_nn_modules=True) 1839*da0073e9SAndroid Build Coastguard Worker def test_module_order(self): 1840*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1841*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1842*da0073e9SAndroid Build Coastguard Worker super().__init__() 1843*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(3, 3) 1844*da0073e9SAndroid Build Coastguard Worker self.linear2 = torch.nn.Linear(10, 10) 1845*da0073e9SAndroid Build Coastguard Worker 1846*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1847*da0073e9SAndroid Build Coastguard Worker return x 1848*da0073e9SAndroid Build Coastguard Worker 1849*da0073e9SAndroid Build Coastguard Worker mod = MyModule() 1850*da0073e9SAndroid Build Coastguard Worker coeffs = [2, 3, 4] 1851*da0073e9SAndroid Build Coastguard Worker 1852*da0073e9SAndroid Build Coastguard Worker coeffs_for_mod = {mod: 10, mod.linear1: 20, mod.linear2: 30} 1853*da0073e9SAndroid Build Coastguard Worker 1854*da0073e9SAndroid Build Coastguard Worker # Check order of _modules 1855*da0073e9SAndroid Build Coastguard Worker def fn(x): 1856*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(mod.modules()): 1857*da0073e9SAndroid Build Coastguard Worker # Something silly to force depedency on the order 1858*da0073e9SAndroid Build Coastguard Worker x += coeffs_for_mod[p] * coeffs[idx] 1859*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(mod.named_modules()): 1860*da0073e9SAndroid Build Coastguard Worker x += coeffs_for_mod[p[1]] * coeffs[idx] 1861*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(mod.children()): 1862*da0073e9SAndroid Build Coastguard Worker x += coeffs_for_mod[p] * coeffs[idx] 1863*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(mod.named_children()): 1864*da0073e9SAndroid Build Coastguard Worker x += coeffs_for_mod[p[1]] * coeffs[idx] 1865*da0073e9SAndroid Build Coastguard Worker return x 1866*da0073e9SAndroid Build Coastguard Worker 1867*da0073e9SAndroid Build Coastguard Worker ref = fn(torch.ones(1)) 1868*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 1869*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 1870*da0073e9SAndroid Build Coastguard Worker res = opt_fn(torch.ones(1)) 1871*da0073e9SAndroid Build Coastguard Worker 1872*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1873*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1874*da0073e9SAndroid Build Coastguard Worker 1875*da0073e9SAndroid Build Coastguard Worker mod._modules["linear1"] = mod._modules.pop("linear1") 1876*da0073e9SAndroid Build Coastguard Worker ref = fn(torch.ones(1)) 1877*da0073e9SAndroid Build Coastguard Worker res = opt_fn(torch.ones(1)) 1878*da0073e9SAndroid Build Coastguard Worker 1879*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1880*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 1881*da0073e9SAndroid Build Coastguard Worker 1882*da0073e9SAndroid Build Coastguard Worker def test_attr(self): 1883*da0073e9SAndroid Build Coastguard Worker class MockModule(torch.nn.Module): 1884*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1885*da0073e9SAndroid Build Coastguard Worker super().__init__() 1886*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(10, 10) 1887*da0073e9SAndroid Build Coastguard Worker self.buf0 = torch.nn.Buffer(torch.randn(10, 10)) 1888*da0073e9SAndroid Build Coastguard Worker 1889*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1890*da0073e9SAndroid Build Coastguard Worker return self.r(torch.sin(x)) + self.buf0 1891*da0073e9SAndroid Build Coastguard Worker 1892*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 1893*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize("eager")(mod) 1894*da0073e9SAndroid Build Coastguard Worker 1895*da0073e9SAndroid Build Coastguard Worker # Check parameters and buffers 1896*da0073e9SAndroid Build Coastguard Worker for p1, p2 in zip(mod.parameters(), opt_mod.parameters()): 1897*da0073e9SAndroid Build Coastguard Worker self.assertTrue(id(p1) == id(p2)) 1898*da0073e9SAndroid Build Coastguard Worker for b1, b2 in zip(mod.buffers(), opt_mod.buffers()): 1899*da0073e9SAndroid Build Coastguard Worker self.assertTrue(id(b1) == id(b2)) 1900*da0073e9SAndroid Build Coastguard Worker 1901*da0073e9SAndroid Build Coastguard Worker def get_parameter_dtype(mod: torch.nn.Module): 1902*da0073e9SAndroid Build Coastguard Worker parameters_and_buffers = itertools.chain(mod.parameters(), mod.buffers()) 1903*da0073e9SAndroid Build Coastguard Worker return next(parameters_and_buffers).dtype 1904*da0073e9SAndroid Build Coastguard Worker 1905*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize("eager")(get_parameter_dtype) 1906*da0073e9SAndroid Build Coastguard Worker out_dtype = opt_mod(mod) 1907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_dtype, torch.float32) 1908*da0073e9SAndroid Build Coastguard Worker 1909*da0073e9SAndroid Build Coastguard Worker def test_dir(self): 1910*da0073e9SAndroid Build Coastguard Worker class MockModule(torch.nn.Module): 1911*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1912*da0073e9SAndroid Build Coastguard Worker super().__init__() 1913*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(10, 10) 1914*da0073e9SAndroid Build Coastguard Worker self.buf0 = torch.nn.Buffer(torch.nn.Buffer(torch.randn(10, 10))) 1915*da0073e9SAndroid Build Coastguard Worker self.register_parameter( 1916*da0073e9SAndroid Build Coastguard Worker name="param0", param=torch.nn.Parameter(torch.randn(10, 10)) 1917*da0073e9SAndroid Build Coastguard Worker ) 1918*da0073e9SAndroid Build Coastguard Worker 1919*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1920*da0073e9SAndroid Build Coastguard Worker return self.r(torch.sin(x)) + self.buf0 1921*da0073e9SAndroid Build Coastguard Worker 1922*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 1923*da0073e9SAndroid Build Coastguard Worker mod_keys = dir(mod) 1924*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize("eager")(mod) 1925*da0073e9SAndroid Build Coastguard Worker opt_mod_keys = dir(opt_mod) 1926*da0073e9SAndroid Build Coastguard Worker 1927*da0073e9SAndroid Build Coastguard Worker # Check user-defined attributes, parameters and buffers 1928*da0073e9SAndroid Build Coastguard Worker self.assertIn("linear", opt_mod_keys) 1929*da0073e9SAndroid Build Coastguard Worker self.assertIn("buf0", opt_mod_keys) 1930*da0073e9SAndroid Build Coastguard Worker self.assertIn("param0", opt_mod_keys) 1931*da0073e9SAndroid Build Coastguard Worker 1932*da0073e9SAndroid Build Coastguard Worker # Check all attributes, parameters and buffers 1933*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(set(mod_keys).difference(opt_mod_keys)) == 0) 1934*da0073e9SAndroid Build Coastguard Worker 1935*da0073e9SAndroid Build Coastguard Worker def test_no_recompile_on_nn_guarded_modules(self): 1936*da0073e9SAndroid Build Coastguard Worker size = (10, 10) 1937*da0073e9SAndroid Build Coastguard Worker cache_size_limit = 1 1938*da0073e9SAndroid Build Coastguard Worker num_submodules = 4 1939*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounterWithBackend("eager") 1940*da0073e9SAndroid Build Coastguard Worker 1941*da0073e9SAndroid Build Coastguard Worker class SubModule(torch.nn.Module): 1942*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1943*da0073e9SAndroid Build Coastguard Worker super().__init__() 1944*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(*size) 1945*da0073e9SAndroid Build Coastguard Worker 1946*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1947*da0073e9SAndroid Build Coastguard Worker a = torch.sin(torch.cos(x)) 1948*da0073e9SAndroid Build Coastguard Worker return self.linear(a) 1949*da0073e9SAndroid Build Coastguard Worker 1950*da0073e9SAndroid Build Coastguard Worker class MockModule(torch.nn.Module): 1951*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1952*da0073e9SAndroid Build Coastguard Worker super().__init__() 1953*da0073e9SAndroid Build Coastguard Worker self.mods = [SubModule() for _ in range(num_submodules)] 1954*da0073e9SAndroid Build Coastguard Worker self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods] 1955*da0073e9SAndroid Build Coastguard Worker 1956*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1957*da0073e9SAndroid Build Coastguard Worker for mod in self.mods: 1958*da0073e9SAndroid Build Coastguard Worker x = mod(x) 1959*da0073e9SAndroid Build Coastguard Worker return x 1960*da0073e9SAndroid Build Coastguard Worker 1961*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 1962*da0073e9SAndroid Build Coastguard Worker # Each submod is compiled separately and has a different nn module 1963*da0073e9SAndroid Build Coastguard Worker # guard. Ensure that recompilation logic is handle correctly. 1964*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch( 1965*da0073e9SAndroid Build Coastguard Worker "torch._dynamo.config.error_on_recompile", True 1966*da0073e9SAndroid Build Coastguard Worker ), unittest.mock.patch( 1967*da0073e9SAndroid Build Coastguard Worker "torch._dynamo.config.cache_size_limit", 1968*da0073e9SAndroid Build Coastguard Worker cache_size_limit, 1969*da0073e9SAndroid Build Coastguard Worker ): 1970*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*size, requires_grad=True) 1971*da0073e9SAndroid Build Coastguard Worker mod(x) 1972*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.inline_inbuilt_nn_modules: 1973*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 1974*da0073e9SAndroid Build Coastguard Worker else: 1975*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, num_submodules) 1976*da0073e9SAndroid Build Coastguard Worker 1977*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "accumulated_cache_size_limit", 2) 1978*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", False) 1979*da0073e9SAndroid Build Coastguard Worker def test_recompile_limit_on_freed_module(self): 1980*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 1981*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1982*da0073e9SAndroid Build Coastguard Worker super().__init__() 1983*da0073e9SAndroid Build Coastguard Worker self.lin = torch.nn.Linear(5, 5) 1984*da0073e9SAndroid Build Coastguard Worker 1985*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1986*da0073e9SAndroid Build Coastguard Worker return self.lin(x) 1987*da0073e9SAndroid Build Coastguard Worker 1988*da0073e9SAndroid Build Coastguard Worker def fn(x, mod): 1989*da0073e9SAndroid Build Coastguard Worker return mod(x) 1990*da0073e9SAndroid Build Coastguard Worker 1991*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounterWithBackend("eager") 1992*da0073e9SAndroid Build Coastguard Worker opt_mod = torch.compile(fn, backend=cnts) 1993*da0073e9SAndroid Build Coastguard Worker for i in range(8): 1994*da0073e9SAndroid Build Coastguard Worker mod = Mod() 1995*da0073e9SAndroid Build Coastguard Worker opt_mod(torch.randn(5, 5), mod) 1996*da0073e9SAndroid Build Coastguard Worker 1997*da0073e9SAndroid Build Coastguard Worker # fn compiles twice 1998*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 1999*da0073e9SAndroid Build Coastguard Worker 2000*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", True) 2001*da0073e9SAndroid Build Coastguard Worker def test_inline_inbuilt_nn_modules(self): 2002*da0073e9SAndroid Build Coastguard Worker size = (10, 10) 2003*da0073e9SAndroid Build Coastguard Worker cache_size_limit = 1 2004*da0073e9SAndroid Build Coastguard Worker num_submodules = 4 2005*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounterWithBackend("eager") 2006*da0073e9SAndroid Build Coastguard Worker 2007*da0073e9SAndroid Build Coastguard Worker class SubModule(torch.nn.Module): 2008*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2009*da0073e9SAndroid Build Coastguard Worker super().__init__() 2010*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(*size) 2011*da0073e9SAndroid Build Coastguard Worker 2012*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2013*da0073e9SAndroid Build Coastguard Worker a = torch.sin(torch.cos(x)) 2014*da0073e9SAndroid Build Coastguard Worker return self.linear(a) 2015*da0073e9SAndroid Build Coastguard Worker 2016*da0073e9SAndroid Build Coastguard Worker class MockModule(torch.nn.Module): 2017*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2018*da0073e9SAndroid Build Coastguard Worker super().__init__() 2019*da0073e9SAndroid Build Coastguard Worker self.mods = [SubModule() for _ in range(num_submodules)] 2020*da0073e9SAndroid Build Coastguard Worker self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods] 2021*da0073e9SAndroid Build Coastguard Worker 2022*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2023*da0073e9SAndroid Build Coastguard Worker for mod in self.mods: 2024*da0073e9SAndroid Build Coastguard Worker x = mod(x) 2025*da0073e9SAndroid Build Coastguard Worker return x 2026*da0073e9SAndroid Build Coastguard Worker 2027*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 2028*da0073e9SAndroid Build Coastguard Worker # Each submod is compiled separately and has a different nn module 2029*da0073e9SAndroid Build Coastguard Worker # guard. Ensure that recompilation logic is handle correctly. 2030*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch( 2031*da0073e9SAndroid Build Coastguard Worker "torch._dynamo.config.error_on_recompile", True 2032*da0073e9SAndroid Build Coastguard Worker ), unittest.mock.patch( 2033*da0073e9SAndroid Build Coastguard Worker "torch._dynamo.config.cache_size_limit", 2034*da0073e9SAndroid Build Coastguard Worker cache_size_limit, 2035*da0073e9SAndroid Build Coastguard Worker ): 2036*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*size, requires_grad=True) 2037*da0073e9SAndroid Build Coastguard Worker mod(x) 2038*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2039*da0073e9SAndroid Build Coastguard Worker 2040*da0073e9SAndroid Build Coastguard Worker def test_cache_size_limit_on_guarded_nn_modules(self): 2041*da0073e9SAndroid Build Coastguard Worker cache_size_limit = 2 2042*da0073e9SAndroid Build Coastguard Worker num_submodules = 4 2043*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounterWithBackend("eager") 2044*da0073e9SAndroid Build Coastguard Worker 2045*da0073e9SAndroid Build Coastguard Worker class SubModule(torch.nn.Module): 2046*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2047*da0073e9SAndroid Build Coastguard Worker super().__init__() 2048*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 2049*da0073e9SAndroid Build Coastguard Worker 2050*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2051*da0073e9SAndroid Build Coastguard Worker a = torch.sin(torch.cos(x)) 2052*da0073e9SAndroid Build Coastguard Worker return self.relu(a) 2053*da0073e9SAndroid Build Coastguard Worker 2054*da0073e9SAndroid Build Coastguard Worker class MockModule(torch.nn.Module): 2055*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2056*da0073e9SAndroid Build Coastguard Worker super().__init__() 2057*da0073e9SAndroid Build Coastguard Worker self.mods = [SubModule() for _ in range(num_submodules)] 2058*da0073e9SAndroid Build Coastguard Worker self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods] 2059*da0073e9SAndroid Build Coastguard Worker 2060*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2061*da0073e9SAndroid Build Coastguard Worker for mod in self.mods: 2062*da0073e9SAndroid Build Coastguard Worker x = mod(x) 2063*da0073e9SAndroid Build Coastguard Worker return x 2064*da0073e9SAndroid Build Coastguard Worker 2065*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 2066*da0073e9SAndroid Build Coastguard Worker # For the third iteration, we would reach the cache size limit, and 2067*da0073e9SAndroid Build Coastguard Worker # therefore the total number of expected frame count is 2 * 2068*da0073e9SAndroid Build Coastguard Worker # num_submodules. 2069*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch( 2070*da0073e9SAndroid Build Coastguard Worker "torch._dynamo.config.cache_size_limit", 2071*da0073e9SAndroid Build Coastguard Worker cache_size_limit, 2072*da0073e9SAndroid Build Coastguard Worker ): 2073*da0073e9SAndroid Build Coastguard Worker for size in [ 2074*da0073e9SAndroid Build Coastguard Worker (4,), 2075*da0073e9SAndroid Build Coastguard Worker (4, 4), 2076*da0073e9SAndroid Build Coastguard Worker (4, 4, 4), 2077*da0073e9SAndroid Build Coastguard Worker ]: 2078*da0073e9SAndroid Build Coastguard Worker x = torch.randn(size) 2079*da0073e9SAndroid Build Coastguard Worker mod(x) 2080*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.inline_inbuilt_nn_modules: 2081*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 2082*da0073e9SAndroid Build Coastguard Worker else: 2083*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2 * num_submodules) 2084*da0073e9SAndroid Build Coastguard Worker 2085*da0073e9SAndroid Build Coastguard Worker def test_recursion(self): 2086*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 2087*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2088*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize(cnt)(mod) 2089*da0073e9SAndroid Build Coastguard Worker 2090*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 2091*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize(cnt)(opt_mod) 2092*da0073e9SAndroid Build Coastguard Worker opt_mod(torch.randn(10, 10)) 2093*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2094*da0073e9SAndroid Build Coastguard Worker 2095*da0073e9SAndroid Build Coastguard Worker def test_composition(self): 2096*da0073e9SAndroid Build Coastguard Worker class InnerModule(torch.nn.Module): 2097*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2098*da0073e9SAndroid Build Coastguard Worker super().__init__() 2099*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 2100*da0073e9SAndroid Build Coastguard Worker 2101*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2102*da0073e9SAndroid Build Coastguard Worker return self.relu(torch.sin(x)) 2103*da0073e9SAndroid Build Coastguard Worker 2104*da0073e9SAndroid Build Coastguard Worker opt_inner_mod = InnerModule() 2105*da0073e9SAndroid Build Coastguard Worker 2106*da0073e9SAndroid Build Coastguard Worker class OuterModule(torch.nn.Module): 2107*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2108*da0073e9SAndroid Build Coastguard Worker super().__init__() 2109*da0073e9SAndroid Build Coastguard Worker self.mod = opt_inner_mod 2110*da0073e9SAndroid Build Coastguard Worker 2111*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2112*da0073e9SAndroid Build Coastguard Worker return self.mod(torch.cos(x)) 2113*da0073e9SAndroid Build Coastguard Worker 2114*da0073e9SAndroid Build Coastguard Worker outer_mod = OuterModule() 2115*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2116*da0073e9SAndroid Build Coastguard Worker opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) 2117*da0073e9SAndroid Build Coastguard Worker 2118*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 2119*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) 2120*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) 2121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2122*da0073e9SAndroid Build Coastguard Worker 2123*da0073e9SAndroid Build Coastguard Worker def test_composition_with_opt_mod(self): 2124*da0073e9SAndroid Build Coastguard Worker class InnerModule(torch.nn.Module): 2125*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2126*da0073e9SAndroid Build Coastguard Worker super().__init__() 2127*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 2128*da0073e9SAndroid Build Coastguard Worker 2129*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2130*da0073e9SAndroid Build Coastguard Worker return self.relu(torch.sin(x)) 2131*da0073e9SAndroid Build Coastguard Worker 2132*da0073e9SAndroid Build Coastguard Worker inner_mod = InnerModule() 2133*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2134*da0073e9SAndroid Build Coastguard Worker opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod) 2135*da0073e9SAndroid Build Coastguard Worker 2136*da0073e9SAndroid Build Coastguard Worker class OuterModule(torch.nn.Module): 2137*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2138*da0073e9SAndroid Build Coastguard Worker super().__init__() 2139*da0073e9SAndroid Build Coastguard Worker self.mod = opt_inner_mod 2140*da0073e9SAndroid Build Coastguard Worker 2141*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2142*da0073e9SAndroid Build Coastguard Worker return self.mod(torch.cos(x)) 2143*da0073e9SAndroid Build Coastguard Worker 2144*da0073e9SAndroid Build Coastguard Worker outer_mod = OuterModule() 2145*da0073e9SAndroid Build Coastguard Worker opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) 2146*da0073e9SAndroid Build Coastguard Worker 2147*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 2148*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) 2149*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) 2150*da0073e9SAndroid Build Coastguard Worker # There will be a graph break for the inner mod being OptimizedModule 2151*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 2152*da0073e9SAndroid Build Coastguard Worker 2153*da0073e9SAndroid Build Coastguard Worker def test_module_patch(self): 2154*da0073e9SAndroid Build Coastguard Worker mod = ModulePatch1() 2155*da0073e9SAndroid Build Coastguard Worker mod.forward = types.MethodType(ModulePatch2.forward, mod) 2156*da0073e9SAndroid Build Coastguard Worker 2157*da0073e9SAndroid Build Coastguard Worker def fn(x): 2158*da0073e9SAndroid Build Coastguard Worker return mod(x) 2159*da0073e9SAndroid Build Coastguard Worker 2160*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2161*da0073e9SAndroid Build Coastguard Worker torch.allclose( 2162*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("eager", nopython=True)(fn)(torch.ones(10)), 2163*da0073e9SAndroid Build Coastguard Worker torch.zeros(1), 2164*da0073e9SAndroid Build Coastguard Worker ) 2165*da0073e9SAndroid Build Coastguard Worker ) 2166*da0073e9SAndroid Build Coastguard Worker 2167*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False) 2168*da0073e9SAndroid Build Coastguard Worker def test_hooks_outer(self): 2169*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 2170*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 2171*da0073e9SAndroid Build Coastguard Worker return 2 * x + 1 2172*da0073e9SAndroid Build Coastguard Worker 2173*da0073e9SAndroid Build Coastguard Worker m = TestModule() 2174*da0073e9SAndroid Build Coastguard Worker 2175*da0073e9SAndroid Build Coastguard Worker def forward_hook( 2176*da0073e9SAndroid Build Coastguard Worker module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor 2177*da0073e9SAndroid Build Coastguard Worker ) -> torch.Tensor: 2178*da0073e9SAndroid Build Coastguard Worker return 2 * output + 1 2179*da0073e9SAndroid Build Coastguard Worker 2180*da0073e9SAndroid Build Coastguard Worker handle = m.register_forward_hook(forward_hook) 2181*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor(1.0, requires_grad=True) 2182*da0073e9SAndroid Build Coastguard Worker 2183*da0073e9SAndroid Build Coastguard Worker failure_reason = None 2184*da0073e9SAndroid Build Coastguard Worker 2185*da0073e9SAndroid Build Coastguard Worker def guard_fail_fn(failure): 2186*da0073e9SAndroid Build Coastguard Worker nonlocal failure_reason 2187*da0073e9SAndroid Build Coastguard Worker failure_reason = failure[0] 2188*da0073e9SAndroid Build Coastguard Worker 2189*da0073e9SAndroid Build Coastguard Worker compiled_m = torch._dynamo.optimize( 2190*da0073e9SAndroid Build Coastguard Worker guard_fail_fn=guard_fail_fn, backend="eager" 2191*da0073e9SAndroid Build Coastguard Worker )(m) 2192*da0073e9SAndroid Build Coastguard Worker 2193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_m(inp), m(inp)) 2194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_m(inp).item(), 7) 2195*da0073e9SAndroid Build Coastguard Worker self.assertTrue(failure_reason is None) 2196*da0073e9SAndroid Build Coastguard Worker 2197*da0073e9SAndroid Build Coastguard Worker # what if we remove our hook? we should recompile? 2198*da0073e9SAndroid Build Coastguard Worker handle.remove() 2199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_m(inp), m(inp)) 2200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_m(inp).item(), 3) 2201*da0073e9SAndroid Build Coastguard Worker # self.assertTrue(failure_reason == "hook") 2202*da0073e9SAndroid Build Coastguard Worker 2203*da0073e9SAndroid Build Coastguard Worker """ 2204*da0073e9SAndroid Build Coastguard Worker Summary: 2205*da0073e9SAndroid Build Coastguard Worker - removing a hook doesn't fail a guard, because we weren't compiling the hook 2206*da0073e9SAndroid Build Coastguard Worker (at least into the same graph) as forward in the first place! We do correctly 2207*da0073e9SAndroid Build Coastguard Worker omit calling the removed hook, but since this hook is a post forward hook, 2208*da0073e9SAndroid Build Coastguard Worker the 'RETURN' from forward is breaking the graph. 2209*da0073e9SAndroid Build Coastguard Worker 2210*da0073e9SAndroid Build Coastguard Worker Why is 'forward' the entrypoint to an InstructionTranslator, after I changed 2211*da0073e9SAndroid Build Coastguard Worker the eval_frame entrypoint to Module.__call__? 2212*da0073e9SAndroid Build Coastguard Worker """ 2213*da0073e9SAndroid Build Coastguard Worker 2214*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False) 2215*da0073e9SAndroid Build Coastguard Worker def test_hooks_inner(self): 2216*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 2217*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 2218*da0073e9SAndroid Build Coastguard Worker return 2 * x + 1 2219*da0073e9SAndroid Build Coastguard Worker 2220*da0073e9SAndroid Build Coastguard Worker m = TestModule() 2221*da0073e9SAndroid Build Coastguard Worker 2222*da0073e9SAndroid Build Coastguard Worker def forward_hook( 2223*da0073e9SAndroid Build Coastguard Worker module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor 2224*da0073e9SAndroid Build Coastguard Worker ) -> torch.Tensor: 2225*da0073e9SAndroid Build Coastguard Worker return 2 * output + 1 2226*da0073e9SAndroid Build Coastguard Worker 2227*da0073e9SAndroid Build Coastguard Worker handle = m.register_forward_hook(forward_hook) 2228*da0073e9SAndroid Build Coastguard Worker 2229*da0073e9SAndroid Build Coastguard Worker def outer_func(tensor): 2230*da0073e9SAndroid Build Coastguard Worker x = tensor * 2 + 1 2231*da0073e9SAndroid Build Coastguard Worker y = m(x) 2232*da0073e9SAndroid Build Coastguard Worker return y 2233*da0073e9SAndroid Build Coastguard Worker 2234*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor(1.0, requires_grad=True) 2235*da0073e9SAndroid Build Coastguard Worker 2236*da0073e9SAndroid Build Coastguard Worker failure_reason = None 2237*da0073e9SAndroid Build Coastguard Worker 2238*da0073e9SAndroid Build Coastguard Worker def guard_fail_fn(failure): 2239*da0073e9SAndroid Build Coastguard Worker nonlocal failure_reason 2240*da0073e9SAndroid Build Coastguard Worker failure_reason = failure[0] 2241*da0073e9SAndroid Build Coastguard Worker 2242*da0073e9SAndroid Build Coastguard Worker cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 2243*da0073e9SAndroid Build Coastguard Worker compiled_func = torch._dynamo.optimize( 2244*da0073e9SAndroid Build Coastguard Worker guard_fail_fn=guard_fail_fn, 2245*da0073e9SAndroid Build Coastguard Worker backend=cc, 2246*da0073e9SAndroid Build Coastguard Worker )(outer_func) 2247*da0073e9SAndroid Build Coastguard Worker 2248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_func(inp), outer_func(inp)) 2249*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_func(inp).item(), 15) 2250*da0073e9SAndroid Build Coastguard Worker 2251*da0073e9SAndroid Build Coastguard Worker # We are compiling 1 big graph for all 3 functions including the hook. 2252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cc.frame_count, 1) 2253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cc.op_count, 6) 2254*da0073e9SAndroid Build Coastguard Worker 2255*da0073e9SAndroid Build Coastguard Worker # If we remove the hook, we should recompile 2256*da0073e9SAndroid Build Coastguard Worker handle.remove() 2257*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_func(inp), outer_func(inp)) 2258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_func(inp).item(), 7) 2259*da0073e9SAndroid Build Coastguard Worker self.assertTrue("forward_hooks" in failure_reason) 2260*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cc.frame_count, 1 + 1) 2261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cc.op_count, 6 + 4) 2262*da0073e9SAndroid Build Coastguard Worker 2263*da0073e9SAndroid Build Coastguard Worker # what if instead of removing, we alter our hook? 2264*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 2265*da0073e9SAndroid Build Coastguard Worker m = TestModule() 2266*da0073e9SAndroid Build Coastguard Worker handle = m.register_forward_hook(forward_hook) 2267*da0073e9SAndroid Build Coastguard Worker failure_reason = None 2268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_func(inp), outer_func(inp)) 2269*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_func(inp).item(), 15) 2270*da0073e9SAndroid Build Coastguard Worker 2271*da0073e9SAndroid Build Coastguard Worker def new_forward_hook( 2272*da0073e9SAndroid Build Coastguard Worker module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor 2273*da0073e9SAndroid Build Coastguard Worker ) -> torch.Tensor: 2274*da0073e9SAndroid Build Coastguard Worker return 2 * output + 2 2275*da0073e9SAndroid Build Coastguard Worker 2276*da0073e9SAndroid Build Coastguard Worker m._forward_hooks[handle.id] = new_forward_hook 2277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_func(inp), outer_func(inp)) 2278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_func(inp).item(), 16) 2279*da0073e9SAndroid Build Coastguard Worker self.assertRegex(failure_reason, r"___check_obj_id\(L\['m'\]._forward_hooks") 2280*da0073e9SAndroid Build Coastguard Worker 2281*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "guard_nn_modules", False) 2282*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True) 2283*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", False) 2284*da0073e9SAndroid Build Coastguard Worker def test_hooks_skip_guards(self): 2285*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 2286*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 2287*da0073e9SAndroid Build Coastguard Worker return 2 * x + 1 2288*da0073e9SAndroid Build Coastguard Worker 2289*da0073e9SAndroid Build Coastguard Worker m = TestModule() 2290*da0073e9SAndroid Build Coastguard Worker 2291*da0073e9SAndroid Build Coastguard Worker def forward_hook( 2292*da0073e9SAndroid Build Coastguard Worker module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor 2293*da0073e9SAndroid Build Coastguard Worker ) -> torch.Tensor: 2294*da0073e9SAndroid Build Coastguard Worker return 2 * output + 1 2295*da0073e9SAndroid Build Coastguard Worker 2296*da0073e9SAndroid Build Coastguard Worker handle = m.register_forward_hook(forward_hook) 2297*da0073e9SAndroid Build Coastguard Worker 2298*da0073e9SAndroid Build Coastguard Worker def outer_func(tensor): 2299*da0073e9SAndroid Build Coastguard Worker x = tensor * 2 + 1 2300*da0073e9SAndroid Build Coastguard Worker y = m(x) 2301*da0073e9SAndroid Build Coastguard Worker return y 2302*da0073e9SAndroid Build Coastguard Worker 2303*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor(1.0, requires_grad=True) 2304*da0073e9SAndroid Build Coastguard Worker 2305*da0073e9SAndroid Build Coastguard Worker failure_reason = None 2306*da0073e9SAndroid Build Coastguard Worker 2307*da0073e9SAndroid Build Coastguard Worker def guard_fail_fn(failure): 2308*da0073e9SAndroid Build Coastguard Worker nonlocal failure_reason 2309*da0073e9SAndroid Build Coastguard Worker failure_reason = failure[0] 2310*da0073e9SAndroid Build Coastguard Worker 2311*da0073e9SAndroid Build Coastguard Worker cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 2312*da0073e9SAndroid Build Coastguard Worker compiled_func = torch._dynamo.optimize( 2313*da0073e9SAndroid Build Coastguard Worker guard_fail_fn=guard_fail_fn, 2314*da0073e9SAndroid Build Coastguard Worker backend=cc, 2315*da0073e9SAndroid Build Coastguard Worker )(outer_func) 2316*da0073e9SAndroid Build Coastguard Worker 2317*da0073e9SAndroid Build Coastguard Worker m = TestModule() 2318*da0073e9SAndroid Build Coastguard Worker handle = m.register_forward_hook(forward_hook) 2319*da0073e9SAndroid Build Coastguard Worker failure_reason = None 2320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_func(inp), outer_func(inp)) 2321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_func(inp).item(), 15) 2322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cc.frame_count, 1) 2323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cc.op_count, 6) 2324*da0073e9SAndroid Build Coastguard Worker 2325*da0073e9SAndroid Build Coastguard Worker # if we remove the hook, dynamo shouldn't notice 2326*da0073e9SAndroid Build Coastguard Worker handle.remove() 2327*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(compiled_func(inp), outer_func(inp)) 2328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiled_func(inp).item(), 15) 2329*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cc.frame_count, 1) 2330*da0073e9SAndroid Build Coastguard Worker 2331*da0073e9SAndroid Build Coastguard Worker def _forward_hook_test_helper(self, model): 2332*da0073e9SAndroid Build Coastguard Worker forward_handles = {} 2333*da0073e9SAndroid Build Coastguard Worker compiled_activations = {} 2334*da0073e9SAndroid Build Coastguard Worker eager_activations = {} 2335*da0073e9SAndroid Build Coastguard Worker activations = None 2336*da0073e9SAndroid Build Coastguard Worker 2337*da0073e9SAndroid Build Coastguard Worker def save_activations(name, mod, inp, out): 2338*da0073e9SAndroid Build Coastguard Worker activations[name] = inp 2339*da0073e9SAndroid Build Coastguard Worker 2340*da0073e9SAndroid Build Coastguard Worker for name, module in model.named_modules(): 2341*da0073e9SAndroid Build Coastguard Worker forward_handles[name] = module.register_forward_hook( 2342*da0073e9SAndroid Build Coastguard Worker partial(save_activations, name) 2343*da0073e9SAndroid Build Coastguard Worker ) 2344*da0073e9SAndroid Build Coastguard Worker 2345*da0073e9SAndroid Build Coastguard Worker compiled_model = torch.compile(model, backend="aot_eager") 2346*da0073e9SAndroid Build Coastguard Worker 2347*da0073e9SAndroid Build Coastguard Worker activations = compiled_activations 2348*da0073e9SAndroid Build Coastguard Worker for i in range(2): 2349*da0073e9SAndroid Build Coastguard Worker # second iteration is key, hooks would have fired during aot trace 2350*da0073e9SAndroid Build Coastguard Worker # on first iter 2351*da0073e9SAndroid Build Coastguard Worker compiled_activations.clear() 2352*da0073e9SAndroid Build Coastguard Worker x = torch.randn((20, 10)) 2353*da0073e9SAndroid Build Coastguard Worker pred = compiled_model(x) 2354*da0073e9SAndroid Build Coastguard Worker loss = pred.sum() 2355*da0073e9SAndroid Build Coastguard Worker loss.backward() 2356*da0073e9SAndroid Build Coastguard Worker 2357*da0073e9SAndroid Build Coastguard Worker activations = eager_activations 2358*da0073e9SAndroid Build Coastguard Worker for i in range(2): 2359*da0073e9SAndroid Build Coastguard Worker # second iteration is key, hooks would have fired during aot trace 2360*da0073e9SAndroid Build Coastguard Worker # on first iter 2361*da0073e9SAndroid Build Coastguard Worker eager_activations.clear() 2362*da0073e9SAndroid Build Coastguard Worker x = torch.randn((20, 10)) 2363*da0073e9SAndroid Build Coastguard Worker pred = model(x) 2364*da0073e9SAndroid Build Coastguard Worker loss = pred.sum() 2365*da0073e9SAndroid Build Coastguard Worker loss.backward() 2366*da0073e9SAndroid Build Coastguard Worker 2367*da0073e9SAndroid Build Coastguard Worker print(f"Recorded Layers: {compiled_activations.keys()}\n\n") 2368*da0073e9SAndroid Build Coastguard Worker print(f"Expected Layers: {eager_activations.keys()}") 2369*da0073e9SAndroid Build Coastguard Worker 2370*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compiled_activations.keys() == eager_activations.keys()) 2371*da0073e9SAndroid Build Coastguard Worker self.assertTrue(activations.keys() == forward_handles.keys()) 2372*da0073e9SAndroid Build Coastguard Worker 2373*da0073e9SAndroid Build Coastguard Worker def test_hooks_allowed_modules(self): 2374*da0073e9SAndroid Build Coastguard Worker # this test shouldn't care whether hook guards are enabled or not 2375*da0073e9SAndroid Build Coastguard Worker class ToyModel(torch.nn.Module): 2376*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2377*da0073e9SAndroid Build Coastguard Worker super().__init__() 2378*da0073e9SAndroid Build Coastguard Worker self.net = torch.nn.Sequential( 2379*da0073e9SAndroid Build Coastguard Worker *[torch.nn.Linear(10, 10000), torch.nn.ReLU()] 2380*da0073e9SAndroid Build Coastguard Worker + [torch.nn.Linear(10000, 5), torch.nn.ReLU()] 2381*da0073e9SAndroid Build Coastguard Worker ) 2382*da0073e9SAndroid Build Coastguard Worker 2383*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2384*da0073e9SAndroid Build Coastguard Worker return self.net(x) 2385*da0073e9SAndroid Build Coastguard Worker 2386*da0073e9SAndroid Build Coastguard Worker model = ToyModel() 2387*da0073e9SAndroid Build Coastguard Worker self._forward_hook_test_helper(model) 2388*da0073e9SAndroid Build Coastguard Worker 2389*da0073e9SAndroid Build Coastguard Worker def test_hooks_allowed_modules_compiles(self): 2390*da0073e9SAndroid Build Coastguard Worker class ToyModel(torch.nn.Module): 2391*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2392*da0073e9SAndroid Build Coastguard Worker super().__init__() 2393*da0073e9SAndroid Build Coastguard Worker self.net = torch.nn.Sequential( 2394*da0073e9SAndroid Build Coastguard Worker *[torch.nn.Linear(10, 10000), torch.nn.ReLU()] 2395*da0073e9SAndroid Build Coastguard Worker + [torch.nn.Linear(10000, 5), torch.nn.ReLU()] 2396*da0073e9SAndroid Build Coastguard Worker ) 2397*da0073e9SAndroid Build Coastguard Worker 2398*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2399*da0073e9SAndroid Build Coastguard Worker return self.net(x) 2400*da0073e9SAndroid Build Coastguard Worker 2401*da0073e9SAndroid Build Coastguard Worker model = ToyModel() 2402*da0073e9SAndroid Build Coastguard Worker activations = [] 2403*da0073e9SAndroid Build Coastguard Worker 2404*da0073e9SAndroid Build Coastguard Worker def save_activations(mod, inp, out): 2405*da0073e9SAndroid Build Coastguard Worker activations.append(inp) 2406*da0073e9SAndroid Build Coastguard Worker 2407*da0073e9SAndroid Build Coastguard Worker for name, module in model.named_modules(): 2408*da0073e9SAndroid Build Coastguard Worker module.register_forward_hook(save_activations) 2409*da0073e9SAndroid Build Coastguard Worker 2410*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2411*da0073e9SAndroid Build Coastguard Worker model = torch._dynamo.optimize(cnt, nopython=True)(model) 2412*da0073e9SAndroid Build Coastguard Worker for i in range(2): 2413*da0073e9SAndroid Build Coastguard Worker # second iteration is key, hooks would have fired during aot trace 2414*da0073e9SAndroid Build Coastguard Worker # on first iter 2415*da0073e9SAndroid Build Coastguard Worker activations.clear() 2416*da0073e9SAndroid Build Coastguard Worker x = torch.randn((20, 10)) 2417*da0073e9SAndroid Build Coastguard Worker pred = model(x) 2418*da0073e9SAndroid Build Coastguard Worker loss = pred.sum() 2419*da0073e9SAndroid Build Coastguard Worker loss.backward() 2420*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(activations), 6) 2421*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2422*da0073e9SAndroid Build Coastguard Worker 2423*da0073e9SAndroid Build Coastguard Worker def test_hooks_allowed_modules_compiles_self_contained(self): 2424*da0073e9SAndroid Build Coastguard Worker class ToyModel(torch.nn.Module): 2425*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2426*da0073e9SAndroid Build Coastguard Worker super().__init__() 2427*da0073e9SAndroid Build Coastguard Worker self.net = torch.nn.Sequential( 2428*da0073e9SAndroid Build Coastguard Worker *[torch.nn.Linear(10, 10000), torch.nn.ReLU()] 2429*da0073e9SAndroid Build Coastguard Worker + [torch.nn.Linear(10000, 5), torch.nn.ReLU()] 2430*da0073e9SAndroid Build Coastguard Worker ) 2431*da0073e9SAndroid Build Coastguard Worker 2432*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2433*da0073e9SAndroid Build Coastguard Worker return self.net(x) * self.net(x) 2434*da0073e9SAndroid Build Coastguard Worker 2435*da0073e9SAndroid Build Coastguard Worker model = ToyModel() 2436*da0073e9SAndroid Build Coastguard Worker forward_handles = {} 2437*da0073e9SAndroid Build Coastguard Worker 2438*da0073e9SAndroid Build Coastguard Worker def output_modifying_hook(mod, inp, out): 2439*da0073e9SAndroid Build Coastguard Worker return 2 * out + 1 2440*da0073e9SAndroid Build Coastguard Worker 2441*da0073e9SAndroid Build Coastguard Worker for name, module in model.named_modules(): 2442*da0073e9SAndroid Build Coastguard Worker forward_handles[name] = module.register_forward_hook(output_modifying_hook) 2443*da0073e9SAndroid Build Coastguard Worker 2444*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2445*da0073e9SAndroid Build Coastguard Worker 2446*da0073e9SAndroid Build Coastguard Worker x = torch.randn((20, 10)) 2447*da0073e9SAndroid Build Coastguard Worker pred_eager = model(x) 2448*da0073e9SAndroid Build Coastguard Worker loss_eager = pred_eager.sum() 2449*da0073e9SAndroid Build Coastguard Worker eager_loss_bwd = loss_eager.backward() 2450*da0073e9SAndroid Build Coastguard Worker 2451*da0073e9SAndroid Build Coastguard Worker model = torch._dynamo.optimize(cnt, nopython=True)(model) 2452*da0073e9SAndroid Build Coastguard Worker pred = model(x) 2453*da0073e9SAndroid Build Coastguard Worker 2454*da0073e9SAndroid Build Coastguard Worker loss = pred.sum() 2455*da0073e9SAndroid Build Coastguard Worker loss_bwd = loss.backward() 2456*da0073e9SAndroid Build Coastguard Worker 2457*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_loss_bwd, loss_bwd) 2458*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 2459*da0073e9SAndroid Build Coastguard Worker 2460*da0073e9SAndroid Build Coastguard Worker # Ndim change, recompile 2461*da0073e9SAndroid Build Coastguard Worker pred = model(torch.randn([10, 10, 10])) 2462*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 4) 2463*da0073e9SAndroid Build Coastguard Worker 2464*da0073e9SAndroid Build Coastguard Worker # Stable 2465*da0073e9SAndroid Build Coastguard Worker pred = model(torch.randn([10, 10, 10])) 2466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 4) 2467*da0073e9SAndroid Build Coastguard Worker 2468*da0073e9SAndroid Build Coastguard Worker def test_dunder_call_explicitly(self): 2469*da0073e9SAndroid Build Coastguard Worker # hooks should be triggered if explicit calling `__call__` 2470*da0073e9SAndroid Build Coastguard Worker class ToyModel(torch.nn.Module): 2471*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2472*da0073e9SAndroid Build Coastguard Worker super().__init__() 2473*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(10, 10000) 2474*da0073e9SAndroid Build Coastguard Worker 2475*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2476*da0073e9SAndroid Build Coastguard Worker return self.linear.__call__(x) 2477*da0073e9SAndroid Build Coastguard Worker 2478*da0073e9SAndroid Build Coastguard Worker model = ToyModel() 2479*da0073e9SAndroid Build Coastguard Worker self._forward_hook_test_helper(model) 2480*da0073e9SAndroid Build Coastguard Worker 2481*da0073e9SAndroid Build Coastguard Worker def test_backward_hooks(self): 2482*da0073e9SAndroid Build Coastguard Worker # this test shouldn't care whether hook guards are enabled or not 2483*da0073e9SAndroid Build Coastguard Worker 2484*da0073e9SAndroid Build Coastguard Worker class CustomLinear(torch.nn.Module): 2485*da0073e9SAndroid Build Coastguard Worker # not an 'allowed module', so should not graph-break 2486*da0073e9SAndroid Build Coastguard Worker def __init__(self, a, b): 2487*da0073e9SAndroid Build Coastguard Worker super().__init__() 2488*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.randn(a, b)) 2489*da0073e9SAndroid Build Coastguard Worker 2490*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2491*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, self.weight) 2492*da0073e9SAndroid Build Coastguard Worker 2493*da0073e9SAndroid Build Coastguard Worker class ToyModel(torch.nn.Module): 2494*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2495*da0073e9SAndroid Build Coastguard Worker super().__init__() 2496*da0073e9SAndroid Build Coastguard Worker self.net = torch.nn.Sequential( 2497*da0073e9SAndroid Build Coastguard Worker *[CustomLinear(10, 10)] 2498*da0073e9SAndroid Build Coastguard Worker + [CustomLinear(10, 10000)] 2499*da0073e9SAndroid Build Coastguard Worker + [CustomLinear(10000, 5)] 2500*da0073e9SAndroid Build Coastguard Worker ) 2501*da0073e9SAndroid Build Coastguard Worker 2502*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2503*da0073e9SAndroid Build Coastguard Worker return self.net(x) 2504*da0073e9SAndroid Build Coastguard Worker 2505*da0073e9SAndroid Build Coastguard Worker model = ToyModel() 2506*da0073e9SAndroid Build Coastguard Worker backward_hook_handles = {} 2507*da0073e9SAndroid Build Coastguard Worker pre_backward_hook_handles = {} 2508*da0073e9SAndroid Build Coastguard Worker 2509*da0073e9SAndroid Build Coastguard Worker grad_sizes = {} 2510*da0073e9SAndroid Build Coastguard Worker 2511*da0073e9SAndroid Build Coastguard Worker def backward_hook(name, mod, grad_inp, grad_out): 2512*da0073e9SAndroid Build Coastguard Worker grad_sizes[name] = ( 2513*da0073e9SAndroid Build Coastguard Worker (gi.shape for gi in grad_inp), 2514*da0073e9SAndroid Build Coastguard Worker (go.shape for go in grad_out), 2515*da0073e9SAndroid Build Coastguard Worker ) 2516*da0073e9SAndroid Build Coastguard Worker return None 2517*da0073e9SAndroid Build Coastguard Worker 2518*da0073e9SAndroid Build Coastguard Worker pre_grad_sizes = {} 2519*da0073e9SAndroid Build Coastguard Worker 2520*da0073e9SAndroid Build Coastguard Worker def backward_pre_hook(name, mod, grad_out): 2521*da0073e9SAndroid Build Coastguard Worker pre_grad_sizes[name] = (go.shape for go in grad_out) 2522*da0073e9SAndroid Build Coastguard Worker return None 2523*da0073e9SAndroid Build Coastguard Worker 2524*da0073e9SAndroid Build Coastguard Worker for name, module in model.named_modules(): 2525*da0073e9SAndroid Build Coastguard Worker backward_hook_handles[name] = module.register_full_backward_hook( 2526*da0073e9SAndroid Build Coastguard Worker partial(backward_hook, name) 2527*da0073e9SAndroid Build Coastguard Worker ) 2528*da0073e9SAndroid Build Coastguard Worker 2529*da0073e9SAndroid Build Coastguard Worker pre_backward_hook_handles[name] = module.register_full_backward_pre_hook( 2530*da0073e9SAndroid Build Coastguard Worker partial(backward_pre_hook, name) 2531*da0073e9SAndroid Build Coastguard Worker ) 2532*da0073e9SAndroid Build Coastguard Worker 2533*da0073e9SAndroid Build Coastguard Worker model = torch.compile(model, backend="aot_eager") 2534*da0073e9SAndroid Build Coastguard Worker 2535*da0073e9SAndroid Build Coastguard Worker for i in range(2): 2536*da0073e9SAndroid Build Coastguard Worker # second iteration is key, hooks would have fired during aot trace 2537*da0073e9SAndroid Build Coastguard Worker # on first iter 2538*da0073e9SAndroid Build Coastguard Worker x = torch.randn((20, 10)) 2539*da0073e9SAndroid Build Coastguard Worker pred = model(x) 2540*da0073e9SAndroid Build Coastguard Worker loss = pred.sum() 2541*da0073e9SAndroid Build Coastguard Worker loss.backward() 2542*da0073e9SAndroid Build Coastguard Worker 2543*da0073e9SAndroid Build Coastguard Worker self.assertTrue(grad_sizes.keys() == backward_hook_handles.keys()) 2544*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pre_grad_sizes.keys() == pre_backward_hook_handles.keys()) 2545*da0073e9SAndroid Build Coastguard Worker 2546*da0073e9SAndroid Build Coastguard Worker def test_udo_instance_method_as_hook(self): 2547*da0073e9SAndroid Build Coastguard Worker class CustomClass: 2548*da0073e9SAndroid Build Coastguard Worker def __init__(self, module): 2549*da0073e9SAndroid Build Coastguard Worker self.module = module 2550*da0073e9SAndroid Build Coastguard Worker self.handle = self.module.register_forward_pre_hook( 2551*da0073e9SAndroid Build Coastguard Worker self.func1, prepend=True, with_kwargs=True 2552*da0073e9SAndroid Build Coastguard Worker ) 2553*da0073e9SAndroid Build Coastguard Worker 2554*da0073e9SAndroid Build Coastguard Worker def func1(self, module, args, kwargs): 2555*da0073e9SAndroid Build Coastguard Worker return (args[0] + 1,), kwargs 2556*da0073e9SAndroid Build Coastguard Worker 2557*da0073e9SAndroid Build Coastguard Worker def __call__(self, x): 2558*da0073e9SAndroid Build Coastguard Worker return self.module(x) 2559*da0073e9SAndroid Build Coastguard Worker 2560*da0073e9SAndroid Build Coastguard Worker class ToyModel(torch.nn.Module): 2561*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2562*da0073e9SAndroid Build Coastguard Worker super().__init__() 2563*da0073e9SAndroid Build Coastguard Worker 2564*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2565*da0073e9SAndroid Build Coastguard Worker return x * x 2566*da0073e9SAndroid Build Coastguard Worker 2567*da0073e9SAndroid Build Coastguard Worker model = ToyModel() 2568*da0073e9SAndroid Build Coastguard Worker x = torch.zeros((3, 4)) 2569*da0073e9SAndroid Build Coastguard Worker obj = CustomClass(model) 2570*da0073e9SAndroid Build Coastguard Worker out = torch.compile(obj, fullgraph=True)(x) 2571*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, (x + 1) * (x + 1)) 2572*da0073e9SAndroid Build Coastguard Worker 2573*da0073e9SAndroid Build Coastguard Worker def test_module_dict_iter_name(self): 2574*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2575*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2576*da0073e9SAndroid Build Coastguard Worker super().__init__() 2577*da0073e9SAndroid Build Coastguard Worker self.activations = torch.nn.ModuleDict( 2578*da0073e9SAndroid Build Coastguard Worker [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]] 2579*da0073e9SAndroid Build Coastguard Worker ) 2580*da0073e9SAndroid Build Coastguard Worker 2581*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2582*da0073e9SAndroid Build Coastguard Worker for activation_name in self.activations: 2583*da0073e9SAndroid Build Coastguard Worker x = self.activations[activation_name](x) 2584*da0073e9SAndroid Build Coastguard Worker return x 2585*da0073e9SAndroid Build Coastguard Worker 2586*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2587*da0073e9SAndroid Build Coastguard Worker # Eager 2588*da0073e9SAndroid Build Coastguard Worker eager_res = MyModule()(torch.ones(10, 10)) 2589*da0073e9SAndroid Build Coastguard Worker 2590*da0073e9SAndroid Build Coastguard Worker # Compile 2591*da0073e9SAndroid Build Coastguard Worker optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10)) 2592*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_res, optim_res) 2593*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2594*da0073e9SAndroid Build Coastguard Worker 2595*da0073e9SAndroid Build Coastguard Worker def test_module_dict_iter_keys(self): 2596*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2597*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2598*da0073e9SAndroid Build Coastguard Worker super().__init__() 2599*da0073e9SAndroid Build Coastguard Worker self.activations = torch.nn.ModuleDict( 2600*da0073e9SAndroid Build Coastguard Worker [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]] 2601*da0073e9SAndroid Build Coastguard Worker ) 2602*da0073e9SAndroid Build Coastguard Worker 2603*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2604*da0073e9SAndroid Build Coastguard Worker for activation_name in self.activations.keys(): 2605*da0073e9SAndroid Build Coastguard Worker x = self.activations[activation_name](x) 2606*da0073e9SAndroid Build Coastguard Worker return x 2607*da0073e9SAndroid Build Coastguard Worker 2608*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2609*da0073e9SAndroid Build Coastguard Worker # Eager 2610*da0073e9SAndroid Build Coastguard Worker eager_res = MyModule()(torch.ones(10, 10)) 2611*da0073e9SAndroid Build Coastguard Worker 2612*da0073e9SAndroid Build Coastguard Worker # Compile 2613*da0073e9SAndroid Build Coastguard Worker optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10)) 2614*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_res, optim_res) 2615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2616*da0073e9SAndroid Build Coastguard Worker 2617*da0073e9SAndroid Build Coastguard Worker def test_module_setattr(self): 2618*da0073e9SAndroid Build Coastguard Worker models = torch.nn.Sequential(torch.nn.Linear(3, 3)) 2619*da0073e9SAndroid Build Coastguard Worker models[0].abc = False 2620*da0073e9SAndroid Build Coastguard Worker 2621*da0073e9SAndroid Build Coastguard Worker def run(): 2622*da0073e9SAndroid Build Coastguard Worker models[0].abc = True 2623*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 3) 2624*da0073e9SAndroid Build Coastguard Worker return models(x) 2625*da0073e9SAndroid Build Coastguard Worker 2626*da0073e9SAndroid Build Coastguard Worker run = torch.compile(run, fullgraph=True) 2627*da0073e9SAndroid Build Coastguard Worker run() 2628*da0073e9SAndroid Build Coastguard Worker self.assertTrue(models[0].abc) 2629*da0073e9SAndroid Build Coastguard Worker 2630*da0073e9SAndroid Build Coastguard Worker def test_assign_does_not_exist(self): 2631*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2632*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2633*da0073e9SAndroid Build Coastguard Worker self.text_encoding = x + 1 2634*da0073e9SAndroid Build Coastguard Worker return self.text_encoding 2635*da0073e9SAndroid Build Coastguard Worker 2636*da0073e9SAndroid Build Coastguard Worker mod = MyModule() 2637*da0073e9SAndroid Build Coastguard Worker out = torch.compile(mod, fullgraph=True)(torch.randn(10)) 2638*da0073e9SAndroid Build Coastguard Worker assert mod.text_encoding is out 2639*da0073e9SAndroid Build Coastguard Worker 2640*da0073e9SAndroid Build Coastguard Worker def test_module_dict_iter_values(self): 2641*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2642*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2643*da0073e9SAndroid Build Coastguard Worker super().__init__() 2644*da0073e9SAndroid Build Coastguard Worker self.activations = torch.nn.ModuleDict( 2645*da0073e9SAndroid Build Coastguard Worker [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]] 2646*da0073e9SAndroid Build Coastguard Worker ) 2647*da0073e9SAndroid Build Coastguard Worker 2648*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2649*da0073e9SAndroid Build Coastguard Worker for activation in self.activations.values(): 2650*da0073e9SAndroid Build Coastguard Worker x = activation(x) 2651*da0073e9SAndroid Build Coastguard Worker return x 2652*da0073e9SAndroid Build Coastguard Worker 2653*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2654*da0073e9SAndroid Build Coastguard Worker # Eager 2655*da0073e9SAndroid Build Coastguard Worker eager_res = MyModule()(torch.ones(10, 10)) 2656*da0073e9SAndroid Build Coastguard Worker 2657*da0073e9SAndroid Build Coastguard Worker # Compile 2658*da0073e9SAndroid Build Coastguard Worker optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10)) 2659*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_res, optim_res) 2660*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2661*da0073e9SAndroid Build Coastguard Worker 2662*da0073e9SAndroid Build Coastguard Worker def test_unspecialized_seq(self): 2663*da0073e9SAndroid Build Coastguard Worker models = torch.nn.Sequential(torch.nn.Linear(3, 3)) 2664*da0073e9SAndroid Build Coastguard Worker 2665*da0073e9SAndroid Build Coastguard Worker def fn(x): 2666*da0073e9SAndroid Build Coastguard Worker models[0].training = False 2667*da0073e9SAndroid Build Coastguard Worker return models(x) 2668*da0073e9SAndroid Build Coastguard Worker 2669*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 2670*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 3) 2671*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 2672*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 2673*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 2674*da0073e9SAndroid Build Coastguard Worker 2675*da0073e9SAndroid Build Coastguard Worker def test_no_op_assignment(self): 2676*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 2677*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2678*da0073e9SAndroid Build Coastguard Worker super().__init__() 2679*da0073e9SAndroid Build Coastguard Worker self.buffer = torch.rand([4]) 2680*da0073e9SAndroid Build Coastguard Worker 2681*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2682*da0073e9SAndroid Build Coastguard Worker # should be a no-op, but causes dynamo to lose the static input 2683*da0073e9SAndroid Build Coastguard Worker x = x + 1 2684*da0073e9SAndroid Build Coastguard Worker self.buffer = self.buffer.to(x) 2685*da0073e9SAndroid Build Coastguard Worker return self.buffer + x 2686*da0073e9SAndroid Build Coastguard Worker 2687*da0073e9SAndroid Build Coastguard Worker compiles_without_buffers = 0 2688*da0073e9SAndroid Build Coastguard Worker 2689*da0073e9SAndroid Build Coastguard Worker def debug_compile(gm, *args, **kwargs): 2690*da0073e9SAndroid Build Coastguard Worker nonlocal compiles_without_buffers 2691*da0073e9SAndroid Build Coastguard Worker compiles_without_buffers += len(list(gm.buffers())) == 0 2692*da0073e9SAndroid Build Coastguard Worker return gm 2693*da0073e9SAndroid Build Coastguard Worker 2694*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=debug_compile) 2695*da0073e9SAndroid Build Coastguard Worker def foo(mod, x): 2696*da0073e9SAndroid Build Coastguard Worker return mod(x) 2697*da0073e9SAndroid Build Coastguard Worker 2698*da0073e9SAndroid Build Coastguard Worker mod = Mod() 2699*da0073e9SAndroid Build Coastguard Worker foo(mod, torch.rand([4])) 2700*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.inline_inbuilt_nn_modules: 2701*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiles_without_buffers, 1) 2702*da0073e9SAndroid Build Coastguard Worker else: 2703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiles_without_buffers, 0) 2704*da0073e9SAndroid Build Coastguard Worker 2705*da0073e9SAndroid Build Coastguard Worker foo(mod, torch.rand([4], dtype=torch.half)) 2706*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.inline_inbuilt_nn_modules: 2707*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiles_without_buffers, 2) 2708*da0073e9SAndroid Build Coastguard Worker else: 2709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compiles_without_buffers, 1) 2710*da0073e9SAndroid Build Coastguard Worker 2711*da0073e9SAndroid Build Coastguard Worker class Mod2(Mod): 2712*da0073e9SAndroid Build Coastguard Worker def __setattr__(self, name, value): 2713*da0073e9SAndroid Build Coastguard Worker return super().__setattr__(name, value) 2714*da0073e9SAndroid Build Coastguard Worker 2715*da0073e9SAndroid Build Coastguard Worker foo(Mod2(), torch.rand([4])) 2716*da0073e9SAndroid Build Coastguard Worker # causes two compilations, bc unimplemented custom setattr 2717*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compiles_without_buffers >= 2) 2718*da0073e9SAndroid Build Coastguard Worker 2719*da0073e9SAndroid Build Coastguard Worker def test_unspec_non_inlinable_module(self): 2720*da0073e9SAndroid Build Coastguard Worker mod = UnspecNonInlinableModule() 2721*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(mod) 2722*da0073e9SAndroid Build Coastguard Worker x = torch.randn(100) 2723*da0073e9SAndroid Build Coastguard Worker actual = opt_fn(x) 2724*da0073e9SAndroid Build Coastguard Worker expected = mod(x) 2725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 2726*da0073e9SAndroid Build Coastguard Worker 2727*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2728*da0073e9SAndroid Build Coastguard Worker def test_mark_static_previously_seen_tensor(self): 2729*da0073e9SAndroid Build Coastguard Worker # This test verifies that dynamo will mark 2730*da0073e9SAndroid Build Coastguard Worker # the buffers/params of a module as static 2731*da0073e9SAndroid Build Coastguard Worker # even if this param was previously seen 2732*da0073e9SAndroid Build Coastguard Worker # (ex. as a different input) 2733*da0073e9SAndroid Build Coastguard Worker num_compiles = 0 2734*da0073e9SAndroid Build Coastguard Worker 2735*da0073e9SAndroid Build Coastguard Worker def debug_compiler(gm, _): 2736*da0073e9SAndroid Build Coastguard Worker nonlocal num_compiles 2737*da0073e9SAndroid Build Coastguard Worker num_compiles += 1 2738*da0073e9SAndroid Build Coastguard Worker 2739*da0073e9SAndroid Build Coastguard Worker input_nodes = [ 2740*da0073e9SAndroid Build Coastguard Worker n for n in gm.graph.nodes if n.op == "placeholder" and n.name == "l_b_" 2741*da0073e9SAndroid Build Coastguard Worker ] 2742*da0073e9SAndroid Build Coastguard Worker 2743*da0073e9SAndroid Build Coastguard Worker self.assertGreater(len(input_nodes), 0) 2744*da0073e9SAndroid Build Coastguard Worker for input_node in input_nodes: 2745*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2746*da0073e9SAndroid Build Coastguard Worker input_node.meta["tensor_dict"]["_dynamo_static_input_type"], 2747*da0073e9SAndroid Build Coastguard Worker "unguarded", 2748*da0073e9SAndroid Build Coastguard Worker ) 2749*da0073e9SAndroid Build Coastguard Worker 2750*da0073e9SAndroid Build Coastguard Worker return gm 2751*da0073e9SAndroid Build Coastguard Worker 2752*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 2753*da0073e9SAndroid Build Coastguard Worker def __init__(self, buf) -> None: 2754*da0073e9SAndroid Build Coastguard Worker super().__init__() 2755*da0073e9SAndroid Build Coastguard Worker # Changing this one to nn.Buffer fails because `nn.Buffer` does a .detach() 2756*da0073e9SAndroid Build Coastguard Worker # so the value in self.tx.output.side_effects will no longer evaluate to True 2757*da0073e9SAndroid Build Coastguard Worker self.register_buffer("buf", buf) 2758*da0073e9SAndroid Build Coastguard Worker 2759*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2760*da0073e9SAndroid Build Coastguard Worker return self.buf * x 2761*da0073e9SAndroid Build Coastguard Worker 2762*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(backend=debug_compiler) 2763*da0073e9SAndroid Build Coastguard Worker def fn(x, b, mod): 2764*da0073e9SAndroid Build Coastguard Worker z = b + 1 2765*da0073e9SAndroid Build Coastguard Worker return z * mod(x) 2766*da0073e9SAndroid Build Coastguard Worker 2767*da0073e9SAndroid Build Coastguard Worker buf = torch.ones(2, 2) 2768*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(2) 2769*da0073e9SAndroid Build Coastguard Worker mod = TestModule(buf) 2770*da0073e9SAndroid Build Coastguard Worker fn(inp, buf, mod) 2771*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_compiles, 1) 2772*da0073e9SAndroid Build Coastguard Worker 2773*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2774*da0073e9SAndroid Build Coastguard Worker def test_mark_static_nn_module_tensor(self): 2775*da0073e9SAndroid Build Coastguard Worker # This test verifies that dynamo will mark 2776*da0073e9SAndroid Build Coastguard Worker # the nn module tensor attributes as static 2777*da0073e9SAndroid Build Coastguard Worker num_compiles = 0 2778*da0073e9SAndroid Build Coastguard Worker 2779*da0073e9SAndroid Build Coastguard Worker def debug_compiler(gm, _): 2780*da0073e9SAndroid Build Coastguard Worker nonlocal num_compiles 2781*da0073e9SAndroid Build Coastguard Worker num_compiles += 1 2782*da0073e9SAndroid Build Coastguard Worker 2783*da0073e9SAndroid Build Coastguard Worker input_nodes = [ 2784*da0073e9SAndroid Build Coastguard Worker n 2785*da0073e9SAndroid Build Coastguard Worker for n in gm.graph.nodes 2786*da0073e9SAndroid Build Coastguard Worker if n.op == "placeholder" and n.name == "l_mod_buf" 2787*da0073e9SAndroid Build Coastguard Worker ] 2788*da0073e9SAndroid Build Coastguard Worker 2789*da0073e9SAndroid Build Coastguard Worker self.assertGreater(len(input_nodes), 0) 2790*da0073e9SAndroid Build Coastguard Worker for input_node in input_nodes: 2791*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2792*da0073e9SAndroid Build Coastguard Worker input_node.meta["tensor_dict"]["_dynamo_static_input_type"], 2793*da0073e9SAndroid Build Coastguard Worker "unguarded", 2794*da0073e9SAndroid Build Coastguard Worker ) 2795*da0073e9SAndroid Build Coastguard Worker 2796*da0073e9SAndroid Build Coastguard Worker return gm 2797*da0073e9SAndroid Build Coastguard Worker 2798*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 2799*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2800*da0073e9SAndroid Build Coastguard Worker super().__init__() 2801*da0073e9SAndroid Build Coastguard Worker self.buf = torch.ones(2, 2) 2802*da0073e9SAndroid Build Coastguard Worker 2803*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2804*da0073e9SAndroid Build Coastguard Worker return self.buf * x 2805*da0073e9SAndroid Build Coastguard Worker 2806*da0073e9SAndroid Build Coastguard Worker mod = TestModule() 2807*da0073e9SAndroid Build Coastguard Worker 2808*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(backend=debug_compiler) 2809*da0073e9SAndroid Build Coastguard Worker def fn(x): 2810*da0073e9SAndroid Build Coastguard Worker return x * mod(x) 2811*da0073e9SAndroid Build Coastguard Worker 2812*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(2) 2813*da0073e9SAndroid Build Coastguard Worker fn(inp) 2814*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_compiles, 1) 2815*da0073e9SAndroid Build Coastguard Worker 2816*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2817*da0073e9SAndroid Build Coastguard Worker @torch._inductor.config.patch("freezing", True) 2818*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 2819*da0073e9SAndroid Build Coastguard Worker def test_mark_static_with_freezing(self): 2820*da0073e9SAndroid Build Coastguard Worker # This test verifies that dynamo will 2821*da0073e9SAndroid Build Coastguard Worker # add buffers/params as attributes of the 2822*da0073e9SAndroid Build Coastguard Worker # graph w/ guards if freezing is enabled 2823*da0073e9SAndroid Build Coastguard Worker num_compiles = 0 2824*da0073e9SAndroid Build Coastguard Worker 2825*da0073e9SAndroid Build Coastguard Worker def debug_compiler(gm, _): 2826*da0073e9SAndroid Build Coastguard Worker nonlocal num_compiles 2827*da0073e9SAndroid Build Coastguard Worker num_compiles += 1 2828*da0073e9SAndroid Build Coastguard Worker 2829*da0073e9SAndroid Build Coastguard Worker input_nodes = [ 2830*da0073e9SAndroid Build Coastguard Worker n for n in gm.graph.nodes if n.op == "placeholder" and n.name == "l_b_" 2831*da0073e9SAndroid Build Coastguard Worker ] 2832*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(input_nodes), 0) 2833*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(gm.buffers())), 1) 2834*da0073e9SAndroid Build Coastguard Worker return gm 2835*da0073e9SAndroid Build Coastguard Worker 2836*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 2837*da0073e9SAndroid Build Coastguard Worker def __init__(self, buf) -> None: 2838*da0073e9SAndroid Build Coastguard Worker super().__init__() 2839*da0073e9SAndroid Build Coastguard Worker self.buf = torch.nn.Buffer(buf) 2840*da0073e9SAndroid Build Coastguard Worker 2841*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2842*da0073e9SAndroid Build Coastguard Worker return self.buf * x 2843*da0073e9SAndroid Build Coastguard Worker 2844*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(backend=debug_compiler) 2845*da0073e9SAndroid Build Coastguard Worker def fn(x, mod): 2846*da0073e9SAndroid Build Coastguard Worker return mod(x) 2847*da0073e9SAndroid Build Coastguard Worker 2848*da0073e9SAndroid Build Coastguard Worker buf = torch.ones(2, 2) 2849*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(2) 2850*da0073e9SAndroid Build Coastguard Worker mod = TestModule(buf) 2851*da0073e9SAndroid Build Coastguard Worker fn(inp, mod) 2852*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_compiles, 1) 2853*da0073e9SAndroid Build Coastguard Worker mod.buf = torch.rand_like(buf) 2854*da0073e9SAndroid Build Coastguard Worker fn(inp, mod) 2855*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_compiles, 2) 2856*da0073e9SAndroid Build Coastguard Worker 2857*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "guard_nn_modules", True) 2858*da0073e9SAndroid Build Coastguard Worker def test_guard_on_torch_nn_modules(self): 2859*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/110048 2860*da0073e9SAndroid Build Coastguard Worker 2861*da0073e9SAndroid Build Coastguard Worker class MockModule(torch.nn.Module): 2862*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2863*da0073e9SAndroid Build Coastguard Worker super().__init__() 2864*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(10, 10) 2865*da0073e9SAndroid Build Coastguard Worker self.multiplier = 10 2866*da0073e9SAndroid Build Coastguard Worker 2867*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2868*da0073e9SAndroid Build Coastguard Worker return self.linear(x) * self.multiplier 2869*da0073e9SAndroid Build Coastguard Worker 2870*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 2871*da0073e9SAndroid Build Coastguard Worker 2872*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2873*da0073e9SAndroid Build Coastguard Worker 2874*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt) 2875*da0073e9SAndroid Build Coastguard Worker def generate(x, c): 2876*da0073e9SAndroid Build Coastguard Worker return mod(x) + c 2877*da0073e9SAndroid Build Coastguard Worker 2878*da0073e9SAndroid Build Coastguard Worker for _ in range(0, 10): 2879*da0073e9SAndroid Build Coastguard Worker generate(torch.randn(10, 10), 0) 2880*da0073e9SAndroid Build Coastguard Worker generate(torch.randn(10, 10), 1) 2881*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 2882*da0073e9SAndroid Build Coastguard Worker 2883*da0073e9SAndroid Build Coastguard Worker # Ensure that modification in user module causes recompile 2884*da0073e9SAndroid Build Coastguard Worker mod.multiplier = 11 2885*da0073e9SAndroid Build Coastguard Worker generate(torch.randn(10, 10), 0) 2886*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 3) 2887*da0073e9SAndroid Build Coastguard Worker 2888*da0073e9SAndroid Build Coastguard Worker def test_setattr_on_compiled_module(self): 2889*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/114844 2890*da0073e9SAndroid Build Coastguard Worker 2891*da0073e9SAndroid Build Coastguard Worker class ReplayMutation(torch.nn.Module): 2892*da0073e9SAndroid Build Coastguard Worker def __init__(self, inp_size, out_size, inner_size): 2893*da0073e9SAndroid Build Coastguard Worker super().__init__() 2894*da0073e9SAndroid Build Coastguard Worker self.Linear1 = torch.nn.Linear(inp_size, inner_size) 2895*da0073e9SAndroid Build Coastguard Worker self.Linear2 = torch.nn.Linear(inner_size, out_size) 2896*da0073e9SAndroid Build Coastguard Worker self.x = None 2897*da0073e9SAndroid Build Coastguard Worker 2898*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 2899*da0073e9SAndroid Build Coastguard Worker res = self.Linear1(inp) 2900*da0073e9SAndroid Build Coastguard Worker self.x = res 2901*da0073e9SAndroid Build Coastguard Worker return self.Linear2(res) 2902*da0073e9SAndroid Build Coastguard Worker 2903*da0073e9SAndroid Build Coastguard Worker N, D_in, H, D_out, inner = 2, 2, 2, 2, 4 2904*da0073e9SAndroid Build Coastguard Worker model = ReplayMutation(D_in, H, inner) 2905*da0073e9SAndroid Build Coastguard Worker model2 = copy.deepcopy(model) 2906*da0073e9SAndroid Build Coastguard Worker input = torch.ones(N, D_in) 2907*da0073e9SAndroid Build Coastguard Worker 2908*da0073e9SAndroid Build Coastguard Worker # Keep some intermediate value in model.x 2909*da0073e9SAndroid Build Coastguard Worker model.x = torch.tensor([[100, 100, 100, 100], [200, 200, 200, 200]]) 2910*da0073e9SAndroid Build Coastguard Worker model(input) 2911*da0073e9SAndroid Build Coastguard Worker 2912*da0073e9SAndroid Build Coastguard Worker compiled_model = torch.compile(model2, backend="eager") 2913*da0073e9SAndroid Build Coastguard Worker compiled_model.x = torch.tensor([[100, 100, 100, 100], [200, 200, 200, 200]]) 2914*da0073e9SAndroid Build Coastguard Worker compiled_model(input) 2915*da0073e9SAndroid Build Coastguard Worker 2916*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model.x, compiled_model.x) 2917*da0073e9SAndroid Build Coastguard Worker 2918*da0073e9SAndroid Build Coastguard Worker def test_globals_change_in_other_file(self): 2919*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 2920*da0073e9SAndroid Build Coastguard Worker def fn(x): 2921*da0073e9SAndroid Build Coastguard Worker update_global() 2922*da0073e9SAndroid Build Coastguard Worker a = test_functions.update_global(x) 2923*da0073e9SAndroid Build Coastguard Worker # Ensure that the updated global values are read 2924*da0073e9SAndroid Build Coastguard Worker return x * a * (_variable + _variable1 + test_functions._variable) 2925*da0073e9SAndroid Build Coastguard Worker 2926*da0073e9SAndroid Build Coastguard Worker res = fn(torch.ones(10)) 2927*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_variable, 1) 2928*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_variable1, 1) 2929*da0073e9SAndroid Build Coastguard Worker # Ensure that the reconstructed bytecode updates the global value in the 2930*da0073e9SAndroid Build Coastguard Worker # other file. 2931*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test_functions._variable, 1) 2932*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, 3 * torch.ones(10)) 2933*da0073e9SAndroid Build Coastguard Worker 2934*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2935*da0073e9SAndroid Build Coastguard Worker "inductor" not in torch._dynamo.list_backends(), 2936*da0073e9SAndroid Build Coastguard Worker "inductor backend is not available", 2937*da0073e9SAndroid Build Coastguard Worker ) 2938*da0073e9SAndroid Build Coastguard Worker def test_save_and_load_inductor(self): 2939*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 2940*da0073e9SAndroid Build Coastguard Worker opt_mod = torch.compile(mod, backend="inductor") 2941*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(10, 10) 2942*da0073e9SAndroid Build Coastguard Worker opt_mod(inp) 2943*da0073e9SAndroid Build Coastguard Worker 2944*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tmpdirname: 2945*da0073e9SAndroid Build Coastguard Worker torch.save(opt_mod, os.path.join(tmpdirname, "model.pt")) 2946*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) 2947*da0073e9SAndroid Build Coastguard Worker loaded_model(inp) 2948*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same_two_models(loaded_model, mod, [inp])) 2949*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same_two_models(loaded_model, opt_mod, [inp])) 2950*da0073e9SAndroid Build Coastguard Worker 2951*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() # force recompiles 2952*da0073e9SAndroid Build Coastguard Worker torch._inductor.metrics.generated_kernel_count = 0 2953*da0073e9SAndroid Build Coastguard Worker loaded_model(inp) 2954*da0073e9SAndroid Build Coastguard Worker self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0) 2955*da0073e9SAndroid Build Coastguard Worker 2956*da0073e9SAndroid Build Coastguard Worker def test_save_and_load_all_backends(self): 2957*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 2958*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(10, 10) 2959*da0073e9SAndroid Build Coastguard Worker for backend in torch._dynamo.list_backends(): 2960*da0073e9SAndroid Build Coastguard Worker try: 2961*da0073e9SAndroid Build Coastguard Worker opt_mod = torch.compile(mod, backend=backend) 2962*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tmpdirname: 2963*da0073e9SAndroid Build Coastguard Worker torch.save(opt_mod, os.path.join(tmpdirname, "model.pt")) 2964*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) 2965*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() # force recompiles 2966*da0073e9SAndroid Build Coastguard Worker torch._inductor.metrics.generated_kernel_count = 0 2967*da0073e9SAndroid Build Coastguard Worker opt_mod(inp) 2968*da0073e9SAndroid Build Coastguard Worker opt_success = torch._inductor.metrics.generated_kernel_count == 0 2969*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() # force recompiles 2970*da0073e9SAndroid Build Coastguard Worker torch._inductor.metrics.generated_kernel_count = 0 2971*da0073e9SAndroid Build Coastguard Worker loaded_model(inp) 2972*da0073e9SAndroid Build Coastguard Worker loaded_success = torch._inductor.metrics.generated_kernel_count == 0 2973*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_success, loaded_success) 2974*da0073e9SAndroid Build Coastguard Worker except torch._dynamo.exc.BackendCompilerFailed: 2975*da0073e9SAndroid Build Coastguard Worker pass 2976*da0073e9SAndroid Build Coastguard Worker 2977*da0073e9SAndroid Build Coastguard Worker def test_monkeypatching_forward(self): 2978*da0073e9SAndroid Build Coastguard Worker class FakeModule(torch.nn.Module): 2979*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2980*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 2981*da0073e9SAndroid Build Coastguard Worker 2982*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2983*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 2984*da0073e9SAndroid Build Coastguard Worker super().__init__() 2985*da0073e9SAndroid Build Coastguard Worker 2986*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2987*da0073e9SAndroid Build Coastguard Worker return torch.cos(x) 2988*da0073e9SAndroid Build Coastguard Worker 2989*da0073e9SAndroid Build Coastguard Worker def helper(): 2990*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 2991*da0073e9SAndroid Build Coastguard Worker mod = MyModule(3) 2992*da0073e9SAndroid Build Coastguard Worker 2993*da0073e9SAndroid Build Coastguard Worker def fn(x): 2994*da0073e9SAndroid Build Coastguard Worker return mod(x) 2995*da0073e9SAndroid Build Coastguard Worker 2996*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2997*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 2998*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 2999*da0073e9SAndroid Build Coastguard Worker 3000*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 3001*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 3002*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 3003*da0073e9SAndroid Build Coastguard Worker 3004*da0073e9SAndroid Build Coastguard Worker # Monkeypatch forward 3005*da0073e9SAndroid Build Coastguard Worker mod.forward = types.MethodType(FakeModule.forward, mod) 3006*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 3007*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 3008*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 3009*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 3010*da0073e9SAndroid Build Coastguard Worker 3011*da0073e9SAndroid Build Coastguard Worker helper() 3012*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True): 3013*da0073e9SAndroid Build Coastguard Worker helper() 3014*da0073e9SAndroid Build Coastguard Worker 3015*da0073e9SAndroid Build Coastguard Worker def test_user_defined_nn_module_dynamic(self): 3016*da0073e9SAndroid Build Coastguard Worker class Conv2d(torch.nn.Conv2d): 3017*da0073e9SAndroid Build Coastguard Worker def __init__(self, *args, **kwargs): 3018*da0073e9SAndroid Build Coastguard Worker super().__init__(*args, **kwargs) 3019*da0073e9SAndroid Build Coastguard Worker 3020*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3021*da0073e9SAndroid Build Coastguard Worker x = torch.nn.functional.conv2d( 3022*da0073e9SAndroid Build Coastguard Worker x, 3023*da0073e9SAndroid Build Coastguard Worker self.weight, 3024*da0073e9SAndroid Build Coastguard Worker self.bias, 3025*da0073e9SAndroid Build Coastguard Worker self.stride, 3026*da0073e9SAndroid Build Coastguard Worker self.padding, 3027*da0073e9SAndroid Build Coastguard Worker self.dilation, 3028*da0073e9SAndroid Build Coastguard Worker self.groups, 3029*da0073e9SAndroid Build Coastguard Worker ) 3030*da0073e9SAndroid Build Coastguard Worker return x 3031*da0073e9SAndroid Build Coastguard Worker 3032*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3033*da0073e9SAndroid Build Coastguard Worker mod1 = Conv2d(64, 64, kernel_size=(2, 2), stride=(1, 1)) 3034*da0073e9SAndroid Build Coastguard Worker mod2 = Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2)) 3035*da0073e9SAndroid Build Coastguard Worker mod3 = Conv2d(64, 64, kernel_size=(2, 2), stride=(3, 3)) 3036*da0073e9SAndroid Build Coastguard Worker 3037*da0073e9SAndroid Build Coastguard Worker opt_mod1 = torch.compile(mod1, backend=cnts, fullgraph=True) 3038*da0073e9SAndroid Build Coastguard Worker opt_mod2 = torch.compile(mod2, backend=cnts, fullgraph=True) 3039*da0073e9SAndroid Build Coastguard Worker opt_mod3 = torch.compile(mod3, backend=cnts, fullgraph=True) 3040*da0073e9SAndroid Build Coastguard Worker 3041*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 64, 64, 64) 3042*da0073e9SAndroid Build Coastguard Worker opt_mod1(x) 3043*da0073e9SAndroid Build Coastguard Worker opt_mod2(x) 3044*da0073e9SAndroid Build Coastguard Worker opt_mod3(x) 3045*da0073e9SAndroid Build Coastguard Worker 3046*da0073e9SAndroid Build Coastguard Worker # Must be 3 compilations. If not marked static there would be 2, because strides would be converted to symints. 3047*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 3) 3048*da0073e9SAndroid Build Coastguard Worker 3049*da0073e9SAndroid Build Coastguard Worker 3050*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 3051*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 3052*da0073e9SAndroid Build Coastguard Worker 3053*da0073e9SAndroid Build Coastguard Worker run_tests() 3054