1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# flake8: noqa: F401 8*523fa7a6SAndroid Build Coastguard Workerimport functools 9*523fa7a6SAndroid Build Coastguard Workerimport inspect 10*523fa7a6SAndroid Build Coastguard Workerimport os 11*523fa7a6SAndroid Build Coastguard Workerimport random 12*523fa7a6SAndroid Build Coastguard Workerimport unittest 13*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict, Optional, Tuple, Type 14*523fa7a6SAndroid Build Coastguard Workerfrom unittest import skip, skipUnless 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.control_flow as control_flow 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Worker# @manual=//executorch/extension/pytree:pybindings 21*523fa7a6SAndroid Build Coastguard Workerimport executorch.extension.pytree as pytree 22*523fa7a6SAndroid Build Coastguard Workerimport torch 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import ( 25*523fa7a6SAndroid Build Coastguard Worker CaptureConfig, 26*523fa7a6SAndroid Build Coastguard Worker EdgeCompileConfig, 27*523fa7a6SAndroid Build Coastguard Worker ExecutorchBackendConfig, 28*523fa7a6SAndroid Build Coastguard Worker memory, 29*523fa7a6SAndroid Build Coastguard Worker) 30*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dynamic_shape import DynamicMemoryPlanningMode 31*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.emit import emit_program 32*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.pass_manager import PassManager 33*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes import ( 34*523fa7a6SAndroid Build Coastguard Worker DebugPass, 35*523fa7a6SAndroid Build Coastguard Worker MemoryPlanningPass, 36*523fa7a6SAndroid Build Coastguard Worker to_scratch_op_pass, 37*523fa7a6SAndroid Build Coastguard Worker ToOutVarPass, 38*523fa7a6SAndroid Build Coastguard Worker) 39*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.print_program import pretty_print, print_program 40*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import make_tensor_value, TensorSpec 41*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.control_flow_models import ( 42*523fa7a6SAndroid Build Coastguard Worker FTCondBasic, 43*523fa7a6SAndroid Build Coastguard Worker FTCondDynShape, 44*523fa7a6SAndroid Build Coastguard Worker FTMapBasic, 45*523fa7a6SAndroid Build Coastguard Worker FTMapDynShape, 46*523fa7a6SAndroid Build Coastguard Worker) 47*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.dynamic_shape_models import BatchNormModel 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.transformer import Transformer 50*523fa7a6SAndroid Build Coastguard Workerfrom functorch.experimental.control_flow import cond 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Workerkernel_mode = None # either aten mode or lean mode 53*523fa7a6SAndroid Build Coastguard Workertry: 54*523fa7a6SAndroid Build Coastguard Worker from executorch.extension.pybindings.portable_lib import ( 55*523fa7a6SAndroid Build Coastguard Worker _load_bundled_program_from_buffer, 56*523fa7a6SAndroid Build Coastguard Worker _load_for_executorch_from_buffer, 57*523fa7a6SAndroid Build Coastguard Worker _load_for_executorch_from_bundled_program, 58*523fa7a6SAndroid Build Coastguard Worker ) 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker kernel_mode = "lean" 61*523fa7a6SAndroid Build Coastguard Workerexcept ImportError as e: 62*523fa7a6SAndroid Build Coastguard Worker print(e) 63*523fa7a6SAndroid Build Coastguard Worker pass 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Workertry: 66*523fa7a6SAndroid Build Coastguard Worker from executorch.extension.pybindings.aten_lib import ( 67*523fa7a6SAndroid Build Coastguard Worker _load_bundled_program_from_buffer, 68*523fa7a6SAndroid Build Coastguard Worker _load_for_executorch_from_buffer, 69*523fa7a6SAndroid Build Coastguard Worker _load_for_executorch_from_bundled_program, 70*523fa7a6SAndroid Build Coastguard Worker ) 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Worker assert kernel_mode is None 73*523fa7a6SAndroid Build Coastguard Worker kernel_mode = "aten" 74*523fa7a6SAndroid Build Coastguard Workerexcept ImportError as e: 75*523fa7a6SAndroid Build Coastguard Worker print(e) 76*523fa7a6SAndroid Build Coastguard Worker pass 77*523fa7a6SAndroid Build Coastguard Worker 78*523fa7a6SAndroid Build Coastguard Workerassert kernel_mode is not None 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Workeris_aten_mode = kernel_mode == "aten" 81*523fa7a6SAndroid Build Coastguard Workeris_lean_mode = kernel_mode == "lean" 82*523fa7a6SAndroid Build Coastguard Worker 83*523fa7a6SAndroid Build Coastguard Workerfrom torch import nn 84*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as torch_pytree 85*523fa7a6SAndroid Build Coastguard Worker 86*523fa7a6SAndroid Build Coastguard Workerfrom .exported_module import ExportedModule 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Worker 89*523fa7a6SAndroid Build Coastguard WorkerRUN_SKIPPED = int(os.environ.get("RUN_SKIPPED", "0")) 90*523fa7a6SAndroid Build Coastguard Worker 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Workerclass ModuleBasic(nn.Module): 93*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 94*523fa7a6SAndroid Build Coastguard Worker super(ModuleBasic, self).__init__() 95*523fa7a6SAndroid Build Coastguard Worker 96*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 97*523fa7a6SAndroid Build Coastguard Worker return torch.sin(x).max() 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 100*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(100),) 101*523fa7a6SAndroid Build Coastguard Worker 102*523fa7a6SAndroid Build Coastguard Worker 103*523fa7a6SAndroid Build Coastguard Workerclass ModuleOpsReturnMulti(nn.Module): 104*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 105*523fa7a6SAndroid Build Coastguard Worker super(ModuleOpsReturnMulti, self).__init__() 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker def forward(self, a, b): 108*523fa7a6SAndroid Build Coastguard Worker x, y = torch.topk(a, 3) 109*523fa7a6SAndroid Build Coastguard Worker return x * 2 + b 110*523fa7a6SAndroid Build Coastguard Worker 111*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 112*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(10), torch.randn(3)) 113*523fa7a6SAndroid Build Coastguard Worker 114*523fa7a6SAndroid Build Coastguard Worker 115*523fa7a6SAndroid Build Coastguard Workerclass ModuleAdd(nn.Module): 116*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 117*523fa7a6SAndroid Build Coastguard Worker super(ModuleAdd, self).__init__() 118*523fa7a6SAndroid Build Coastguard Worker 119*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 120*523fa7a6SAndroid Build Coastguard Worker return torch.add(x, y) 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 123*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(2, 2), torch.randn(2, 2)) 124*523fa7a6SAndroid Build Coastguard Worker 125*523fa7a6SAndroid Build Coastguard Worker 126*523fa7a6SAndroid Build Coastguard Workerclass ModuleFloatAddWithAlpha(nn.Module): 127*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 128*523fa7a6SAndroid Build Coastguard Worker super(ModuleFloatAddWithAlpha, self).__init__() 129*523fa7a6SAndroid Build Coastguard Worker 130*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, y: torch.Tensor, c: float): 131*523fa7a6SAndroid Build Coastguard Worker return torch.add(x, y, alpha=c) 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 134*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(2, 2), torch.randn(2, 2), random.random()) 135*523fa7a6SAndroid Build Coastguard Worker 136*523fa7a6SAndroid Build Coastguard Worker 137*523fa7a6SAndroid Build Coastguard Workerclass ModuleIntAddWithAlpha(nn.Module): 138*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 139*523fa7a6SAndroid Build Coastguard Worker super(ModuleIntAddWithAlpha, self).__init__() 140*523fa7a6SAndroid Build Coastguard Worker 141*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, y: torch.Tensor, c: int): 142*523fa7a6SAndroid Build Coastguard Worker return torch.add(x, y, alpha=c) 143*523fa7a6SAndroid Build Coastguard Worker 144*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 145*523fa7a6SAndroid Build Coastguard Worker return ( 146*523fa7a6SAndroid Build Coastguard Worker torch.randint(0, 10, (2, 2)), 147*523fa7a6SAndroid Build Coastguard Worker torch.randint(0, 10, (2, 2)), 148*523fa7a6SAndroid Build Coastguard Worker random.randint(0, 10), 149*523fa7a6SAndroid Build Coastguard Worker ) 150*523fa7a6SAndroid Build Coastguard Worker 151*523fa7a6SAndroid Build Coastguard Worker 152*523fa7a6SAndroid Build Coastguard Workerclass ModuleContainers(nn.Module): 153*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 154*523fa7a6SAndroid Build Coastguard Worker super(ModuleContainers, self).__init__() 155*523fa7a6SAndroid Build Coastguard Worker 156*523fa7a6SAndroid Build Coastguard Worker def forward(self, d): 157*523fa7a6SAndroid Build Coastguard Worker a = d["a"] 158*523fa7a6SAndroid Build Coastguard Worker b = d["b"] 159*523fa7a6SAndroid Build Coastguard Worker return {"inputs": (a, b), "c": torch.add(a, b)} 160*523fa7a6SAndroid Build Coastguard Worker 161*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 162*523fa7a6SAndroid Build Coastguard Worker return ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},) 163*523fa7a6SAndroid Build Coastguard Worker 164*523fa7a6SAndroid Build Coastguard Worker 165*523fa7a6SAndroid Build Coastguard Workerclass ToyModelForMemPlanning(nn.Module): 166*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 167*523fa7a6SAndroid Build Coastguard Worker super(ToyModelForMemPlanning, self).__init__() 168*523fa7a6SAndroid Build Coastguard Worker 169*523fa7a6SAndroid Build Coastguard Worker def forward(self, a, b): 170*523fa7a6SAndroid Build Coastguard Worker o = a 171*523fa7a6SAndroid Build Coastguard Worker for i in range(3): 172*523fa7a6SAndroid Build Coastguard Worker o = o * a 173*523fa7a6SAndroid Build Coastguard Worker o = o + b 174*523fa7a6SAndroid Build Coastguard Worker return o 175*523fa7a6SAndroid Build Coastguard Worker 176*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 177*523fa7a6SAndroid Build Coastguard Worker return ( 178*523fa7a6SAndroid Build Coastguard Worker torch.randn(10), 179*523fa7a6SAndroid Build Coastguard Worker torch.randn(10), 180*523fa7a6SAndroid Build Coastguard Worker ) 181*523fa7a6SAndroid Build Coastguard Worker 182*523fa7a6SAndroid Build Coastguard Worker 183*523fa7a6SAndroid Build Coastguard Workerclass MemPlanningWithScratchTensor(nn.Module): 184*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 185*523fa7a6SAndroid Build Coastguard Worker super(MemPlanningWithScratchTensor, self).__init__() 186*523fa7a6SAndroid Build Coastguard Worker self.linear1 = nn.Linear(4, 2) 187*523fa7a6SAndroid Build Coastguard Worker self.linear2 = nn.Linear(4, 2) 188*523fa7a6SAndroid Build Coastguard Worker 189*523fa7a6SAndroid Build Coastguard Worker def forward(self, a, b): 190*523fa7a6SAndroid Build Coastguard Worker o1 = self.linear1(a) 191*523fa7a6SAndroid Build Coastguard Worker o2 = self.linear2(b) 192*523fa7a6SAndroid Build Coastguard Worker return o1 + o2 193*523fa7a6SAndroid Build Coastguard Worker 194*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 195*523fa7a6SAndroid Build Coastguard Worker return ( 196*523fa7a6SAndroid Build Coastguard Worker torch.randn(10, 4), 197*523fa7a6SAndroid Build Coastguard Worker torch.randn(10, 4), 198*523fa7a6SAndroid Build Coastguard Worker ) 199*523fa7a6SAndroid Build Coastguard Worker 200*523fa7a6SAndroid Build Coastguard Worker 201*523fa7a6SAndroid Build Coastguard Workerclass ModuleOpsReturnTensorList(nn.Module): 202*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 203*523fa7a6SAndroid Build Coastguard Worker super(ModuleOpsReturnTensorList, self).__init__() 204*523fa7a6SAndroid Build Coastguard Worker 205*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 206*523fa7a6SAndroid Build Coastguard Worker split = torch.ops.aten.tensor_split.sections(x, 3) 207*523fa7a6SAndroid Build Coastguard Worker return split[0] 208*523fa7a6SAndroid Build Coastguard Worker 209*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 210*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(100),) 211*523fa7a6SAndroid Build Coastguard Worker 212*523fa7a6SAndroid Build Coastguard Worker 213*523fa7a6SAndroid Build Coastguard Workerclass ModuleReturnInput(nn.Module): 214*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 215*523fa7a6SAndroid Build Coastguard Worker super(ModuleReturnInput, self).__init__() 216*523fa7a6SAndroid Build Coastguard Worker 217*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 218*523fa7a6SAndroid Build Coastguard Worker return (x, x, {"x": x, "y": x}, [x, x, x]) 219*523fa7a6SAndroid Build Coastguard Worker 220*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 221*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(1),) 222*523fa7a6SAndroid Build Coastguard Worker 223*523fa7a6SAndroid Build Coastguard Worker 224*523fa7a6SAndroid Build Coastguard Workerclass ModuleIfElse(nn.Module): 225*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 226*523fa7a6SAndroid Build Coastguard Worker super().__init__() 227*523fa7a6SAndroid Build Coastguard Worker 228*523fa7a6SAndroid Build Coastguard Worker def forward(self, c, x): 229*523fa7a6SAndroid Build Coastguard Worker x = x * x 230*523fa7a6SAndroid Build Coastguard Worker 231*523fa7a6SAndroid Build Coastguard Worker def addloop(x, n): 232*523fa7a6SAndroid Build Coastguard Worker out = x 233*523fa7a6SAndroid Build Coastguard Worker for _ in range(n - 1): 234*523fa7a6SAndroid Build Coastguard Worker out = out + x 235*523fa7a6SAndroid Build Coastguard Worker return out 236*523fa7a6SAndroid Build Coastguard Worker 237*523fa7a6SAndroid Build Coastguard Worker def true_branch(c, x): 238*523fa7a6SAndroid Build Coastguard Worker return addloop(x, 3) 239*523fa7a6SAndroid Build Coastguard Worker 240*523fa7a6SAndroid Build Coastguard Worker def false_branch(c, x): 241*523fa7a6SAndroid Build Coastguard Worker return addloop(x, 4) 242*523fa7a6SAndroid Build Coastguard Worker 243*523fa7a6SAndroid Build Coastguard Worker y = cond(c, true_branch, false_branch, (c, x)) 244*523fa7a6SAndroid Build Coastguard Worker return y * y 245*523fa7a6SAndroid Build Coastguard Worker 246*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 247*523fa7a6SAndroid Build Coastguard Worker return (torch.randint(2, [1]) == 0, torch.randn(10)) 248*523fa7a6SAndroid Build Coastguard Worker 249*523fa7a6SAndroid Build Coastguard Worker 250*523fa7a6SAndroid Build Coastguard Workerclass ModuleIfElseWithBoolInput(nn.Module): 251*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 252*523fa7a6SAndroid Build Coastguard Worker super().__init__() 253*523fa7a6SAndroid Build Coastguard Worker 254*523fa7a6SAndroid Build Coastguard Worker def forward(self, c: bool, x: torch.Tensor): 255*523fa7a6SAndroid Build Coastguard Worker x = x * x 256*523fa7a6SAndroid Build Coastguard Worker 257*523fa7a6SAndroid Build Coastguard Worker def addloop(x, n): 258*523fa7a6SAndroid Build Coastguard Worker out = x 259*523fa7a6SAndroid Build Coastguard Worker for _ in range(n - 1): 260*523fa7a6SAndroid Build Coastguard Worker out = out + x 261*523fa7a6SAndroid Build Coastguard Worker return out 262*523fa7a6SAndroid Build Coastguard Worker 263*523fa7a6SAndroid Build Coastguard Worker def true_branch(c, x): 264*523fa7a6SAndroid Build Coastguard Worker return addloop(x, 3) 265*523fa7a6SAndroid Build Coastguard Worker 266*523fa7a6SAndroid Build Coastguard Worker def false_branch(c, x): 267*523fa7a6SAndroid Build Coastguard Worker return addloop(x, 4) 268*523fa7a6SAndroid Build Coastguard Worker 269*523fa7a6SAndroid Build Coastguard Worker y = cond(c, true_branch, false_branch, (c, x)) 270*523fa7a6SAndroid Build Coastguard Worker 271*523fa7a6SAndroid Build Coastguard Worker return y * y 272*523fa7a6SAndroid Build Coastguard Worker 273*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 274*523fa7a6SAndroid Build Coastguard Worker return (random.randint(0, 1) == 0, torch.randn(10)) 275*523fa7a6SAndroid Build Coastguard Worker 276*523fa7a6SAndroid Build Coastguard Worker 277*523fa7a6SAndroid Build Coastguard Workerclass ModuleWhileIf(nn.Module): 278*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 279*523fa7a6SAndroid Build Coastguard Worker super().__init__() 280*523fa7a6SAndroid Build Coastguard Worker 281*523fa7a6SAndroid Build Coastguard Worker def forward(self, accum, cnt): 282*523fa7a6SAndroid Build Coastguard Worker @control_flow.tracing_context( 283*523fa7a6SAndroid Build Coastguard Worker inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 284*523fa7a6SAndroid Build Coastguard Worker ) 285*523fa7a6SAndroid Build Coastguard Worker def loop_cond(accum, cnt): 286*523fa7a6SAndroid Build Coastguard Worker return cnt != torch.zeros([1]).to(dtype=torch.long) 287*523fa7a6SAndroid Build Coastguard Worker 288*523fa7a6SAndroid Build Coastguard Worker @control_flow.tracing_context( 289*523fa7a6SAndroid Build Coastguard Worker inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 290*523fa7a6SAndroid Build Coastguard Worker ) 291*523fa7a6SAndroid Build Coastguard Worker def loop_body(accum, cnt): 292*523fa7a6SAndroid Build Coastguard Worker # return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long) 293*523fa7a6SAndroid Build Coastguard Worker @control_flow.tracing_context( 294*523fa7a6SAndroid Build Coastguard Worker inputs=(torch.zeros([1]).to(dtype=torch.long),) 295*523fa7a6SAndroid Build Coastguard Worker ) 296*523fa7a6SAndroid Build Coastguard Worker def true_branch(cnt): 297*523fa7a6SAndroid Build Coastguard Worker return cnt 298*523fa7a6SAndroid Build Coastguard Worker 299*523fa7a6SAndroid Build Coastguard Worker @control_flow.tracing_context( 300*523fa7a6SAndroid Build Coastguard Worker inputs=(torch.zeros([1]).to(dtype=torch.long),) 301*523fa7a6SAndroid Build Coastguard Worker ) 302*523fa7a6SAndroid Build Coastguard Worker def false_branch(cnt): 303*523fa7a6SAndroid Build Coastguard Worker return torch.zeros([1], dtype=torch.long) 304*523fa7a6SAndroid Build Coastguard Worker 305*523fa7a6SAndroid Build Coastguard Worker accum = accum + cond( 306*523fa7a6SAndroid Build Coastguard Worker torch.BoolTensor([True]), true_branch, false_branch, (cnt,) 307*523fa7a6SAndroid Build Coastguard Worker ) 308*523fa7a6SAndroid Build Coastguard Worker # 'cnt - 1' does not work yet since the runtime does not expect 309*523fa7a6SAndroid Build Coastguard Worker # tensor to be mixed with scalar for sub op. 310*523fa7a6SAndroid Build Coastguard Worker return accum, cnt - torch.ones([1]).to(dtype=torch.long) 311*523fa7a6SAndroid Build Coastguard Worker 312*523fa7a6SAndroid Build Coastguard Worker y, _ = control_flow.while_loop( 313*523fa7a6SAndroid Build Coastguard Worker loop_cond, 314*523fa7a6SAndroid Build Coastguard Worker loop_body, 315*523fa7a6SAndroid Build Coastguard Worker (accum, cnt), 316*523fa7a6SAndroid Build Coastguard Worker ) 317*523fa7a6SAndroid Build Coastguard Worker return y 318*523fa7a6SAndroid Build Coastguard Worker 319*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 320*523fa7a6SAndroid Build Coastguard Worker return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 321*523fa7a6SAndroid Build Coastguard Worker 322*523fa7a6SAndroid Build Coastguard Worker 323*523fa7a6SAndroid Build Coastguard Workerclass ModuleIfWhile(nn.Module): 324*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 325*523fa7a6SAndroid Build Coastguard Worker super().__init__() 326*523fa7a6SAndroid Build Coastguard Worker 327*523fa7a6SAndroid Build Coastguard Worker def forward(self, accum, cnt): 328*523fa7a6SAndroid Build Coastguard Worker @control_flow.tracing_context( 329*523fa7a6SAndroid Build Coastguard Worker inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 330*523fa7a6SAndroid Build Coastguard Worker ) 331*523fa7a6SAndroid Build Coastguard Worker def true_branch(accum, cnt): 332*523fa7a6SAndroid Build Coastguard Worker @control_flow.tracing_context( 333*523fa7a6SAndroid Build Coastguard Worker inputs=( 334*523fa7a6SAndroid Build Coastguard Worker torch.zeros([1]).to(dtype=torch.long), 335*523fa7a6SAndroid Build Coastguard Worker torch.randint(10, 100, [1]), 336*523fa7a6SAndroid Build Coastguard Worker ) 337*523fa7a6SAndroid Build Coastguard Worker ) 338*523fa7a6SAndroid Build Coastguard Worker def loop_cond(accum, cnt): 339*523fa7a6SAndroid Build Coastguard Worker return cnt != torch.zeros([1]).to(dtype=torch.long) 340*523fa7a6SAndroid Build Coastguard Worker 341*523fa7a6SAndroid Build Coastguard Worker @control_flow.tracing_context( 342*523fa7a6SAndroid Build Coastguard Worker inputs=( 343*523fa7a6SAndroid Build Coastguard Worker torch.zeros([1]).to(dtype=torch.long), 344*523fa7a6SAndroid Build Coastguard Worker torch.randint(10, 100, [1]), 345*523fa7a6SAndroid Build Coastguard Worker ) 346*523fa7a6SAndroid Build Coastguard Worker ) 347*523fa7a6SAndroid Build Coastguard Worker def loop_body(accum, cnt): 348*523fa7a6SAndroid Build Coastguard Worker return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long) 349*523fa7a6SAndroid Build Coastguard Worker 350*523fa7a6SAndroid Build Coastguard Worker return control_flow.while_loop(loop_cond, loop_body, (accum, cnt)) 351*523fa7a6SAndroid Build Coastguard Worker 352*523fa7a6SAndroid Build Coastguard Worker @control_flow.tracing_context( 353*523fa7a6SAndroid Build Coastguard Worker inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 354*523fa7a6SAndroid Build Coastguard Worker ) 355*523fa7a6SAndroid Build Coastguard Worker def false_branch(accum, cnt): 356*523fa7a6SAndroid Build Coastguard Worker return accum, cnt 357*523fa7a6SAndroid Build Coastguard Worker 358*523fa7a6SAndroid Build Coastguard Worker return cond(torch.BoolTensor([True]), true_branch, false_branch, (accum, cnt))[ 359*523fa7a6SAndroid Build Coastguard Worker 0 360*523fa7a6SAndroid Build Coastguard Worker ] 361*523fa7a6SAndroid Build Coastguard Worker 362*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 363*523fa7a6SAndroid Build Coastguard Worker return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 364*523fa7a6SAndroid Build Coastguard Worker 365*523fa7a6SAndroid Build Coastguard Worker 366*523fa7a6SAndroid Build Coastguard Workerclass ModuleContiguousTensor(nn.Module): 367*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 368*523fa7a6SAndroid Build Coastguard Worker super().__init__() 369*523fa7a6SAndroid Build Coastguard Worker self.linear = nn.Linear(8, 32) 370*523fa7a6SAndroid Build Coastguard Worker 371*523fa7a6SAndroid Build Coastguard Worker def forward(self, arg): 372*523fa7a6SAndroid Build Coastguard Worker return self.linear(arg) 373*523fa7a6SAndroid Build Coastguard Worker 374*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 375*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(3, 8),) 376*523fa7a6SAndroid Build Coastguard Worker 377*523fa7a6SAndroid Build Coastguard Worker 378*523fa7a6SAndroid Build Coastguard Workerclass ModuleInputDynamicShape(nn.Module): 379*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 380*523fa7a6SAndroid Build Coastguard Worker super().__init__() 381*523fa7a6SAndroid Build Coastguard Worker 382*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 383*523fa7a6SAndroid Build Coastguard Worker for i in range(4): 384*523fa7a6SAndroid Build Coastguard Worker x = x + x 385*523fa7a6SAndroid Build Coastguard Worker x = x * x 386*523fa7a6SAndroid Build Coastguard Worker return x 387*523fa7a6SAndroid Build Coastguard Worker 388*523fa7a6SAndroid Build Coastguard Worker def get_upper_bound_inputs(self): 389*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(10),) 390*523fa7a6SAndroid Build Coastguard Worker 391*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 392*523fa7a6SAndroid Build Coastguard Worker n = random.randint(1, 10) 393*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(n),) 394*523fa7a6SAndroid Build Coastguard Worker 395*523fa7a6SAndroid Build Coastguard Worker 396*523fa7a6SAndroid Build Coastguard Workerclass ModuleIntermediateDynamicShape(nn.Module): 397*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 398*523fa7a6SAndroid Build Coastguard Worker super().__init__() 399*523fa7a6SAndroid Build Coastguard Worker 400*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 401*523fa7a6SAndroid Build Coastguard Worker x = x * x 402*523fa7a6SAndroid Build Coastguard Worker 403*523fa7a6SAndroid Build Coastguard Worker # We should use x[torch.nonzero(x)] ideally, but index op is not supported 404*523fa7a6SAndroid Build Coastguard Worker # in the runtime so far. 405*523fa7a6SAndroid Build Coastguard Worker x = torch.nonzero(x) 406*523fa7a6SAndroid Build Coastguard Worker return x + x 407*523fa7a6SAndroid Build Coastguard Worker 408*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 409*523fa7a6SAndroid Build Coastguard Worker return (torch.randint(0, 2, (10,), dtype=torch.float),) 410*523fa7a6SAndroid Build Coastguard Worker 411*523fa7a6SAndroid Build Coastguard Worker 412*523fa7a6SAndroid Build Coastguard Workerdef allclose(lhs, rhs, rtol=1e-5, atol=1e-8): 413*523fa7a6SAndroid Build Coastguard Worker r""" 414*523fa7a6SAndroid Build Coastguard Worker Unlike torch.allocse which only handles Tensor arguments, allclose handles 415*523fa7a6SAndroid Build Coastguard Worker list, tuple, dict and nesting of these as well. 416*523fa7a6SAndroid Build Coastguard Worker """ 417*523fa7a6SAndroid Build Coastguard Worker if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor): 418*523fa7a6SAndroid Build Coastguard Worker return torch.allclose(lhs, rhs, rtol, atol) 419*523fa7a6SAndroid Build Coastguard Worker if isinstance(lhs, (tuple, list)) and isinstance(rhs, (tuple, list)): 420*523fa7a6SAndroid Build Coastguard Worker return len(lhs) == len(rhs) and all( 421*523fa7a6SAndroid Build Coastguard Worker allclose(a, b, rtol, atol) for a, b in zip(lhs, rhs) 422*523fa7a6SAndroid Build Coastguard Worker ) 423*523fa7a6SAndroid Build Coastguard Worker if isinstance(lhs, dict) and isinstance(rhs, dict): 424*523fa7a6SAndroid Build Coastguard Worker lhs_keys = set(lhs.keys()) 425*523fa7a6SAndroid Build Coastguard Worker rhs_keys = set(rhs.keys()) 426*523fa7a6SAndroid Build Coastguard Worker if lhs_keys != rhs_keys: 427*523fa7a6SAndroid Build Coastguard Worker return False 428*523fa7a6SAndroid Build Coastguard Worker return all(allclose(lhs[k], rhs[k], rtol, atol) for k in lhs) 429*523fa7a6SAndroid Build Coastguard Worker else: 430*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 431*523fa7a6SAndroid Build Coastguard Worker f"Unexpected types: lhs type {type(lhs)}, rhs type {type(rhs)}" 432*523fa7a6SAndroid Build Coastguard Worker ) 433*523fa7a6SAndroid Build Coastguard Worker 434*523fa7a6SAndroid Build Coastguard Worker 435*523fa7a6SAndroid Build Coastguard Workerdef validate_contiguous_tensors(program): 436*523fa7a6SAndroid Build Coastguard Worker def _is_contiguous_tensor(tensor: exir.schema.Tensor): 437*523fa7a6SAndroid Build Coastguard Worker """ 438*523fa7a6SAndroid Build Coastguard Worker Ensure the tensor is pytorch contigous (torch.memory_format=torch.contiguous) 439*523fa7a6SAndroid Build Coastguard Worker since the runtime can not handle non-contiguous tensors so far. 440*523fa7a6SAndroid Build Coastguard Worker """ 441*523fa7a6SAndroid Build Coastguard Worker sizes = tensor.sizes 442*523fa7a6SAndroid Build Coastguard Worker dim_order = tensor.dim_order 443*523fa7a6SAndroid Build Coastguard Worker assert len(sizes) == len(dim_order) 444*523fa7a6SAndroid Build Coastguard Worker for i, val in enumerate(dim_order): 445*523fa7a6SAndroid Build Coastguard Worker if i != val: 446*523fa7a6SAndroid Build Coastguard Worker return False 447*523fa7a6SAndroid Build Coastguard Worker return True 448*523fa7a6SAndroid Build Coastguard Worker 449*523fa7a6SAndroid Build Coastguard Worker for execution_plan in program.execution_plan: 450*523fa7a6SAndroid Build Coastguard Worker for value in execution_plan.values: 451*523fa7a6SAndroid Build Coastguard Worker if isinstance(value.val, exir.schema.Tensor): 452*523fa7a6SAndroid Build Coastguard Worker assert _is_contiguous_tensor( 453*523fa7a6SAndroid Build Coastguard Worker value.val 454*523fa7a6SAndroid Build Coastguard Worker ), f"Non-contiguous tensor found: size {value.val.sizes} stride {value.val.strides}. constant_buffer_idx {value.val.constant_buffer_idx}. allocation_info {value.val.allocation_info}." 455*523fa7a6SAndroid Build Coastguard Worker 456*523fa7a6SAndroid Build Coastguard Worker 457*523fa7a6SAndroid Build Coastguard Workerclass BoundMethod(object): 458*523fa7a6SAndroid Build Coastguard Worker def __init__(self, instance, callable): 459*523fa7a6SAndroid Build Coastguard Worker self._instance = instance 460*523fa7a6SAndroid Build Coastguard Worker self._callable = callable 461*523fa7a6SAndroid Build Coastguard Worker 462*523fa7a6SAndroid Build Coastguard Worker def __call__(self, *args, **kwargs): 463*523fa7a6SAndroid Build Coastguard Worker return self._callable(self.instance, *args, **kwargs) 464*523fa7a6SAndroid Build Coastguard Worker 465*523fa7a6SAndroid Build Coastguard Worker 466*523fa7a6SAndroid Build Coastguard Workerdef maketest( 467*523fa7a6SAndroid Build Coastguard Worker module_cls: Type[nn.Module], 468*523fa7a6SAndroid Build Coastguard Worker niter: int = 10, 469*523fa7a6SAndroid Build Coastguard Worker run_executor: bool = True, 470*523fa7a6SAndroid Build Coastguard Worker do_tree_flatten: bool = False, 471*523fa7a6SAndroid Build Coastguard Worker run_graph_module: bool = True, 472*523fa7a6SAndroid Build Coastguard Worker atol: float = 1e-8, 473*523fa7a6SAndroid Build Coastguard Worker rtol: float = 1e-5, 474*523fa7a6SAndroid Build Coastguard Worker ignore_to_out_var_failure: bool = False, 475*523fa7a6SAndroid Build Coastguard Worker allow_non_contiguous_tensor: bool = False, 476*523fa7a6SAndroid Build Coastguard Worker method: str = "forward", 477*523fa7a6SAndroid Build Coastguard Worker dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND, 478*523fa7a6SAndroid Build Coastguard Worker capture_config=None, 479*523fa7a6SAndroid Build Coastguard Worker verify_graph: Optional[Callable] = None, 480*523fa7a6SAndroid Build Coastguard Worker) -> Callable[[unittest.TestCase], None]: 481*523fa7a6SAndroid Build Coastguard Worker r"""Returns a TestCase method to test the provided module class and method. 482*523fa7a6SAndroid Build Coastguard Worker 483*523fa7a6SAndroid Build Coastguard Worker Args: 484*523fa7a6SAndroid Build Coastguard Worker module_cls: The subclass of nn.Module to export. 485*523fa7a6SAndroid Build Coastguard Worker niter: The number of random input data sets to test with. 486*523fa7a6SAndroid Build Coastguard Worker run_executor: Whether to run the model on the executor. We may want to 487*523fa7a6SAndroid Build Coastguard Worker skip running a model thru executor since some kernels are not 488*523fa7a6SAndroid Build Coastguard Worker implemented. 489*523fa7a6SAndroid Build Coastguard Worker do_tree_flatten: Whether to flatten input and unflatten output. 490*523fa7a6SAndroid Build Coastguard Worker run_graph_module: Whether to run the traced and transformed GraphModule. 491*523fa7a6SAndroid Build Coastguard Worker One may want to skip this if some custom ops do not have 492*523fa7a6SAndroid Build Coastguard Worker implementation in torch.ops but is implemented in the executor. 493*523fa7a6SAndroid Build Coastguard Worker atol: Absolute tolerance used in allclose and torch.allclose 494*523fa7a6SAndroid Build Coastguard Worker rtol: Relative tolerance used in allclose and torch.allclose 495*523fa7a6SAndroid Build Coastguard Worker ignore_to_out_var_failure: Whether to ignore the failue when a 496*523fa7a6SAndroid Build Coastguard Worker functional op does not have an out variant. 497*523fa7a6SAndroid Build Coastguard Worker allow_non_contiguous_tensor: If false, will validate that the emitted 498*523fa7a6SAndroid Build Coastguard Worker program only contains contiguous tensors. 499*523fa7a6SAndroid Build Coastguard Worker method: The name of the module_cls method to trace. 500*523fa7a6SAndroid Build Coastguard Worker dynamic_memory_planning_mode: The dynamic memory planning mode to use. 501*523fa7a6SAndroid Build Coastguard Worker 502*523fa7a6SAndroid Build Coastguard Worker Returns: 503*523fa7a6SAndroid Build Coastguard Worker A TestCase method that tests the provided module class and method. 504*523fa7a6SAndroid Build Coastguard Worker """ 505*523fa7a6SAndroid Build Coastguard Worker 506*523fa7a6SAndroid Build Coastguard Worker def wrapper(self: unittest.TestCase) -> None: 507*523fa7a6SAndroid Build Coastguard Worker """A TestCase method that traces/exports/tests an nn.Module and method.""" 508*523fa7a6SAndroid Build Coastguard Worker module = ExportedModule.export( 509*523fa7a6SAndroid Build Coastguard Worker module_class=module_cls, 510*523fa7a6SAndroid Build Coastguard Worker # testend2end only supports modules with single methods defined 511*523fa7a6SAndroid Build Coastguard Worker methods=(method,), 512*523fa7a6SAndroid Build Coastguard Worker ignore_to_out_var_failure=ignore_to_out_var_failure, 513*523fa7a6SAndroid Build Coastguard Worker dynamic_memory_planning_mode=dynamic_memory_planning_mode, 514*523fa7a6SAndroid Build Coastguard Worker capture_config=capture_config, 515*523fa7a6SAndroid Build Coastguard Worker ) 516*523fa7a6SAndroid Build Coastguard Worker if verify_graph: 517*523fa7a6SAndroid Build Coastguard Worker verify_graph(self, module.exported_program.graph_module) 518*523fa7a6SAndroid Build Coastguard Worker print(f"inputs for tracing: {module.trace_inputs}") 519*523fa7a6SAndroid Build Coastguard Worker 520*523fa7a6SAndroid Build Coastguard Worker # compare the result between the eager module and graph module 521*523fa7a6SAndroid Build Coastguard Worker inputs_list = [module.get_random_inputs() for _ in range(niter)] 522*523fa7a6SAndroid Build Coastguard Worker 523*523fa7a6SAndroid Build Coastguard Worker if run_graph_module: 524*523fa7a6SAndroid Build Coastguard Worker for inputs in inputs_list: 525*523fa7a6SAndroid Build Coastguard Worker with torch.no_grad(): 526*523fa7a6SAndroid Build Coastguard Worker # only one method is supported so just grab that single method 527*523fa7a6SAndroid Build Coastguard Worker expected = getattr(module.eager_module, module.methods[0])(*inputs) 528*523fa7a6SAndroid Build Coastguard Worker with torch.no_grad(): 529*523fa7a6SAndroid Build Coastguard Worker result = module.exported_program.module()(*inputs) 530*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(allclose(expected, result, rtol, atol)) 531*523fa7a6SAndroid Build Coastguard Worker 532*523fa7a6SAndroid Build Coastguard Worker program = module.executorch_program.executorch_program 533*523fa7a6SAndroid Build Coastguard Worker pretty_print(program) 534*523fa7a6SAndroid Build Coastguard Worker print_program(program, show_meminfo=True, mark_dynamic_shape_tensor=True) 535*523fa7a6SAndroid Build Coastguard Worker print(f"mem buffer sizes: {program.execution_plan[0].non_const_buffer_sizes}") 536*523fa7a6SAndroid Build Coastguard Worker if not allow_non_contiguous_tensor: 537*523fa7a6SAndroid Build Coastguard Worker validate_contiguous_tensors(program) 538*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(len(program.execution_plan[0].non_const_buffer_sizes) >= 2) 539*523fa7a6SAndroid Build Coastguard Worker # We should not enable the following assertion since for some models 540*523fa7a6SAndroid Build Coastguard Worker # that simply returning graph input, no mutable memory should be allocated 541*523fa7a6SAndroid Build Coastguard Worker # self.assertTrue(all(s > 0 for s in program.program.execution_plan[0].non_const_buffer_sizes[1:])) 542*523fa7a6SAndroid Build Coastguard Worker 543*523fa7a6SAndroid Build Coastguard Worker program.version = 0 544*523fa7a6SAndroid Build Coastguard Worker buff = module.executorch_program.buffer 545*523fa7a6SAndroid Build Coastguard Worker # Check that the magic version number is in the expected place, and 546*523fa7a6SAndroid Build Coastguard Worker # follows the expected pattern. 547*523fa7a6SAndroid Build Coastguard Worker self.assertRegex(buff[4:8].decode(errors="replace"), r"^ET[0-9][0-9]$") 548*523fa7a6SAndroid Build Coastguard Worker 549*523fa7a6SAndroid Build Coastguard Worker if run_executor: 550*523fa7a6SAndroid Build Coastguard Worker print("Running on the runtime") 551*523fa7a6SAndroid Build Coastguard Worker executorch_module = _load_for_executorch_from_buffer(buff) 552*523fa7a6SAndroid Build Coastguard Worker # compare the result between eager module and executor 553*523fa7a6SAndroid Build Coastguard Worker for idx, inputs in enumerate(inputs_list): 554*523fa7a6SAndroid Build Coastguard Worker with torch.no_grad(): 555*523fa7a6SAndroid Build Coastguard Worker expected = getattr(module.eager_module, method)(*inputs) 556*523fa7a6SAndroid Build Coastguard Worker 557*523fa7a6SAndroid Build Coastguard Worker if do_tree_flatten: 558*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 559*523fa7a6SAndroid Build Coastguard Worker flatten_inputs, inputs_spec = pytree.tree_flatten(*inputs) 560*523fa7a6SAndroid Build Coastguard Worker executorch_result = executorch_module.forward([*flatten_inputs]) 561*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: Module `pytree` has no attribute `TreeSpec`. 562*523fa7a6SAndroid Build Coastguard Worker executorch_result_unflatten = pytree.TreeSpec.from_str( 563*523fa7a6SAndroid Build Coastguard Worker program.execution_plan[0].container_meta_type.encoded_out_str 564*523fa7a6SAndroid Build Coastguard Worker ).tree_unflatten(executorch_result) 565*523fa7a6SAndroid Build Coastguard Worker actual = executorch_result_unflatten 566*523fa7a6SAndroid Build Coastguard Worker else: 567*523fa7a6SAndroid Build Coastguard Worker actual = executorch_module.forward(inputs)[0] 568*523fa7a6SAndroid Build Coastguard Worker is_close = allclose(expected, actual, rtol, atol) 569*523fa7a6SAndroid Build Coastguard Worker if not is_close: 570*523fa7a6SAndroid Build Coastguard Worker print(f"Fail for {idx}th inputs: {inputs}") 571*523fa7a6SAndroid Build Coastguard Worker print(f"expected result: {expected}") 572*523fa7a6SAndroid Build Coastguard Worker print(f"actual result: {actual}") 573*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(is_close) 574*523fa7a6SAndroid Build Coastguard Worker 575*523fa7a6SAndroid Build Coastguard Worker return wrapper 576*523fa7a6SAndroid Build Coastguard Worker 577*523fa7a6SAndroid Build Coastguard Worker 578*523fa7a6SAndroid Build Coastguard Workerclass E2ETest(unittest.TestCase): 579*523fa7a6SAndroid Build Coastguard Worker r""" 580*523fa7a6SAndroid Build Coastguard Worker When adding a new unittest, call maketest(ModuleName) if possible since 581*523fa7a6SAndroid Build Coastguard Worker maketest handles all the boilterplate part. Ideally, we only need define 582*523fa7a6SAndroid Build Coastguard Worker a new nn.Module and add one line to call maketest for new end2end test cases. 583*523fa7a6SAndroid Build Coastguard Worker """ 584*523fa7a6SAndroid Build Coastguard Worker 585*523fa7a6SAndroid Build Coastguard Worker # don't run the model thru executor because aten::sin.out is not defined 586*523fa7a6SAndroid Build Coastguard Worker # in the executor currently. 587*523fa7a6SAndroid Build Coastguard Worker # 588*523fa7a6SAndroid Build Coastguard Worker # aten::max.default does not have an out variant. Thus we need set 589*523fa7a6SAndroid Build Coastguard Worker # ignore_to_out_var_failure to be True. 590*523fa7a6SAndroid Build Coastguard Worker def test_basic(self): 591*523fa7a6SAndroid Build Coastguard Worker maketest(ModuleBasic, run_executor=False, ignore_to_out_var_failure=True)(self) 592*523fa7a6SAndroid Build Coastguard Worker 593*523fa7a6SAndroid Build Coastguard Worker # Make sure we can handle ops that return mutliple values. E.g. topk 594*523fa7a6SAndroid Build Coastguard Worker # At one time we can not properly setup TensorSpec for an Fx node 595*523fa7a6SAndroid Build Coastguard Worker # returning multiple tensors 596*523fa7a6SAndroid Build Coastguard Worker # 597*523fa7a6SAndroid Build Coastguard Worker # don't run the model thru executor because aten::topk.values is not defined 598*523fa7a6SAndroid Build Coastguard Worker # in the executor currently 599*523fa7a6SAndroid Build Coastguard Worker def test_ops_return_multi(self): 600*523fa7a6SAndroid Build Coastguard Worker maketest(ModuleOpsReturnMulti, run_executor=False)(self) 601*523fa7a6SAndroid Build Coastguard Worker 602*523fa7a6SAndroid Build Coastguard Worker @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") 603*523fa7a6SAndroid Build Coastguard Worker def test_mem_planning_toy_model(self): 604*523fa7a6SAndroid Build Coastguard Worker maketest( 605*523fa7a6SAndroid Build Coastguard Worker ToyModelForMemPlanning, 606*523fa7a6SAndroid Build Coastguard Worker capture_config=exir.CaptureConfig( 607*523fa7a6SAndroid Build Coastguard Worker enable_dynamic_shape=True, 608*523fa7a6SAndroid Build Coastguard Worker ), 609*523fa7a6SAndroid Build Coastguard Worker )(self) 610*523fa7a6SAndroid Build Coastguard Worker 611*523fa7a6SAndroid Build Coastguard Worker # TODO: add ops implementations and turn on 'run_executor' 612*523fa7a6SAndroid Build Coastguard Worker def test_mem_planning_scratch_tensor(self): 613*523fa7a6SAndroid Build Coastguard Worker maketest( 614*523fa7a6SAndroid Build Coastguard Worker MemPlanningWithScratchTensor, 615*523fa7a6SAndroid Build Coastguard Worker run_graph_module=False, 616*523fa7a6SAndroid Build Coastguard Worker run_executor=False, 617*523fa7a6SAndroid Build Coastguard Worker atol=1e-5, 618*523fa7a6SAndroid Build Coastguard Worker )(self) 619*523fa7a6SAndroid Build Coastguard Worker 620*523fa7a6SAndroid Build Coastguard Worker def test_executorch_forward(self): 621*523fa7a6SAndroid Build Coastguard Worker maketest(ModuleAdd)(self) 622*523fa7a6SAndroid Build Coastguard Worker 623*523fa7a6SAndroid Build Coastguard Worker @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") 624*523fa7a6SAndroid Build Coastguard Worker def test_containers(self): 625*523fa7a6SAndroid Build Coastguard Worker maketest( 626*523fa7a6SAndroid Build Coastguard Worker ModuleContainers, 627*523fa7a6SAndroid Build Coastguard Worker do_tree_flatten=True, 628*523fa7a6SAndroid Build Coastguard Worker capture_config=exir.CaptureConfig( 629*523fa7a6SAndroid Build Coastguard Worker enable_dynamic_shape=True, 630*523fa7a6SAndroid Build Coastguard Worker ), 631*523fa7a6SAndroid Build Coastguard Worker )(self) 632*523fa7a6SAndroid Build Coastguard Worker 633*523fa7a6SAndroid Build Coastguard Worker # can not run the graph module since the out variance with tensor list out 634*523fa7a6SAndroid Build Coastguard Worker # argument returns None rather than tensor list. 635*523fa7a6SAndroid Build Coastguard Worker # 636*523fa7a6SAndroid Build Coastguard Worker # Can not run in the executor since kernel for tensor splitting is not implemented.. 637*523fa7a6SAndroid Build Coastguard Worker def test_ops_return_tensorlist(self): 638*523fa7a6SAndroid Build Coastguard Worker maketest(ModuleOpsReturnTensorList, run_graph_module=False, run_executor=False)( 639*523fa7a6SAndroid Build Coastguard Worker self 640*523fa7a6SAndroid Build Coastguard Worker ) 641*523fa7a6SAndroid Build Coastguard Worker 642*523fa7a6SAndroid Build Coastguard Worker # Failed to produce a graph during tracing w/ dynamo because there are no torch ops 643*523fa7a6SAndroid Build Coastguard Worker # test_return_input = maketest(ModuleReturnInput, do_tree_flatten=True) 644*523fa7a6SAndroid Build Coastguard Worker 645*523fa7a6SAndroid Build Coastguard Worker # can not run this on the executor because missing the following ops: 646*523fa7a6SAndroid Build Coastguard Worker # aten::select_copy.int_out, aten::eq.Scalar_out 647*523fa7a6SAndroid Build Coastguard Worker # TODO(zhxchen17) re-enable these tests. 648*523fa7a6SAndroid Build Coastguard Worker # test_control_flow_cond = maketest(ControlFlowCond, run_executor=False) 649*523fa7a6SAndroid Build Coastguard Worker # fail to trace with functionalization enabled 650*523fa7a6SAndroid Build Coastguard Worker # test_ifelse = maketest(ModuleIfElse) 651*523fa7a6SAndroid Build Coastguard Worker 652*523fa7a6SAndroid Build Coastguard Worker # fail to trace with functionalization enabled 653*523fa7a6SAndroid Build Coastguard Worker # Fail with error: Missing out variants: {'aten::select', 'aten::_shape_as_tensor', 'aten::tensor_split'} 654*523fa7a6SAndroid Build Coastguard Worker # TODO(zhxchen17) re-enable these tests. 655*523fa7a6SAndroid Build Coastguard Worker # test_while_0 = maketest( 656*523fa7a6SAndroid Build Coastguard Worker # ControlFlowWhile, 657*523fa7a6SAndroid Build Coastguard Worker # ignore_to_out_var_failure=True, 658*523fa7a6SAndroid Build Coastguard Worker # run_executor=False, 659*523fa7a6SAndroid Build Coastguard Worker # ) 660*523fa7a6SAndroid Build Coastguard Worker 661*523fa7a6SAndroid Build Coastguard Worker # test_while = maketest(ModuleWhile) 662*523fa7a6SAndroid Build Coastguard Worker 663*523fa7a6SAndroid Build Coastguard Worker # test_while_if = maketest(ModuleWhileIf) 664*523fa7a6SAndroid Build Coastguard Worker # test_if_while = maketest(ModuleIfWhile) 665*523fa7a6SAndroid Build Coastguard Worker @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) This fails on OSS macos job") 666*523fa7a6SAndroid Build Coastguard Worker def test_contiguous_tensor(self): 667*523fa7a6SAndroid Build Coastguard Worker maketest(ModuleContiguousTensor, run_executor=False)(self) 668*523fa7a6SAndroid Build Coastguard Worker 669*523fa7a6SAndroid Build Coastguard Worker 670*523fa7a6SAndroid Build Coastguard Workerclass DynamicModelE2ETest(unittest.TestCase): 671*523fa7a6SAndroid Build Coastguard Worker """ 672*523fa7a6SAndroid Build Coastguard Worker End2end tests for dynamic models. For dynamic models we mean models with 673*523fa7a6SAndroid Build Coastguard Worker control flow or dynamic shape. 674*523fa7a6SAndroid Build Coastguard Worker """ 675*523fa7a6SAndroid Build Coastguard Worker 676*523fa7a6SAndroid Build Coastguard Worker @skip("Revisit when unbacked symint is ready") 677*523fa7a6SAndroid Build Coastguard Worker def test_intermediate_dynamic_shape(self): 678*523fa7a6SAndroid Build Coastguard Worker maketest( 679*523fa7a6SAndroid Build Coastguard Worker ModuleIntermediateDynamicShape, 680*523fa7a6SAndroid Build Coastguard Worker run_graph_module=False, 681*523fa7a6SAndroid Build Coastguard Worker allow_non_contiguous_tensor=True, 682*523fa7a6SAndroid Build Coastguard Worker capture_config=exir.CaptureConfig( 683*523fa7a6SAndroid Build Coastguard Worker enable_dynamic_shape=True, 684*523fa7a6SAndroid Build Coastguard Worker ), 685*523fa7a6SAndroid Build Coastguard Worker )(self) 686*523fa7a6SAndroid Build Coastguard Worker 687*523fa7a6SAndroid Build Coastguard Worker # TODO(shunting): some non constant tensors for transformer are non-contiguous. 688*523fa7a6SAndroid Build Coastguard Worker # Ignore for now. Will debug more. 689*523fa7a6SAndroid Build Coastguard Worker # NOTE: can not run on runtime since missing these ops: P535190636 690*523fa7a6SAndroid Build Coastguard Worker @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) This fails on OSS macos job") 691*523fa7a6SAndroid Build Coastguard Worker def test_transformer_encode(self): 692*523fa7a6SAndroid Build Coastguard Worker maketest( 693*523fa7a6SAndroid Build Coastguard Worker Transformer, 694*523fa7a6SAndroid Build Coastguard Worker method="encode", 695*523fa7a6SAndroid Build Coastguard Worker allow_non_contiguous_tensor=True, 696*523fa7a6SAndroid Build Coastguard Worker run_executor=False, 697*523fa7a6SAndroid Build Coastguard Worker )(self) 698*523fa7a6SAndroid Build Coastguard Worker 699*523fa7a6SAndroid Build Coastguard Worker # basic test for functorch torch.ops.higher_order.cond 700*523fa7a6SAndroid Build Coastguard Worker @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") 701*523fa7a6SAndroid Build Coastguard Worker def test_ft_cond_basic(self): 702*523fa7a6SAndroid Build Coastguard Worker maketest( 703*523fa7a6SAndroid Build Coastguard Worker FTCondBasic, 704*523fa7a6SAndroid Build Coastguard Worker capture_config=exir.CaptureConfig( 705*523fa7a6SAndroid Build Coastguard Worker enable_dynamic_shape=True, 706*523fa7a6SAndroid Build Coastguard Worker enable_functionalization=False, # TODO enable functionalization 707*523fa7a6SAndroid Build Coastguard Worker ), 708*523fa7a6SAndroid Build Coastguard Worker )(self) 709*523fa7a6SAndroid Build Coastguard Worker 710*523fa7a6SAndroid Build Coastguard Worker @skipUnless(RUN_SKIPPED, "Emitter is not ready yet") 711*523fa7a6SAndroid Build Coastguard Worker def test_ft_map_basic(self): 712*523fa7a6SAndroid Build Coastguard Worker maketest( 713*523fa7a6SAndroid Build Coastguard Worker FTMapBasic, 714*523fa7a6SAndroid Build Coastguard Worker capture_config=exir.CaptureConfig( 715*523fa7a6SAndroid Build Coastguard Worker enable_dynamic_shape=True, 716*523fa7a6SAndroid Build Coastguard Worker enable_functionalization=False, # TODO enable functionalization 717*523fa7a6SAndroid Build Coastguard Worker ), 718*523fa7a6SAndroid Build Coastguard Worker )(self) 719*523fa7a6SAndroid Build Coastguard Worker 720*523fa7a6SAndroid Build Coastguard Worker @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") 721*523fa7a6SAndroid Build Coastguard Worker def test_ft_cond_dynshape(self): 722*523fa7a6SAndroid Build Coastguard Worker maketest( 723*523fa7a6SAndroid Build Coastguard Worker FTCondDynShape, 724*523fa7a6SAndroid Build Coastguard Worker capture_config=exir.CaptureConfig( 725*523fa7a6SAndroid Build Coastguard Worker enable_dynamic_shape=True, 726*523fa7a6SAndroid Build Coastguard Worker enable_functionalization=False, # TODO enable functionalization 727*523fa7a6SAndroid Build Coastguard Worker ), 728*523fa7a6SAndroid Build Coastguard Worker )(self) 729*523fa7a6SAndroid Build Coastguard Worker 730*523fa7a6SAndroid Build Coastguard Worker @skipUnless(RUN_SKIPPED, "Emitter is not ready yet") 731*523fa7a6SAndroid Build Coastguard Worker def test_ft_map_dynshape(self): 732*523fa7a6SAndroid Build Coastguard Worker maketest( 733*523fa7a6SAndroid Build Coastguard Worker FTMapDynShape, 734*523fa7a6SAndroid Build Coastguard Worker capture_config=exir.CaptureConfig( 735*523fa7a6SAndroid Build Coastguard Worker enable_dynamic_shape=True, 736*523fa7a6SAndroid Build Coastguard Worker enable_functionalization=False, # TODO enable functionalization 737*523fa7a6SAndroid Build Coastguard Worker ), 738*523fa7a6SAndroid Build Coastguard Worker )(self) 739*523fa7a6SAndroid Build Coastguard Worker 740*523fa7a6SAndroid Build Coastguard Worker @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") 741*523fa7a6SAndroid Build Coastguard Worker def test_batch_norm(self): 742*523fa7a6SAndroid Build Coastguard Worker maketest( 743*523fa7a6SAndroid Build Coastguard Worker BatchNormModel, 744*523fa7a6SAndroid Build Coastguard Worker capture_config=exir.CaptureConfig( 745*523fa7a6SAndroid Build Coastguard Worker enable_dynamic_shape=True, 746*523fa7a6SAndroid Build Coastguard Worker ), 747*523fa7a6SAndroid Build Coastguard Worker verify_graph=BatchNormModel.verify_graph, 748*523fa7a6SAndroid Build Coastguard Worker # TODO: lean mode does not have native_batch_norm.out implemented 749*523fa7a6SAndroid Build Coastguard Worker # run this on aten mode. 750*523fa7a6SAndroid Build Coastguard Worker run_executor=is_aten_mode, 751*523fa7a6SAndroid Build Coastguard Worker )(self) 752