1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: fx"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport builtins 4*da0073e9SAndroid Build Coastguard Workerimport contextlib 5*da0073e9SAndroid Build Coastguard Workerimport copy 6*da0073e9SAndroid Build Coastguard Workerimport functools 7*da0073e9SAndroid Build Coastguard Workerimport inspect 8*da0073e9SAndroid Build Coastguard Workerimport math 9*da0073e9SAndroid Build Coastguard Workerimport numbers 10*da0073e9SAndroid Build Coastguard Workerimport io 11*da0073e9SAndroid Build Coastguard Workerimport operator 12*da0073e9SAndroid Build Coastguard Workerimport os 13*da0073e9SAndroid Build Coastguard Workerimport pickle 14*da0073e9SAndroid Build Coastguard Workerimport sys 15*da0073e9SAndroid Build Coastguard Workerimport torch 16*da0073e9SAndroid Build Coastguard Workerimport traceback 17*da0073e9SAndroid Build Coastguard Workerimport typing 18*da0073e9SAndroid Build Coastguard Workerimport types 19*da0073e9SAndroid Build Coastguard Workerimport warnings 20*da0073e9SAndroid Build Coastguard Workerimport unittest 21*da0073e9SAndroid Build Coastguard Workerfrom math import sqrt 22*da0073e9SAndroid Build Coastguard Workerfrom functorch.experimental import control_flow 23*da0073e9SAndroid Build Coastguard Workerfrom torch.multiprocessing import Process 24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db 26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests 27*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 28*da0073e9SAndroid Build Coastguard Workerimport torch.fx._pytree as fx_pytree 29*da0073e9SAndroid Build Coastguard Workerfrom torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen 30*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.node import Target, Argument, _format_arg 31*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes import shape_prop 32*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.immutable_collections import immutable_dict, immutable_list 33*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.rewriter import RewritingTracer 34*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.operator_schemas import get_signature_for_torch_op 35*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy 36*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.proxy import TraceError 39*da0073e9SAndroid Build Coastguard Workerfrom torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMPATIBILITY 40*da0073e9SAndroid Build Coastguard Workerfrom torch.fx._symbolic_trace import PHBase, PHWithMeta 41*da0073e9SAndroid Build Coastguard Workerfrom fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401 42*da0073e9SAndroid Build Coastguard Workerfrom fx.test_dce_pass import TestDCE # noqa: F401 43*da0073e9SAndroid Build Coastguard Workerfrom fx.test_fx_const_fold import TestConstFold # noqa: F401 44*da0073e9SAndroid Build Coastguard Workerfrom fx.test_fx_param_shape_control_flow import TestConstParamShapeInControlFlow # noqa: F401 45*da0073e9SAndroid Build Coastguard Workerfrom fx.test_pass_infra import TestPassManager # noqa: F401 46*da0073e9SAndroid Build Coastguard Workerfrom fx.test_common_passes import TestCommonPass # noqa: F401 47*da0073e9SAndroid Build Coastguard Workerfrom fx.test_cse_pass import TestCSEPass # noqa: F401 48*da0073e9SAndroid Build Coastguard Workerfrom fx.test_matcher_utils import TestMatcher # noqa: F401 49*da0073e9SAndroid Build Coastguard Workerfrom fx.test_source_matcher_utils import TestSourceMatcher # noqa: F401 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Workerfrom fx.test_gradual_type import AnnotationsTest # noqa: F401 52*da0073e9SAndroid Build Coastguard Workerfrom fx.test_gradual_type import TypeCheckerTest # noqa: F401 53*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, NamedTuple, List, Optional, Set, Tuple, Union 54*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 55*da0073e9SAndroid Build Coastguard Worker IS_FBCODE, 56*da0073e9SAndroid Build Coastguard Worker IS_MACOS, 57*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 58*da0073e9SAndroid Build Coastguard Worker find_library_location, 59*da0073e9SAndroid Build Coastguard Worker run_tests, 60*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 61*da0073e9SAndroid Build Coastguard Worker) 62*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Workerfrom fx.named_tup import MyNamedTup 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Workertry: 67*da0073e9SAndroid Build Coastguard Worker from torchvision import models as torchvision_models 68*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = True 69*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 70*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = False 71*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 72*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantization import skipIfNoDynamoSupport 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Workerclass SimpleTest(torch.nn.Module): 75*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 76*da0073e9SAndroid Build Coastguard Worker return torch.relu(x + 3.0) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Workerdef a_non_torch_leaf(a, b): 79*da0073e9SAndroid Build Coastguard Worker return a + b 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker# Used for test_autowrap_function. Autowrapped functions need to be global 82*da0073e9SAndroid Build Coastguard Workerdef fx_int(x: float) -> int: 83*da0073e9SAndroid Build Coastguard Worker return int(x) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workerdef fx_int_x2(x: float) -> int: 86*da0073e9SAndroid Build Coastguard Worker return int(x) * 2 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker# used in test_pytree. It's all the way out here because pickling a GraphModule 89*da0073e9SAndroid Build Coastguard Worker# that uses Point errors out if Point is local to the function 90*da0073e9SAndroid Build Coastguard WorkerPoint = namedtuple('Point', ['x', 'y']) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker# Test wrap() passing both a function name as well as a function 93*da0073e9SAndroid Build Coastguard Worker# directly 94*da0073e9SAndroid Build Coastguard Workerdef a_lifted_leaf(a, b): 95*da0073e9SAndroid Build Coastguard Worker return a[0] + a[1] + b 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Workerwrap('a_lifted_leaf') 98*da0073e9SAndroid Build Coastguard Worker# Test wrapping twice doesn't break anything 99*da0073e9SAndroid Build Coastguard Workerwrap('a_lifted_leaf') 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Workerdef a_lifted_leaf2(a, b): 102*da0073e9SAndroid Build Coastguard Worker return a[0] + a[1] + b 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Workerwrap(a_lifted_leaf2) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Workerwrap('len') 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Workerwrap('getattr') 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Workerdef wrapped_named_tup(p1, *, p2): 111*da0073e9SAndroid Build Coastguard Worker return p1.x + p2.y 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Workerwrap(wrapped_named_tup) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker@wrap 116*da0073e9SAndroid Build Coastguard Workerdef wrapped_via_decorator(a): 117*da0073e9SAndroid Build Coastguard Worker return a + 1 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Workerwrap('wrapped_with_submodule') 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Workerdef wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d): 122*da0073e9SAndroid Build Coastguard Worker return batchnorm1d(x) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Workerdef my_decorator(f): 125*da0073e9SAndroid Build Coastguard Worker @functools.wraps(f) 126*da0073e9SAndroid Build Coastguard Worker def wrapper_inside_decorator(*args, **kwargs): 127*da0073e9SAndroid Build Coastguard Worker return f(*args, **kwargs) 128*da0073e9SAndroid Build Coastguard Worker return wrapper_inside_decorator 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker@wrap 131*da0073e9SAndroid Build Coastguard Worker@my_decorator 132*da0073e9SAndroid Build Coastguard Workerdef wrapped_decorated_fn(x): 133*da0073e9SAndroid Build Coastguard Worker return x 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Workerreal_wrapped_via_decorator = wrapped_via_decorator 136*da0073e9SAndroid Build Coastguard Workerreal_a_lifed_leaf = a_lifted_leaf 137*da0073e9SAndroid Build Coastguard Workerreal_a_lifed_leaf2 = a_lifted_leaf2 138*da0073e9SAndroid Build Coastguard Worker_sqrt = sqrt 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Workerwrap('wrapper_fn') 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Workerdef wrapper_fn(x): 143*da0073e9SAndroid Build Coastguard Worker return torch.foo(x) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Workerclass Pair(NamedTuple): 146*da0073e9SAndroid Build Coastguard Worker x : torch.Tensor 147*da0073e9SAndroid Build Coastguard Worker y : torch.Tensor 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker def _custom_fx_repr_fn(self) -> str: 150*da0073e9SAndroid Build Coastguard Worker return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})" 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker# for testing pytrees 153*da0073e9SAndroid Build Coastguard Workerclass Foo: # noqa: B209 154*da0073e9SAndroid Build Coastguard Worker def __init__(self, a, b): 155*da0073e9SAndroid Build Coastguard Worker self.a = a 156*da0073e9SAndroid Build Coastguard Worker self.b = b 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Workerclass Add(torch.nn.Module): 159*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 160*da0073e9SAndroid Build Coastguard Worker return x + x 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker@torch.fx.has_side_effect 163*da0073e9SAndroid Build Coastguard Worker@torch.fx.wrap 164*da0073e9SAndroid Build Coastguard Workerdef side_effect_func(x: torch.Tensor): 165*da0073e9SAndroid Build Coastguard Worker print(x) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Workerclass TestFX(JitTestCase): 168*da0073e9SAndroid Build Coastguard Worker def setUp(self): 169*da0073e9SAndroid Build Coastguard Worker super().setUp() 170*da0073e9SAndroid Build Coastguard Worker # Checking for mutable operations whil tracing is feature flagged 171*da0073e9SAndroid Build Coastguard Worker # Enable it in testing but not by default 172*da0073e9SAndroid Build Coastguard Worker self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations 173*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.check_mutable_operations = True 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker if not (IS_FBCODE or IS_WINDOWS or IS_MACOS): 176*da0073e9SAndroid Build Coastguard Worker lib_file_path = find_library_location('libtorchbind_test.so') 177*da0073e9SAndroid Build Coastguard Worker torch.ops.load_library(str(lib_file_path)) 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 180*da0073e9SAndroid Build Coastguard Worker super().tearDown() 181*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): 184*da0073e9SAndroid Build Coastguard Worker """Check that an nn.Module's results match the GraphModule version 185*da0073e9SAndroid Build Coastguard Worker for a given set of args/kwargs. 186*da0073e9SAndroid Build Coastguard Worker """ 187*da0073e9SAndroid Build Coastguard Worker kwargs = kwargs if kwargs else {} 188*da0073e9SAndroid Build Coastguard Worker ref_outs = m(*args, **kwargs) 189*da0073e9SAndroid Build Coastguard Worker gm = symbolic_trace(m) 190*da0073e9SAndroid Build Coastguard Worker gm.graph.lint() 191*da0073e9SAndroid Build Coastguard Worker test_outs = gm(*args, **kwargs) 192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_outs, test_outs) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker def test_graph_module(self): 195*da0073e9SAndroid Build Coastguard Worker class MySub(torch.nn.Module): 196*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 197*da0073e9SAndroid Build Coastguard Worker super().__init__() 198*da0073e9SAndroid Build Coastguard Worker self.w = torch.nn.Parameter(torch.rand(4, 3)) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 201*da0073e9SAndroid Build Coastguard Worker return self.w + x 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 204*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 205*da0073e9SAndroid Build Coastguard Worker super().__init__() 206*da0073e9SAndroid Build Coastguard Worker self.lin = torch.nn.Linear(4, 3) 207*da0073e9SAndroid Build Coastguard Worker self.sub_mod = MySub() 208*da0073e9SAndroid Build Coastguard Worker self.w = torch.nn.Parameter(torch.rand(3)) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker def forward(self, A, B, c): 211*da0073e9SAndroid Build Coastguard Worker t = torch.sigmoid(A) + self.lin(c) 212*da0073e9SAndroid Build Coastguard Worker return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3)) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker m = MyModule() 215*da0073e9SAndroid Build Coastguard Worker gm = symbolic_trace(m) 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker ms = torch.jit.script(gm) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker class M2(torch.nn.Module): 220*da0073e9SAndroid Build Coastguard Worker def forward(self, A): 221*da0073e9SAndroid Build Coastguard Worker m, idx = torch.max(A, 0) 222*da0073e9SAndroid Build Coastguard Worker return m + 1, idx + 1 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker m2 = M2() 225*da0073e9SAndroid Build Coastguard Worker gm2 = symbolic_trace(m2) 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker class T(torch.nn.Module): 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker def forward(self, A, b=4, *args, c=5, **kwargs): 230*da0073e9SAndroid Build Coastguard Worker x = A + 1 + args[0] + kwargs['3'] 231*da0073e9SAndroid Build Coastguard Worker return x 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker t = T() 234*da0073e9SAndroid Build Coastguard Worker symbolic_trace(t) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker # test for issue described at https://github.com/pytorch/pytorch/issues/63883 237*da0073e9SAndroid Build Coastguard Worker class M3(torch.nn.Module): 238*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 239*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker m3 = M3() 242*da0073e9SAndroid Build Coastguard Worker gm3 = symbolic_trace(m3) 243*da0073e9SAndroid Build Coastguard Worker new_instance = gm3.__new__(type(gm3)) 244*da0073e9SAndroid Build Coastguard Worker new_instance.__init__(gm3, gm3.graph) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 3) 247*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(new_instance(x), torch.relu(x)) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker def test_informative_co_filename(self): 250*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 251*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 252*da0073e9SAndroid Build Coastguard Worker return a * 2 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker gm = symbolic_trace(MyModule()) 255*da0073e9SAndroid Build Coastguard Worker self.assertIn(os.path.basename(__file__), gm.forward.__code__.co_filename) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker def test_custom_import(self): 258*da0073e9SAndroid Build Coastguard Worker graph = torch.fx.Graph() 259*da0073e9SAndroid Build Coastguard Worker a = graph.placeholder('x') 260*da0073e9SAndroid Build Coastguard Worker b = graph.placeholder('y') 261*da0073e9SAndroid Build Coastguard Worker c = graph.call_function(a_non_torch_leaf, (a, b)) 262*da0073e9SAndroid Build Coastguard Worker d = graph.call_function(torch.sin, (c,)) 263*da0073e9SAndroid Build Coastguard Worker graph.output(d) 264*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(torch.nn.Module(), graph) 265*da0073e9SAndroid Build Coastguard Worker x, y = torch.rand(1), torch.rand(1) 266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sin(x + y), gm(x, y)) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker def test_args_kwargs(self): 269*da0073e9SAndroid Build Coastguard Worker class T(torch.nn.Module): 270*da0073e9SAndroid Build Coastguard Worker def forward(self, *args, **kwargs): 271*da0073e9SAndroid Build Coastguard Worker x = args[0] + kwargs['foo'] 272*da0073e9SAndroid Build Coastguard Worker return x 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker t = T() 275*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker def test_varargs_concrete(self): 278*da0073e9SAndroid Build Coastguard Worker class T(torch.nn.Module): 279*da0073e9SAndroid Build Coastguard Worker def forward(self, *args, **kwargs): 280*da0073e9SAndroid Build Coastguard Worker x = args[0] + args[1] 281*da0073e9SAndroid Build Coastguard Worker return x 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker args = (torch.rand(1), torch.rand(1)) 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker t = T() 286*da0073e9SAndroid Build Coastguard Worker ref_outs = t(*args) 287*da0073e9SAndroid Build Coastguard Worker gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH)) 288*da0073e9SAndroid Build Coastguard Worker gm.graph.lint() 289*da0073e9SAndroid Build Coastguard Worker test_outs = gm(*args) 290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_outs, test_outs) 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker def test_args_kwargs_no_self(self): 293*da0073e9SAndroid Build Coastguard Worker class T(torch.nn.Module): 294*da0073e9SAndroid Build Coastguard Worker def forward(*args, **kwargs): # noqa: B902 295*da0073e9SAndroid Build Coastguard Worker self = args[0] 296*da0073e9SAndroid Build Coastguard Worker return torch.relu(args[1]) 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker t = T() 299*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'): 300*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker def test_fx_shifts(self): 303*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 304*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 305*da0073e9SAndroid Build Coastguard Worker return x << 3, x >> 3 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker input = torch.LongTensor(10).random_(0, 1024) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker m = MyModule() 310*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, (input,)) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker def test_fx_and_or(self): 313*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 314*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 315*da0073e9SAndroid Build Coastguard Worker return x & x, x | x 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker input = torch.LongTensor(10).random_(0, 1024) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker m = MyModule() 320*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, (input,)) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker def test_dict(self): 323*da0073e9SAndroid Build Coastguard Worker class MyDictMod(torch.nn.Module): 324*da0073e9SAndroid Build Coastguard Worker def forward(self, d): 325*da0073e9SAndroid Build Coastguard Worker return d['3'].relu(), {'4' : d['3'].neg()} 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker input_dict = {'3': torch.rand(3, 4)} 328*da0073e9SAndroid Build Coastguard Worker m = MyDictMod() 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, (input_dict,)) 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker def test_matmul_tracing(self): 333*da0073e9SAndroid Build Coastguard Worker const = torch.randn(3) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker def matmul_f(x): 336*da0073e9SAndroid Build Coastguard Worker return x @ const 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker mod = symbolic_trace(matmul_f) 339*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3) 340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod(inp), matmul_f(inp)) 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker def rmatmul_f(x): 343*da0073e9SAndroid Build Coastguard Worker return const @ x 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker mod = symbolic_trace(rmatmul_f) 346*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3) 347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod(inp), rmatmul_f(inp)) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker @skipIfNoDynamoSupport 350*da0073e9SAndroid Build Coastguard Worker def test_control_flow_tracing(self): 351*da0073e9SAndroid Build Coastguard Worker def true(x, y): 352*da0073e9SAndroid Build Coastguard Worker return x + y 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker def false(x, y): 355*da0073e9SAndroid Build Coastguard Worker return x - y 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker def f(x, y): 358*da0073e9SAndroid Build Coastguard Worker x = control_flow.cond(x[0] == 0, true, false, [x, y]) 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Expected pred to be bool or tensor, but got Proxy\(eq\)"): 361*da0073e9SAndroid Build Coastguard Worker _ = symbolic_trace(f) 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker def test_disallow_override(self): 364*da0073e9SAndroid Build Coastguard Worker # Custom delegate to disallow in-place tensor operations 365*da0073e9SAndroid Build Coastguard Worker class NoMutableCallTracer(Tracer): 366*da0073e9SAndroid Build Coastguard Worker def create_node(self, kind : str, target : Union[str, Callable], 367*da0073e9SAndroid Build Coastguard Worker args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None, 368*da0073e9SAndroid Build Coastguard Worker type_expr : Optional[Any] = None) -> Node: 369*da0073e9SAndroid Build Coastguard Worker name = target if isinstance(target, str) else torch.typename(target) 370*da0073e9SAndroid Build Coastguard Worker if name[-1] == '_': 371*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('In-place operations are not supported') 372*da0073e9SAndroid Build Coastguard Worker return super().create_node(kind, target, args, kwargs, name) 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker # Test method 375*da0073e9SAndroid Build Coastguard Worker class MyInplaceMod(torch.nn.Module): 376*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 377*da0073e9SAndroid Build Coastguard Worker x.add_(3.0) 378*da0073e9SAndroid Build Coastguard Worker return x 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker m = MyInplaceMod() 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'In-place operations'): 383*da0073e9SAndroid Build Coastguard Worker NoMutableCallTracer().trace(m) 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker # Test free function 386*da0073e9SAndroid Build Coastguard Worker class MyInplaceMod2(torch.nn.Module): 387*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 388*da0073e9SAndroid Build Coastguard Worker torch.log_(x) 389*da0073e9SAndroid Build Coastguard Worker return x 390*da0073e9SAndroid Build Coastguard Worker m2 = MyInplaceMod2() 391*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'In-place operations'): 392*da0073e9SAndroid Build Coastguard Worker NoMutableCallTracer().trace(m2) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker # Test symbolic node as an arg 395*da0073e9SAndroid Build Coastguard Worker class MyInplaceMod3(torch.nn.Module): 396*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 397*da0073e9SAndroid Build Coastguard Worker y = torch.ones(3, 4) 398*da0073e9SAndroid Build Coastguard Worker y.add_(x) 399*da0073e9SAndroid Build Coastguard Worker return x 400*da0073e9SAndroid Build Coastguard Worker m3 = MyInplaceMod3() 401*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'In-place operations'): 402*da0073e9SAndroid Build Coastguard Worker NoMutableCallTracer().trace(m3) 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker def test_leaf_module(self): 405*da0073e9SAndroid Build Coastguard Worker # Custom delegate to make it so that there are no leaf modules, everything 406*da0073e9SAndroid Build Coastguard Worker # should get traced through 407*da0073e9SAndroid Build Coastguard Worker class NoLeafModulesTracer(Tracer): 408*da0073e9SAndroid Build Coastguard Worker def is_leaf_module(self, m, qualname): 409*da0073e9SAndroid Build Coastguard Worker return False 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker class MyReluMod(torch.nn.Module): 412*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 413*da0073e9SAndroid Build Coastguard Worker super().__init__() 414*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 417*da0073e9SAndroid Build Coastguard Worker return self.relu(x) 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker mrm = MyReluMod() 420*da0073e9SAndroid Build Coastguard Worker sym = NoLeafModulesTracer().trace(mrm) 421*da0073e9SAndroid Build Coastguard Worker for node in sym.nodes: 422*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(node.op, 'call_module') 423*da0073e9SAndroid Build Coastguard Worker sym.lint() 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker def test_wrap(self): 426*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5)) 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker def to_trace(y): 429*da0073e9SAndroid Build Coastguard Worker return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y) 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker m = symbolic_trace(to_trace) 432*da0073e9SAndroid Build Coastguard Worker self.assertIn('a_lifted_leaf', m.code) 433*da0073e9SAndroid Build Coastguard Worker self.assertEqual(27, m(2)) 434*da0073e9SAndroid Build Coastguard Worker self.assertIs(a_lifted_leaf, real_a_lifed_leaf) 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker def test_wrap_fn_directly(self): 437*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5)) 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker def to_trace(y): 440*da0073e9SAndroid Build Coastguard Worker return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker m = symbolic_trace(to_trace) 443*da0073e9SAndroid Build Coastguard Worker self.assertIn('a_lifted_leaf2', m.code) 444*da0073e9SAndroid Build Coastguard Worker self.assertEqual(27, m(2)) 445*da0073e9SAndroid Build Coastguard Worker self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker def test_wrapped_via_decorator(self): 448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapped_via_decorator(0), 1) 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker def to_trace(y): 451*da0073e9SAndroid Build Coastguard Worker return wrapped_via_decorator(y) 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker m = symbolic_trace(to_trace) 454*da0073e9SAndroid Build Coastguard Worker self.assertIn('wrapped_via_decorator', m.code) 455*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(0), 1) 456*da0073e9SAndroid Build Coastguard Worker self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) 457*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker def test_wrapped_via_decorator_and_transformed(self): 460*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapped_via_decorator(0), 1) 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker def to_trace(y): 463*da0073e9SAndroid Build Coastguard Worker return wrapped_via_decorator(y) 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker m = symbolic_trace(to_trace) 466*da0073e9SAndroid Build Coastguard Worker self.assertIn('wrapped_via_decorator', m.code) 467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(0), 1) 468*da0073e9SAndroid Build Coastguard Worker self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) 469*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker transformed = torch.fx.Transformer(m).transform() 472*da0073e9SAndroid Build Coastguard Worker self.assertIn('wrapped_via_decorator', transformed.code) 473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(transformed(0), 1) 474*da0073e9SAndroid Build Coastguard Worker self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) 475*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker def test_wrap_with_submodule(self): 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 480*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 481*da0073e9SAndroid Build Coastguard Worker super().__init__() 482*da0073e9SAndroid Build Coastguard Worker self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor): 485*da0073e9SAndroid Build Coastguard Worker return wrapped_with_submodule(x, self.batchnorm1d) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker m = symbolic_trace(M()) 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker self.assertIn("wrapped_with_submodule", m.code) 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3, 2) 492*da0073e9SAndroid Build Coastguard Worker ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) 493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_batchnorm1d(input), m(input)) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker def test_wrapped_retrace(self): 496*da0073e9SAndroid Build Coastguard Worker def to_trace(y): 497*da0073e9SAndroid Build Coastguard Worker return wrapped_via_decorator(y) 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker m = symbolic_trace(to_trace) 500*da0073e9SAndroid Build Coastguard Worker self.assertIn('wrapped_via_decorator', m.code) 501*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(0), 1) 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker retraced = symbolic_trace(m) 504*da0073e9SAndroid Build Coastguard Worker self.assertIn('wrapped_via_decorator', retraced.code) 505*da0073e9SAndroid Build Coastguard Worker self.assertEqual(retraced(0), 1) 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker def test_wrap_decorated_function(self): 508*da0073e9SAndroid Build Coastguard Worker def to_trace(y): 509*da0073e9SAndroid Build Coastguard Worker return wrapped_decorated_fn(y) 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker m = symbolic_trace(to_trace) 512*da0073e9SAndroid Build Coastguard Worker self.assertIn('wrapped_decorated_fn', m.code) 513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(1), 1) 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker def test_graph_edit_with_proxy(self): 516*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 517*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 518*da0073e9SAndroid Build Coastguard Worker return a + b 519*da0073e9SAndroid Build Coastguard Worker m = M() 520*da0073e9SAndroid Build Coastguard Worker g = symbolic_trace(m).graph 521*da0073e9SAndroid Build Coastguard Worker new_g = torch.fx.Graph() 522*da0073e9SAndroid Build Coastguard Worker val_map : Dict[Node, Node] = {} 523*da0073e9SAndroid Build Coastguard Worker output_val = new_g.graph_copy(g, val_map) 524*da0073e9SAndroid Build Coastguard Worker t = Proxy(output_val) 525*da0073e9SAndroid Build Coastguard Worker # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. 526*da0073e9SAndroid Build Coastguard Worker new_g.output((t + t).node) 527*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(m, new_g) 528*da0073e9SAndroid Build Coastguard Worker gm.graph.lint() 529*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gm(3, 4), 14) 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker def test_proxy_deepcopy_without_tracer(self): 532*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 533*da0073e9SAndroid Build Coastguard Worker def __init__(self): 534*da0073e9SAndroid Build Coastguard Worker super().__init__() 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 537*da0073e9SAndroid Build Coastguard Worker return 2 * x 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker module = MyModule() 540*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(module) 541*da0073e9SAndroid Build Coastguard Worker node = list(traced.graph.nodes)[-2] 542*da0073e9SAndroid Build Coastguard Worker p = torch.fx.Proxy(node, None) 543*da0073e9SAndroid Build Coastguard Worker node.proxy = p 544*da0073e9SAndroid Build Coastguard Worker p2 = copy.deepcopy(p) 545*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(p2, torch.fx.Proxy)) 546*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p2.node.name, node.name) 547*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p2.node.target, node.target) 548*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(id(p2.node), id(node)) 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker def test_proxy_deepcopy_with_tracer(self): 551*da0073e9SAndroid Build Coastguard Worker class TestTracer(Tracer): 552*da0073e9SAndroid Build Coastguard Worker def __init__(self, name): 553*da0073e9SAndroid Build Coastguard Worker super().__init__() 554*da0073e9SAndroid Build Coastguard Worker self.name = name 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker def is_leaf_module(self, module, name): 557*da0073e9SAndroid Build Coastguard Worker return True 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 560*da0073e9SAndroid Build Coastguard Worker def __init__(self): 561*da0073e9SAndroid Build Coastguard Worker super().__init__() 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 564*da0073e9SAndroid Build Coastguard Worker return 2 * x 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker module = MyModule() 567*da0073e9SAndroid Build Coastguard Worker tracer = TestTracer("mytracer") 568*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(module) 569*da0073e9SAndroid Build Coastguard Worker node = list(traced.graph.nodes)[-2] 570*da0073e9SAndroid Build Coastguard Worker p = torch.fx.Proxy(node, tracer) 571*da0073e9SAndroid Build Coastguard Worker node.proxy = p 572*da0073e9SAndroid Build Coastguard Worker p2 = copy.deepcopy(p) 573*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(p2, torch.fx.Proxy)) 574*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(p2.tracer, torch.fx._symbolic_trace.Tracer)) 575*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p2.tracer.name, "mytracer") 576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p2.node.name, node.name) 577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p2.node.target, node.target) 578*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(id(p2.node), id(node)) 579*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(id(p2.tracer), id(tracer)) 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker def test_concrete_arg_none_assert(self): 582*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 583*da0073e9SAndroid Build Coastguard Worker def forward(self, x, val=None): 584*da0073e9SAndroid Build Coastguard Worker return x if val is None else x + val 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker f = Foo() 587*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(f, concrete_args={'val' : None}) 588*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, 'val has been specialized to have value None'): 589*da0073e9SAndroid Build Coastguard Worker traced(torch.randn(5), torch.randn(5)) 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5) 592*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(traced(x), f(x)) 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker def test_trace_multiple_funcs(self): 595*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 596*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 597*da0073e9SAndroid Build Coastguard Worker return x + y 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker def minus_forward(self, x, y): 600*da0073e9SAndroid Build Coastguard Worker return x - y 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker def multiply_forward(self, x, y): 603*da0073e9SAndroid Build Coastguard Worker return x * y 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker f = Foo() 606*da0073e9SAndroid Build Coastguard Worker x, y = torch.randn(5), torch.randn(5) 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker print(torch.__version__) 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker tracer = Tracer() 611*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(GraphModule(f, tracer.trace(f))(x, y), f(x, y)) 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker tracer.traced_func_name = "minus_forward" 614*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 615*da0073e9SAndroid Build Coastguard Worker GraphModule(f, tracer.trace(f))(x, y), 616*da0073e9SAndroid Build Coastguard Worker f.minus_forward(x, y), 617*da0073e9SAndroid Build Coastguard Worker ) 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker tracer.traced_func_name = "multiply_forward" 620*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 621*da0073e9SAndroid Build Coastguard Worker GraphModule(f, tracer.trace(f))(x, y), 622*da0073e9SAndroid Build Coastguard Worker f.multiply_forward(x, y), 623*da0073e9SAndroid Build Coastguard Worker ) 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker tracer.traced_func_name = "add_forward" 626*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "doesn't exist in"): 627*da0073e9SAndroid Build Coastguard Worker tracer.trace(f) 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker def test_graph_unique_names(self): 630*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 631*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 632*da0073e9SAndroid Build Coastguard Worker return a + b 633*da0073e9SAndroid Build Coastguard Worker m = M() 634*da0073e9SAndroid Build Coastguard Worker g = symbolic_trace(m).graph 635*da0073e9SAndroid Build Coastguard Worker new_g = torch.fx.Graph() 636*da0073e9SAndroid Build Coastguard Worker val_map : Dict[Node, Node] = {} 637*da0073e9SAndroid Build Coastguard Worker output_val = new_g.graph_copy(g, val_map) 638*da0073e9SAndroid Build Coastguard Worker t = Proxy(output_val) 639*da0073e9SAndroid Build Coastguard Worker # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. 640*da0073e9SAndroid Build Coastguard Worker new_g.output((t + t).node) 641*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(m, new_g) 642*da0073e9SAndroid Build Coastguard Worker seen_names : Set[str] = set() 643*da0073e9SAndroid Build Coastguard Worker for node in gm.graph.nodes: 644*da0073e9SAndroid Build Coastguard Worker assert node.name not in seen_names 645*da0073e9SAndroid Build Coastguard Worker seen_names.add(node.name) 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker def test_stack_traces(self): 648*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 649*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 650*da0073e9SAndroid Build Coastguard Worker return a + b 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker tracer = torch.fx.Tracer() 653*da0073e9SAndroid Build Coastguard Worker tracer.record_stack_traces = True 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker graph = tracer.trace(M()) 656*da0073e9SAndroid Build Coastguard Worker # saving the original list because we will insert new nodes as a part of a test 657*da0073e9SAndroid Build Coastguard Worker orig_graph_nodes = list(graph.nodes) 658*da0073e9SAndroid Build Coastguard Worker for node in orig_graph_nodes: 659*da0073e9SAndroid Build Coastguard Worker if node.op == 'output': 660*da0073e9SAndroid Build Coastguard Worker continue 661*da0073e9SAndroid Build Coastguard Worker self.assertTrue(node.stack_trace is not None) 662*da0073e9SAndroid Build Coastguard Worker assert 'test_fx.py' in node.stack_trace 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker # verify that copying the node does not lose the stack trace 665*da0073e9SAndroid Build Coastguard Worker new_node = graph.node_copy(node) 666*da0073e9SAndroid Build Coastguard Worker self.assertTrue(new_node.stack_trace is not None) 667*da0073e9SAndroid Build Coastguard Worker assert 'test_fx.py' in new_node.stack_trace 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker def test_stack_traces_with_transformer(self): 670*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 671*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 672*da0073e9SAndroid Build Coastguard Worker return a + b 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Worker tracer = torch.fx.Tracer() 675*da0073e9SAndroid Build Coastguard Worker tracer.record_stack_traces = True 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker graph = tracer.trace(M()) 678*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(tracer.root, graph) 679*da0073e9SAndroid Build Coastguard Worker new_gm = Transformer(gm).transform() 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker # nodes after Transformer should still preserve the original node's stack trace 682*da0073e9SAndroid Build Coastguard Worker for node in new_gm.graph.nodes: 683*da0073e9SAndroid Build Coastguard Worker if node.op in {'placeholder', 'output'}: 684*da0073e9SAndroid Build Coastguard Worker continue 685*da0073e9SAndroid Build Coastguard Worker self.assertTrue(node.stack_trace is not None) 686*da0073e9SAndroid Build Coastguard Worker assert 'test_fx.py' in node.stack_trace 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker def test_lineno_map(self): 689*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 690*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 691*da0073e9SAndroid Build Coastguard Worker a = torch.sin(a) 692*da0073e9SAndroid Build Coastguard Worker b = torch.cos(b) 693*da0073e9SAndroid Build Coastguard Worker return a + b 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker tracer = torch.fx.Tracer() 696*da0073e9SAndroid Build Coastguard Worker graph = tracer.trace(M()) 697*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(tracer.root, graph) 698*da0073e9SAndroid Build Coastguard Worker expected = {1: 2, 2: 3, 3: 4, 4: 5} 699*da0073e9SAndroid Build Coastguard Worker self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items()))) 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker # test custom codegen 702*da0073e9SAndroid Build Coastguard Worker def transform_code(code): 703*da0073e9SAndroid Build Coastguard Worker return ["print('hello!')\n", *code] 704*da0073e9SAndroid Build Coastguard Worker gm.graph.on_generate_code(lambda _: transform_code) 705*da0073e9SAndroid Build Coastguard Worker gm.recompile() 706*da0073e9SAndroid Build Coastguard Worker expected = {2: 2, 3: 3, 4: 4, 5: 5} 707*da0073e9SAndroid Build Coastguard Worker self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items()))) 708*da0073e9SAndroid Build Coastguard Worker 709*da0073e9SAndroid Build Coastguard Worker def test_graph_unique_names_manual(self): 710*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.Graph = torch.fx.Graph() 711*da0073e9SAndroid Build Coastguard Worker a : torch.fx.Node = graph.create_node('placeholder', 'x') 712*da0073e9SAndroid Build Coastguard Worker b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1') 713*da0073e9SAndroid Build Coastguard Worker c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1') 714*da0073e9SAndroid Build Coastguard Worker d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) 715*da0073e9SAndroid Build Coastguard Worker graph.output(d) 716*da0073e9SAndroid Build Coastguard Worker graph2 = torch.fx.Graph() 717*da0073e9SAndroid Build Coastguard Worker val_map : Dict[Node, Node] = {} 718*da0073e9SAndroid Build Coastguard Worker graph2.graph_copy(graph, val_map) 719*da0073e9SAndroid Build Coastguard Worker seen_names : Set[str] = set() 720*da0073e9SAndroid Build Coastguard Worker for node in graph2.nodes: 721*da0073e9SAndroid Build Coastguard Worker assert node.name not in seen_names 722*da0073e9SAndroid Build Coastguard Worker seen_names.add(node.name) 723*da0073e9SAndroid Build Coastguard Worker 724*da0073e9SAndroid Build Coastguard Worker def test_unpack(self): 725*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 726*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 727*da0073e9SAndroid Build Coastguard Worker c, d = a 728*da0073e9SAndroid Build Coastguard Worker return c + d + b 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker a = (torch.rand(1), torch.rand(1)) 731*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1) 732*da0073e9SAndroid Build Coastguard Worker m = M() 733*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, (a, b)) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker def test_native_callable(self): 736*da0073e9SAndroid Build Coastguard Worker if IS_FBCODE or IS_WINDOWS or IS_MACOS: 737*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("non-portable load_library call used in test") 738*da0073e9SAndroid Build Coastguard Worker # This test exercises the case where we use FX to translate from Python 739*da0073e9SAndroid Build Coastguard Worker # code to some native callable object 740*da0073e9SAndroid Build Coastguard Worker # 741*da0073e9SAndroid Build Coastguard Worker # For the purposes of testing, we use ElementwiseInterpreter defined 742*da0073e9SAndroid Build Coastguard Worker # in test_custom_class.cpp. 743*da0073e9SAndroid Build Coastguard Worker # 744*da0073e9SAndroid Build Coastguard Worker # We test that we can 745*da0073e9SAndroid Build Coastguard Worker # 1) Construct a native callable from FX IR 746*da0073e9SAndroid Build Coastguard Worker # 2) Construct a drop-in replacement module that delegates to the 747*da0073e9SAndroid Build Coastguard Worker # native callable rather than the original code 748*da0073e9SAndroid Build Coastguard Worker # 3) Run both the original code and native callable wrapper with 749*da0073e9SAndroid Build Coastguard Worker # equivalent results 750*da0073e9SAndroid Build Coastguard Worker # 4) TorchScript compile the native callable wrapper and confirm 751*da0073e9SAndroid Build Coastguard Worker # equivalent results with the reference 752*da0073e9SAndroid Build Coastguard Worker # 5) TorchScript serialize and deserialize the native callable 753*da0073e9SAndroid Build Coastguard Worker # and confirm equivalent results with the reference 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Worker # We use this simple Module as a reference computation 756*da0073e9SAndroid Build Coastguard Worker class MySimpleMod(torch.nn.Module): 757*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 758*da0073e9SAndroid Build Coastguard Worker return 3.0 * x + x 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker msm = MySimpleMod() 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker # This is what a lowering pass might look like: a function that takes 763*da0073e9SAndroid Build Coastguard Worker # a valid nn.Module, symbolically traces it, lowers the Module to some 764*da0073e9SAndroid Build Coastguard Worker # representation, and wraps that representation up into another 765*da0073e9SAndroid Build Coastguard Worker # nn.Module instance that handles dispatch to the compiled/lowered code. 766*da0073e9SAndroid Build Coastguard Worker def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module: 767*da0073e9SAndroid Build Coastguard Worker # ===== Stage 1: Symbolic trace the module ===== 768*da0073e9SAndroid Build Coastguard Worker mod = symbolic_trace(orig_mod) 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker # ===== Stage 2: Lower GraphModule representation to the C++ 771*da0073e9SAndroid Build Coastguard Worker # interpreter's instruction format ====== 772*da0073e9SAndroid Build Coastguard Worker instructions = [] 773*da0073e9SAndroid Build Coastguard Worker constant_idx = 0 774*da0073e9SAndroid Build Coastguard Worker constants = {} 775*da0073e9SAndroid Build Coastguard Worker fn_input_names = [] 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker target_to_name = { 778*da0073e9SAndroid Build Coastguard Worker operator.add : "add", 779*da0073e9SAndroid Build Coastguard Worker operator.mul : "mul" 780*da0073e9SAndroid Build Coastguard Worker } 781*da0073e9SAndroid Build Coastguard Worker 782*da0073e9SAndroid Build Coastguard Worker output_node : Optional[Node] = None 783*da0073e9SAndroid Build Coastguard Worker # For each instruction, create a triple 784*da0073e9SAndroid Build Coastguard Worker # (instruction_name : str, inputs : List[str], output : str) 785*da0073e9SAndroid Build Coastguard Worker # to feed into the C++ interpreter 786*da0073e9SAndroid Build Coastguard Worker for n in mod.graph.nodes: 787*da0073e9SAndroid Build Coastguard Worker target, args, out_name = n.target, n.args, n.name 788*da0073e9SAndroid Build Coastguard Worker assert len(n.kwargs) == 0, "kwargs currently not supported" 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Worker if n.op == 'placeholder': 791*da0073e9SAndroid Build Coastguard Worker # Placeholders specify function argument names. Save these 792*da0073e9SAndroid Build Coastguard Worker # for later when we generate the wrapper GraphModule 793*da0073e9SAndroid Build Coastguard Worker fn_input_names.append(target) 794*da0073e9SAndroid Build Coastguard Worker elif n.op == 'call_function': 795*da0073e9SAndroid Build Coastguard Worker assert target in target_to_name, "Unsupported call target " + target 796*da0073e9SAndroid Build Coastguard Worker arg_names = [] 797*da0073e9SAndroid Build Coastguard Worker for arg in args: 798*da0073e9SAndroid Build Coastguard Worker if not isinstance(arg, Node): 799*da0073e9SAndroid Build Coastguard Worker # Pull out constants. These constants will later be 800*da0073e9SAndroid Build Coastguard Worker # fed to the interpreter C++ object via add_constant() 801*da0073e9SAndroid Build Coastguard Worker arg_name = f'constant_{constant_idx}' 802*da0073e9SAndroid Build Coastguard Worker constants[arg_name] = torch.tensor( 803*da0073e9SAndroid Build Coastguard Worker [arg] if isinstance(arg, numbers.Number) else arg) 804*da0073e9SAndroid Build Coastguard Worker arg_names.append(arg_name) 805*da0073e9SAndroid Build Coastguard Worker constant_idx += 1 806*da0073e9SAndroid Build Coastguard Worker else: 807*da0073e9SAndroid Build Coastguard Worker arg_names.append(arg.name) 808*da0073e9SAndroid Build Coastguard Worker instructions.append((target_to_name[target], arg_names, out_name)) 809*da0073e9SAndroid Build Coastguard Worker elif n.op == 'output': 810*da0073e9SAndroid Build Coastguard Worker if output_node is not None: 811*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('Multiple output nodes!') 812*da0073e9SAndroid Build Coastguard Worker output_node = n 813*da0073e9SAndroid Build Coastguard Worker else: 814*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('Unsupported opcode ' + n.op) 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter() 817*da0073e9SAndroid Build Coastguard Worker # Load constants 818*da0073e9SAndroid Build Coastguard Worker for k, v in constants.items(): 819*da0073e9SAndroid Build Coastguard Worker interpreter.add_constant(k, v) 820*da0073e9SAndroid Build Coastguard Worker # Specify names for positional input arguments 821*da0073e9SAndroid Build Coastguard Worker interpreter.set_input_names(fn_input_names) 822*da0073e9SAndroid Build Coastguard Worker # Load instructions 823*da0073e9SAndroid Build Coastguard Worker interpreter.set_instructions(instructions) 824*da0073e9SAndroid Build Coastguard Worker # Specify name for single output 825*da0073e9SAndroid Build Coastguard Worker assert isinstance(output_node.args[0], torch.fx.Node) 826*da0073e9SAndroid Build Coastguard Worker interpreter.set_output_name(output_node.args[0].name) 827*da0073e9SAndroid Build Coastguard Worker 828*da0073e9SAndroid Build Coastguard Worker # ===== Stage 3: Create a wrapper GraphModule around the interpreter ===== 829*da0073e9SAndroid Build Coastguard Worker class WrapperModule(torch.nn.Module): 830*da0073e9SAndroid Build Coastguard Worker def __init__(self, interpreter): 831*da0073e9SAndroid Build Coastguard Worker super().__init__() 832*da0073e9SAndroid Build Coastguard Worker self.interpreter = interpreter 833*da0073e9SAndroid Build Coastguard Worker 834*da0073e9SAndroid Build Coastguard Worker wrapper = WrapperModule(interpreter) 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter 837*da0073e9SAndroid Build Coastguard Worker # 3) Returns the speficied return value 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker # FIXME: The following code could be greatly simplified by symbolic_trace'ing 840*da0073e9SAndroid Build Coastguard Worker # the wrapper with a Tracer that considers the Wrapper instance a root 841*da0073e9SAndroid Build Coastguard Worker # module, however, I can't get `__call__` exposed on TorchBind classes 842*da0073e9SAndroid Build Coastguard Worker # without it messing up Python `hasattr` for some reason. More digging 843*da0073e9SAndroid Build Coastguard Worker # into CPython's implementation of hasattr is probably in order... 844*da0073e9SAndroid Build Coastguard Worker 845*da0073e9SAndroid Build Coastguard Worker graph = torch.fx.Graph() 846*da0073e9SAndroid Build Coastguard Worker # Add placeholders for fn inputs 847*da0073e9SAndroid Build Coastguard Worker placeholder_nodes = [] 848*da0073e9SAndroid Build Coastguard Worker for name in fn_input_names: 849*da0073e9SAndroid Build Coastguard Worker placeholder_nodes.append(graph.create_node('placeholder', name)) 850*da0073e9SAndroid Build Coastguard Worker 851*da0073e9SAndroid Build Coastguard Worker # Get the interpreter object 852*da0073e9SAndroid Build Coastguard Worker interpreter_node = graph.create_node('get_attr', 'interpreter') 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker # Add a node to call the interpreter instance 855*da0073e9SAndroid Build Coastguard Worker output_node = graph.create_node( 856*da0073e9SAndroid Build Coastguard Worker op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes)) 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Worker # Register output 859*da0073e9SAndroid Build Coastguard Worker graph.output(output_node) 860*da0073e9SAndroid Build Coastguard Worker 861*da0073e9SAndroid Build Coastguard Worker graph.lint() 862*da0073e9SAndroid Build Coastguard Worker 863*da0073e9SAndroid Build Coastguard Worker # Return final GraphModule!!! 864*da0073e9SAndroid Build Coastguard Worker return GraphModule(wrapper, graph) 865*da0073e9SAndroid Build Coastguard Worker 866*da0073e9SAndroid Build Coastguard Worker # Lower GraphModule to C++ interpreter 867*da0073e9SAndroid Build Coastguard Worker lowered = lower_to_elementwise_interpreter(msm) 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Worker # Compare correctness with original module 870*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 871*da0073e9SAndroid Build Coastguard Worker ref_out = msm(x) 872*da0073e9SAndroid Build Coastguard Worker test_out = lowered(x) 873*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(test_out, ref_out) 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker # Test TorchScript compilation 876*da0073e9SAndroid Build Coastguard Worker scripted_lowered = torch.jit.script(lowered) 877*da0073e9SAndroid Build Coastguard Worker script_out = scripted_lowered(x) 878*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(script_out, ref_out) 879*da0073e9SAndroid Build Coastguard Worker 880*da0073e9SAndroid Build Coastguard Worker # Test TorchScript ser/de 881*da0073e9SAndroid Build Coastguard Worker import_copy = self.getExportImportCopy(scripted_lowered) 882*da0073e9SAndroid Build Coastguard Worker imported_out = import_copy(x) 883*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(imported_out, ref_out) 884*da0073e9SAndroid Build Coastguard Worker 885*da0073e9SAndroid Build Coastguard Worker def test_reserved_getattr(self): 886*da0073e9SAndroid Build Coastguard Worker """Ensure that we do not name any nodes with a reserved builtin like `getattr`""" 887*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 888*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 889*da0073e9SAndroid Build Coastguard Worker return a.foo.bar.baz 890*da0073e9SAndroid Build Coastguard Worker 891*da0073e9SAndroid Build Coastguard Worker m = M() 892*da0073e9SAndroid Build Coastguard Worker m_g = symbolic_trace(m) 893*da0073e9SAndroid Build Coastguard Worker m_g.graph.lint() 894*da0073e9SAndroid Build Coastguard Worker for node in m_g.graph.nodes: 895*da0073e9SAndroid Build Coastguard Worker self.assertTrue(node.name != "getattr") 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Hotfix for SEV remediation") 898*da0073e9SAndroid Build Coastguard Worker def test_trace_buffer_slice(self): 899*da0073e9SAndroid Build Coastguard Worker bs, d_hid = 10, 23 900*da0073e9SAndroid Build Coastguard Worker 901*da0073e9SAndroid Build Coastguard Worker class ExampleCode(torch.nn.Module): 902*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 903*da0073e9SAndroid Build Coastguard Worker super().__init__() 904*da0073e9SAndroid Build Coastguard Worker self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 905*da0073e9SAndroid Build Coastguard Worker self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 906*da0073e9SAndroid Build Coastguard Worker self.lin = torch.nn.Linear(d_hid, d_hid) 907*da0073e9SAndroid Build Coastguard Worker self.buffer = torch.nn.Buffer(torch.randn(bs + 100, d_hid)) 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 910*da0073e9SAndroid Build Coastguard Worker x = torch.mm(x, self.mm_param) 911*da0073e9SAndroid Build Coastguard Worker skip_connection = x 912*da0073e9SAndroid Build Coastguard Worker x = torch.relu(x) 913*da0073e9SAndroid Build Coastguard Worker x = torch.mm(x, self.mm_param) + self.buffer[:x.shape[0]] 914*da0073e9SAndroid Build Coastguard Worker x = self.lin(x) 915*da0073e9SAndroid Build Coastguard Worker x = torch.relu(x) 916*da0073e9SAndroid Build Coastguard Worker x = x + skip_connection 917*da0073e9SAndroid Build Coastguard Worker x = torch.mm(x, self.mm_param2) 918*da0073e9SAndroid Build Coastguard Worker x = self.lin(x) 919*da0073e9SAndroid Build Coastguard Worker return x 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker ec = ExampleCode() 922*da0073e9SAndroid Build Coastguard Worker 923*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(ec) 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker x = torch.randn(bs, d_hid) 926*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ec(x), traced(x)) 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker def test_node_tagging(self): 929*da0073e9SAndroid Build Coastguard Worker class TaggingTracer(Tracer): 930*da0073e9SAndroid Build Coastguard Worker def create_node(self, kind : str, target : Union[str, Callable], 931*da0073e9SAndroid Build Coastguard Worker args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None, 932*da0073e9SAndroid Build Coastguard Worker type_expr : Optional[Any] = None) -> Node: 933*da0073e9SAndroid Build Coastguard Worker n = super().create_node(kind, target, args, kwargs, name) 934*da0073e9SAndroid Build Coastguard Worker n.tag = 'foo' 935*da0073e9SAndroid Build Coastguard Worker return n 936*da0073e9SAndroid Build Coastguard Worker 937*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 938*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 939*da0073e9SAndroid Build Coastguard Worker return a + b 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker m = M() 942*da0073e9SAndroid Build Coastguard Worker g = TaggingTracer().trace(m) 943*da0073e9SAndroid Build Coastguard Worker g.lint() 944*da0073e9SAndroid Build Coastguard Worker for n in g.nodes: 945*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(n, 'tag')) 946*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n.tag, 'foo') 947*da0073e9SAndroid Build Coastguard Worker 948*da0073e9SAndroid Build Coastguard Worker def test_tensor_attribute(self): 949*da0073e9SAndroid Build Coastguard Worker class TensorAttribute(torch.nn.Module): 950*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 951*da0073e9SAndroid Build Coastguard Worker super().__init__() 952*da0073e9SAndroid Build Coastguard Worker self.tensor = torch.rand(3, 4) 953*da0073e9SAndroid Build Coastguard Worker 954*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 955*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.linear(x, self.tensor) 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker ta = TensorAttribute() 958*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(ta) 959*da0073e9SAndroid Build Coastguard Worker traced(torch.rand(4, 4)) 960*da0073e9SAndroid Build Coastguard Worker 961*da0073e9SAndroid Build Coastguard Worker class WrapperForQualname(torch.nn.Module): 962*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 963*da0073e9SAndroid Build Coastguard Worker super().__init__() 964*da0073e9SAndroid Build Coastguard Worker self.ta = TensorAttribute() 965*da0073e9SAndroid Build Coastguard Worker 966*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 967*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.linear(x, self.ta.tensor) 968*da0073e9SAndroid Build Coastguard Worker 969*da0073e9SAndroid Build Coastguard Worker wfq = WrapperForQualname() 970*da0073e9SAndroid Build Coastguard Worker traced2 = symbolic_trace(wfq) 971*da0073e9SAndroid Build Coastguard Worker traced2.graph.lint() 972*da0073e9SAndroid Build Coastguard Worker traced2(torch.rand(4, 4)) 973*da0073e9SAndroid Build Coastguard Worker 974*da0073e9SAndroid Build Coastguard Worker def test_tensor_attribute_coalseced(self): 975*da0073e9SAndroid Build Coastguard Worker 976*da0073e9SAndroid Build Coastguard Worker def count_attrs(fx_module): 977*da0073e9SAndroid Build Coastguard Worker targets = set() 978*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 979*da0073e9SAndroid Build Coastguard Worker if node.op == 'get_attr': 980*da0073e9SAndroid Build Coastguard Worker targets.add(node.target) 981*da0073e9SAndroid Build Coastguard Worker return len(targets) 982*da0073e9SAndroid Build Coastguard Worker 983*da0073e9SAndroid Build Coastguard Worker val = torch.tensor(5) 984*da0073e9SAndroid Build Coastguard Worker 985*da0073e9SAndroid Build Coastguard Worker def f(x): 986*da0073e9SAndroid Build Coastguard Worker return x + val + val 987*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(f) 988*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 989*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count_attrs(traced), 1) 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Worker val2 = torch.tensor(5) 992*da0073e9SAndroid Build Coastguard Worker 993*da0073e9SAndroid Build Coastguard Worker def f(x): 994*da0073e9SAndroid Build Coastguard Worker val = torch.tensor(5) 995*da0073e9SAndroid Build Coastguard Worker return x + val + val2 996*da0073e9SAndroid Build Coastguard Worker 997*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(f) 998*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 999*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count_attrs(traced), 2) 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker def test_symbolic_trace_sequential(self): 1002*da0073e9SAndroid Build Coastguard Worker class Simple(torch.nn.Module): 1003*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1004*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker seq = torch.nn.Sequential( 1007*da0073e9SAndroid Build Coastguard Worker Simple(), 1008*da0073e9SAndroid Build Coastguard Worker Simple(), 1009*da0073e9SAndroid Build Coastguard Worker Simple() 1010*da0073e9SAndroid Build Coastguard Worker ) 1011*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(seq) 1012*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 1013*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 1014*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(x), seq(x)) 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker def test_tensor_constant(self): 1017*da0073e9SAndroid Build Coastguard Worker class ConstTensor(torch.nn.Module): 1018*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1019*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.linear(x, torch.zeros(3, 4)) 1020*da0073e9SAndroid Build Coastguard Worker 1021*da0073e9SAndroid Build Coastguard Worker ct = ConstTensor() 1022*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(ct) 1023*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 1024*da0073e9SAndroid Build Coastguard Worker traced(torch.rand(4, 4)) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker def test_pickle_graphmodule(self): 1027*da0073e9SAndroid Build Coastguard Worker class Nested(torch.nn.Module): 1028*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1029*da0073e9SAndroid Build Coastguard Worker super().__init__() 1030*da0073e9SAndroid Build Coastguard Worker self.st = torch.nn.Linear(4, 4) 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1033*da0073e9SAndroid Build Coastguard Worker return self.st(x) 1034*da0073e9SAndroid Build Coastguard Worker 1035*da0073e9SAndroid Build Coastguard Worker n = Nested() 1036*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(n) 1037*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 1038*da0073e9SAndroid Build Coastguard Worker pickled = pickle.dumps(traced) 1039*da0073e9SAndroid Build Coastguard Worker loaded = pickle.loads(pickled) 1040*da0073e9SAndroid Build Coastguard Worker loaded.graph.lint() 1041*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 1042*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded(x), traced(x)) 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker def test_pickle_custom_import(self): 1045*da0073e9SAndroid Build Coastguard Worker graph = torch.fx.Graph() 1046*da0073e9SAndroid Build Coastguard Worker a = graph.placeholder('x') 1047*da0073e9SAndroid Build Coastguard Worker b = graph.placeholder('y') 1048*da0073e9SAndroid Build Coastguard Worker c = graph.call_function(a_non_torch_leaf, (a, b)) 1049*da0073e9SAndroid Build Coastguard Worker d = graph.call_function(torch.sin, (c,)) 1050*da0073e9SAndroid Build Coastguard Worker graph.output(d) 1051*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(torch.nn.Module(), graph) 1052*da0073e9SAndroid Build Coastguard Worker pickled = pickle.dumps(gm) 1053*da0073e9SAndroid Build Coastguard Worker loaded = pickle.loads(pickled) 1054*da0073e9SAndroid Build Coastguard Worker loaded.graph.lint() 1055*da0073e9SAndroid Build Coastguard Worker x, y = torch.rand(1), torch.rand(1) 1056*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded(x, y), gm(x, y)) 1057*da0073e9SAndroid Build Coastguard Worker 1058*da0073e9SAndroid Build Coastguard Worker def test_all_input_nodes(self): 1059*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.Graph = torch.fx.Graph() 1060*da0073e9SAndroid Build Coastguard Worker a : torch.fx.Node = graph.placeholder('x') 1061*da0073e9SAndroid Build Coastguard Worker b : torch.fx.Node = graph.call_module('linear_mod', args=(a,)) 1062*da0073e9SAndroid Build Coastguard Worker c : torch.fx.Node = graph.get_attr('y_attr') 1063*da0073e9SAndroid Build Coastguard Worker d : torch.fx.Node = graph.call_function(operator.add, args=(b, c)) 1064*da0073e9SAndroid Build Coastguard Worker e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0)) 1065*da0073e9SAndroid Build Coastguard Worker graph.output(e) 1066*da0073e9SAndroid Build Coastguard Worker graph.lint() 1067*da0073e9SAndroid Build Coastguard Worker 1068*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.all_input_nodes, [a]) 1069*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c.all_input_nodes, []) 1070*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.all_input_nodes, [b, c]) 1071*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.all_input_nodes, [d]) 1072*da0073e9SAndroid Build Coastguard Worker 1073*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_graphmodule_with_transform(self): 1074*da0073e9SAndroid Build Coastguard Worker st = SimpleTest() 1075*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(st) 1076*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 1077*da0073e9SAndroid Build Coastguard Worker 1078*da0073e9SAndroid Build Coastguard Worker def transform(traced): 1079*da0073e9SAndroid Build Coastguard Worker new_graph = torch.fx.Graph() 1080*da0073e9SAndroid Build Coastguard Worker val_map : Dict[Node, Node] = {} 1081*da0073e9SAndroid Build Coastguard Worker output_value = new_graph.graph_copy(traced.graph, val_map) 1082*da0073e9SAndroid Build Coastguard Worker relu_out = new_graph.create_node( 1083*da0073e9SAndroid Build Coastguard Worker op='call_method', target='neg', args=(output_value,), kwargs={}) 1084*da0073e9SAndroid Build Coastguard Worker new_graph.output(relu_out) 1085*da0073e9SAndroid Build Coastguard Worker return GraphModule(traced, new_graph) 1086*da0073e9SAndroid Build Coastguard Worker transformed = transform(traced) 1087*da0073e9SAndroid Build Coastguard Worker transformed.graph.lint() 1088*da0073e9SAndroid Build Coastguard Worker copied = copy.deepcopy(transformed) 1089*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(id(type(transformed)), id(type(copied))) 1090*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4) 1091*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copied(x), transformed(x)) 1092*da0073e9SAndroid Build Coastguard Worker 1093*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_with_submods_params(self): 1094*da0073e9SAndroid Build Coastguard Worker class Bar(torch.nn.Module): 1095*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1096*da0073e9SAndroid Build Coastguard Worker super().__init__() 1097*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 4)) 1098*da0073e9SAndroid Build Coastguard Worker 1099*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1100*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) + self.param 1101*da0073e9SAndroid Build Coastguard Worker 1102*da0073e9SAndroid Build Coastguard Worker class Baz(torch.nn.Module): 1103*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1104*da0073e9SAndroid Build Coastguard Worker super().__init__() 1105*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 4)) 1106*da0073e9SAndroid Build Coastguard Worker self.bar = Bar() 1107*da0073e9SAndroid Build Coastguard Worker 1108*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1109*da0073e9SAndroid Build Coastguard Worker return self.bar(x) - self.param 1110*da0073e9SAndroid Build Coastguard Worker 1111*da0073e9SAndroid Build Coastguard Worker baz = Baz() 1112*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(baz) 1113*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 1114*da0073e9SAndroid Build Coastguard Worker copied = copy.deepcopy(traced) 1115*da0073e9SAndroid Build Coastguard Worker copied.graph.lint() 1116*da0073e9SAndroid Build Coastguard Worker 1117*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_graph_with_tracer_cls(self): 1118*da0073e9SAndroid Build Coastguard Worker class TestTracer(Tracer): 1119*da0073e9SAndroid Build Coastguard Worker def is_leaf_module(self, module, name): 1120*da0073e9SAndroid Build Coastguard Worker return True 1121*da0073e9SAndroid Build Coastguard Worker 1122*da0073e9SAndroid Build Coastguard Worker g = Graph(tracer_cls=TestTracer) 1123*da0073e9SAndroid Build Coastguard Worker x = g.placeholder("x") 1124*da0073e9SAndroid Build Coastguard Worker g.output(x) 1125*da0073e9SAndroid Build Coastguard Worker 1126*da0073e9SAndroid Build Coastguard Worker h = copy.deepcopy(g) 1127*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(h._tracer_cls) 1128*da0073e9SAndroid Build Coastguard Worker self.assertTrue(g._tracer_cls == h._tracer_cls) 1129*da0073e9SAndroid Build Coastguard Worker 1130*da0073e9SAndroid Build Coastguard Worker def test_unpack_list_better_error(self): 1131*da0073e9SAndroid Build Coastguard Worker class SomeArgs(torch.nn.Module): 1132*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 1133*da0073e9SAndroid Build Coastguard Worker return torch.rand(3, 4) 1134*da0073e9SAndroid Build Coastguard Worker 1135*da0073e9SAndroid Build Coastguard Worker class UnpacksList(torch.nn.Module): 1136*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1137*da0073e9SAndroid Build Coastguard Worker super().__init__() 1138*da0073e9SAndroid Build Coastguard Worker self.sa = SomeArgs() 1139*da0073e9SAndroid Build Coastguard Worker 1140*da0073e9SAndroid Build Coastguard Worker def forward(self, x : list): 1141*da0073e9SAndroid Build Coastguard Worker return self.sa(*x) 1142*da0073e9SAndroid Build Coastguard Worker 1143*da0073e9SAndroid Build Coastguard Worker ul = UnpacksList() 1144*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): 1145*da0073e9SAndroid Build Coastguard Worker symbolic_trace(ul) 1146*da0073e9SAndroid Build Coastguard Worker 1147*da0073e9SAndroid Build Coastguard Worker def test_unpack_dict_better_error(self): 1148*da0073e9SAndroid Build Coastguard Worker class SomeKwargs(torch.nn.Module): 1149*da0073e9SAndroid Build Coastguard Worker def forward(self, x=3, y=4): 1150*da0073e9SAndroid Build Coastguard Worker return torch.rand(3, 4) 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker class UnpacksDict(torch.nn.Module): 1153*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1154*da0073e9SAndroid Build Coastguard Worker super().__init__() 1155*da0073e9SAndroid Build Coastguard Worker self.sk = SomeKwargs() 1156*da0073e9SAndroid Build Coastguard Worker 1157*da0073e9SAndroid Build Coastguard Worker def forward(self, x : dict): 1158*da0073e9SAndroid Build Coastguard Worker return self.sk(**x) 1159*da0073e9SAndroid Build Coastguard Worker 1160*da0073e9SAndroid Build Coastguard Worker ud = UnpacksDict() 1161*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): 1162*da0073e9SAndroid Build Coastguard Worker symbolic_trace(ud) 1163*da0073e9SAndroid Build Coastguard Worker 1164*da0073e9SAndroid Build Coastguard Worker def test_pretty_print_targets(self): 1165*da0073e9SAndroid Build Coastguard Worker # Test that Graph pretty-print prints friendly name for targets 1166*da0073e9SAndroid Build Coastguard Worker # in `operator` and `builtins` 1167*da0073e9SAndroid Build Coastguard Worker 1168*da0073e9SAndroid Build Coastguard Worker class SomeMod(torch.nn.Module): 1169*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1170*da0073e9SAndroid Build Coastguard Worker return torch.add(x.foo + x.bar, 3.0) 1171*da0073e9SAndroid Build Coastguard Worker 1172*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(SomeMod()) 1173*da0073e9SAndroid Build Coastguard Worker graph_str = str(traced.graph) 1174*da0073e9SAndroid Build Coastguard Worker self.assertIn('builtins.getattr', graph_str) 1175*da0073e9SAndroid Build Coastguard Worker self.assertIn('operator.add', graph_str) 1176*da0073e9SAndroid Build Coastguard Worker self.assertIn('torch.add', graph_str) 1177*da0073e9SAndroid Build Coastguard Worker 1178*da0073e9SAndroid Build Coastguard Worker def test_pretty_print_node(self): 1179*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 1180*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1181*da0073e9SAndroid Build Coastguard Worker super().__init__() 1182*da0073e9SAndroid Build Coastguard Worker self.param: torch.nn.Parameter = torch.nn.Parameter( 1183*da0073e9SAndroid Build Coastguard Worker torch.rand(3, 4)) 1184*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 5) 1185*da0073e9SAndroid Build Coastguard Worker 1186*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, y: int = 2): 1187*da0073e9SAndroid Build Coastguard Worker return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0) 1188*da0073e9SAndroid Build Coastguard Worker 1189*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(M()) 1190*da0073e9SAndroid Build Coastguard Worker 1191*da0073e9SAndroid Build Coastguard Worker all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes]) 1192*da0073e9SAndroid Build Coastguard Worker 1193*da0073e9SAndroid Build Coastguard Worker FileCheck().check("x").check("placeholder") \ 1194*da0073e9SAndroid Build Coastguard Worker .check("y").check("placeholder") \ 1195*da0073e9SAndroid Build Coastguard Worker .check("getitem").check("call_function") \ 1196*da0073e9SAndroid Build Coastguard Worker .check("param").check("get_attr") \ 1197*da0073e9SAndroid Build Coastguard Worker .check("add").check("call_function") \ 1198*da0073e9SAndroid Build Coastguard Worker .check("linear").check("call_module") \ 1199*da0073e9SAndroid Build Coastguard Worker .check("clamp").check("call_method") \ 1200*da0073e9SAndroid Build Coastguard Worker .run(all_formatted) 1201*da0073e9SAndroid Build Coastguard Worker 1202*da0073e9SAndroid Build Coastguard Worker def test_script_tensor_constant(self): 1203*da0073e9SAndroid Build Coastguard Worker # TorchScript seems to ignore attributes that start with `__`. 1204*da0073e9SAndroid Build Coastguard Worker # We used to call anonymous Tensor values `__tensor_constant*`, but 1205*da0073e9SAndroid Build Coastguard Worker # they were getting ignored by script. Now they're called 1206*da0073e9SAndroid Build Coastguard Worker # `_tensor_constant*` 1207*da0073e9SAndroid Build Coastguard Worker class IHaveATensorConstant(torch.nn.Module): 1208*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1209*da0073e9SAndroid Build Coastguard Worker return x + torch.rand(3, 4) 1210*da0073e9SAndroid Build Coastguard Worker 1211*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(IHaveATensorConstant()) 1212*da0073e9SAndroid Build Coastguard Worker torch.jit.script(traced) 1213*da0073e9SAndroid Build Coastguard Worker 1214*da0073e9SAndroid Build Coastguard Worker def test_autowrap_functions(self): 1215*da0073e9SAndroid Build Coastguard Worker class AutowrapFnTest(torch.nn.Module): 1216*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1217*da0073e9SAndroid Build Coastguard Worker return fx_int(x.shape[0] / 2) 1218*da0073e9SAndroid Build Coastguard Worker 1219*da0073e9SAndroid Build Coastguard Worker class AutowrapFnTest2(torch.nn.Module): 1220*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1221*da0073e9SAndroid Build Coastguard Worker return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2) 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker # Check function(s) are wrapped 1224*da0073e9SAndroid Build Coastguard Worker # `int` would normally throw a TypeError as argument can't be `Proxy` 1225*da0073e9SAndroid Build Coastguard Worker tracer = Tracer(autowrap_functions=(fx_int,)) 1226*da0073e9SAndroid Build Coastguard Worker graph = tracer.trace(AutowrapFnTest()) 1227*da0073e9SAndroid Build Coastguard Worker traced = GraphModule(tracer.root, graph, 'test') 1228*da0073e9SAndroid Build Coastguard Worker tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2)) 1229*da0073e9SAndroid Build Coastguard Worker tracer_2.trace(AutowrapFnTest2()) 1230*da0073e9SAndroid Build Coastguard Worker 1231*da0073e9SAndroid Build Coastguard Worker # Test scriptability 1232*da0073e9SAndroid Build Coastguard Worker traced_scripted = torch.jit.script(traced) 1233*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_scripted(torch.rand(4)), 2) 1234*da0073e9SAndroid Build Coastguard Worker 1235*da0073e9SAndroid Build Coastguard Worker def test_tuple_no_subscript(self): 1236*da0073e9SAndroid Build Coastguard Worker def foo(x : Tuple): 1237*da0073e9SAndroid Build Coastguard Worker return x[0] 1238*da0073e9SAndroid Build Coastguard Worker 1239*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(foo) 1240*da0073e9SAndroid Build Coastguard Worker x = (torch.randn(5, 3),) 1241*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(traced(x), x[0]) 1242*da0073e9SAndroid Build Coastguard Worker 1243*da0073e9SAndroid Build Coastguard Worker bio = io.BytesIO() 1244*da0073e9SAndroid Build Coastguard Worker 1245*da0073e9SAndroid Build Coastguard Worker torch.save(traced, bio) 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker bio.seek(0) 1248*da0073e9SAndroid Build Coastguard Worker 1249*da0073e9SAndroid Build Coastguard Worker # weights_only=False as this loads a GraphModule 1250*da0073e9SAndroid Build Coastguard Worker # GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default 1251*da0073e9SAndroid Build Coastguard Worker loaded = torch.load(bio, weights_only=False) 1252*da0073e9SAndroid Build Coastguard Worker 1253*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(loaded(x), x[0]) 1254*da0073e9SAndroid Build Coastguard Worker 1255*da0073e9SAndroid Build Coastguard Worker def test_torch_fx_len(self): 1256*da0073e9SAndroid Build Coastguard Worker class FXLenTest(torch.nn.Module): 1257*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1258*da0073e9SAndroid Build Coastguard Worker return len(x) 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(FXLenTest()) 1261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(torch.rand(3, 4)), 3) 1262*da0073e9SAndroid Build Coastguard Worker 1263*da0073e9SAndroid Build Coastguard Worker # Test scriptability 1264*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(FXLenTest()) 1265*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted(torch.rand(3)), 3) 1266*da0073e9SAndroid Build Coastguard Worker 1267*da0073e9SAndroid Build Coastguard Worker traced_scripted = torch.jit.script(traced) 1268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_scripted(torch.rand(3)), 3) 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Worker # Test non-proxy len 1271*da0073e9SAndroid Build Coastguard Worker class FXLenTest2(torch.nn.Module): 1272*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1273*da0073e9SAndroid Build Coastguard Worker super().__init__() 1274*da0073e9SAndroid Build Coastguard Worker self.l = [3, 4, 5] 1275*da0073e9SAndroid Build Coastguard Worker 1276*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1277*da0073e9SAndroid Build Coastguard Worker return x + len(self.l) 1278*da0073e9SAndroid Build Coastguard Worker 1279*da0073e9SAndroid Build Coastguard Worker traced2 = symbolic_trace(FXLenTest2()) 1280*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(3, 4) 1281*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced2(inp), inp + 3.0) 1282*da0073e9SAndroid Build Coastguard Worker self.assertIs(len, builtins.len) 1283*da0073e9SAndroid Build Coastguard Worker 1284*da0073e9SAndroid Build Coastguard Worker def test_torch_fx_getattr(self): 1285*da0073e9SAndroid Build Coastguard Worker class FXGetattrTest(torch.nn.Module): 1286*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1287*da0073e9SAndroid Build Coastguard Worker return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3])) 1288*da0073e9SAndroid Build Coastguard Worker 1289*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(FXGetattrTest()) 1290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3])) 1291*da0073e9SAndroid Build Coastguard Worker 1292*da0073e9SAndroid Build Coastguard Worker def test_sqrt(self): 1293*da0073e9SAndroid Build Coastguard Worker class Sqrt1(torch.nn.Module): 1294*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1295*da0073e9SAndroid Build Coastguard Worker return sqrt(x.size(0)) 1296*da0073e9SAndroid Build Coastguard Worker 1297*da0073e9SAndroid Build Coastguard Worker class Sqrt2(torch.nn.Module): 1298*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1299*da0073e9SAndroid Build Coastguard Worker return math.sqrt(x.size(0)) 1300*da0073e9SAndroid Build Coastguard Worker 1301*da0073e9SAndroid Build Coastguard Worker class Sqrt3(torch.nn.Module): 1302*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1303*da0073e9SAndroid Build Coastguard Worker return x + math.sqrt(2) + sqrt(2) 1304*da0073e9SAndroid Build Coastguard Worker 1305*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(Sqrt1(), [torch.zeros(8)]) 1306*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(Sqrt2(), [torch.zeros(8)]) 1307*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(Sqrt3(), [torch.zeros(8)]) 1308*da0073e9SAndroid Build Coastguard Worker self.assertIs(sqrt, _sqrt) 1309*da0073e9SAndroid Build Coastguard Worker self.assertIs(math.sqrt, _sqrt) 1310*da0073e9SAndroid Build Coastguard Worker 1311*da0073e9SAndroid Build Coastguard Worker def test_torch_custom_ops(self): 1312*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 1313*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 1314*da0073e9SAndroid Build Coastguard Worker b = torch.ops.aten.sigmoid(a) 1315*da0073e9SAndroid Build Coastguard Worker c = torch.ops.aten.cat([a, b]) 1316*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten.cat((c, c)) 1317*da0073e9SAndroid Build Coastguard Worker m = M() 1318*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3) 1319*da0073e9SAndroid Build Coastguard Worker ref_out = m(input) 1320*da0073e9SAndroid Build Coastguard Worker gm = symbolic_trace(m) 1321*da0073e9SAndroid Build Coastguard Worker gm.graph.lint() 1322*da0073e9SAndroid Build Coastguard Worker out = gm(input) 1323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 1324*da0073e9SAndroid Build Coastguard Worker 1325*da0073e9SAndroid Build Coastguard Worker def test_torch_op_overloads(self): 1326*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 1327*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 1328*da0073e9SAndroid Build Coastguard Worker b = torch.ops.aten.add.Tensor(a, a) 1329*da0073e9SAndroid Build Coastguard Worker return b 1330*da0073e9SAndroid Build Coastguard Worker m = M() 1331*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3) 1332*da0073e9SAndroid Build Coastguard Worker ref_out = m(input) 1333*da0073e9SAndroid Build Coastguard Worker gm = symbolic_trace(m) 1334*da0073e9SAndroid Build Coastguard Worker gm.graph.lint() 1335*da0073e9SAndroid Build Coastguard Worker out = gm(input) 1336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 1337*da0073e9SAndroid Build Coastguard Worker 1338*da0073e9SAndroid Build Coastguard Worker for node in gm.graph.nodes: 1339*da0073e9SAndroid Build Coastguard Worker if node.op == 'call_function': 1340*da0073e9SAndroid Build Coastguard Worker assert isinstance(node.target, torch._ops.OpOverload) 1341*da0073e9SAndroid Build Coastguard Worker assert node.target.__name__ == 'add.Tensor' 1342*da0073e9SAndroid Build Coastguard Worker 1343*da0073e9SAndroid Build Coastguard Worker def test_pickle_torch_custom_ops(self): 1344*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 1345*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 1346*da0073e9SAndroid Build Coastguard Worker b = torch.ops.aten.sigmoid(a) 1347*da0073e9SAndroid Build Coastguard Worker c = torch.ops.aten.cat([a, b]) 1348*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten.cat((c, c)) 1349*da0073e9SAndroid Build Coastguard Worker m = M() 1350*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3) 1351*da0073e9SAndroid Build Coastguard Worker ref_out = m(input) 1352*da0073e9SAndroid Build Coastguard Worker gm = symbolic_trace(m) 1353*da0073e9SAndroid Build Coastguard Worker gm.graph.lint() 1354*da0073e9SAndroid Build Coastguard Worker pickled = pickle.dumps(gm) 1355*da0073e9SAndroid Build Coastguard Worker loaded = pickle.loads(pickled) 1356*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded(input), gm(input)) 1357*da0073e9SAndroid Build Coastguard Worker 1358*da0073e9SAndroid Build Coastguard Worker def test_pretty_print(self): 1359*da0073e9SAndroid Build Coastguard Worker st = SimpleTest() 1360*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(st) 1361*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 1362*da0073e9SAndroid Build Coastguard Worker printed = str(traced) 1363*da0073e9SAndroid Build Coastguard Worker assert 'SimpleTest()' in printed 1364*da0073e9SAndroid Build Coastguard Worker assert 'torch.relu' in printed 1365*da0073e9SAndroid Build Coastguard Worker 1366*da0073e9SAndroid Build Coastguard Worker def test_pretty_print_graph(self): 1367*da0073e9SAndroid Build Coastguard Worker class KwargPrintTest(torch.nn.Module): 1368*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1369*da0073e9SAndroid Build Coastguard Worker return torch.squeeze(x + 3.0, dim=2) 1370*da0073e9SAndroid Build Coastguard Worker st = KwargPrintTest() 1371*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(st) 1372*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 1373*da0073e9SAndroid Build Coastguard Worker stringed = str(traced.graph) 1374*da0073e9SAndroid Build Coastguard Worker for s in ['args', 'kwargs', 'num_users']: 1375*da0073e9SAndroid Build Coastguard Worker assert s in stringed 1376*da0073e9SAndroid Build Coastguard Worker 1377*da0073e9SAndroid Build Coastguard Worker def test_custom_proxy_type(self): 1378*da0073e9SAndroid Build Coastguard Worker class TensorPair: 1379*da0073e9SAndroid Build Coastguard Worker def __init__(self, left, right): 1380*da0073e9SAndroid Build Coastguard Worker self.left, self.right = left, right 1381*da0073e9SAndroid Build Coastguard Worker 1382*da0073e9SAndroid Build Coastguard Worker def add(self, other): 1383*da0073e9SAndroid Build Coastguard Worker l = self.left + other.left 1384*da0073e9SAndroid Build Coastguard Worker r = self.right + other.right 1385*da0073e9SAndroid Build Coastguard Worker return TensorPair(l, r) 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker def mul(self, other): 1388*da0073e9SAndroid Build Coastguard Worker l = self.left * other.left 1389*da0073e9SAndroid Build Coastguard Worker r = self.right * other.right 1390*da0073e9SAndroid Build Coastguard Worker return TensorPair(l, r) 1391*da0073e9SAndroid Build Coastguard Worker 1392*da0073e9SAndroid Build Coastguard Worker def use_tensor_pair(x : TensorPair, y : TensorPair): 1393*da0073e9SAndroid Build Coastguard Worker s = x.add(y) 1394*da0073e9SAndroid Build Coastguard Worker return s.mul(x) 1395*da0073e9SAndroid Build Coastguard Worker 1396*da0073e9SAndroid Build Coastguard Worker x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) 1397*da0073e9SAndroid Build Coastguard Worker y = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) 1398*da0073e9SAndroid Build Coastguard Worker 1399*da0073e9SAndroid Build Coastguard Worker ref_out = use_tensor_pair(x, y) 1400*da0073e9SAndroid Build Coastguard Worker 1401*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(use_tensor_pair) 1402*da0073e9SAndroid Build Coastguard Worker 1403*da0073e9SAndroid Build Coastguard Worker traced_out = traced(x, y) 1404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_out.left, ref_out.left) 1405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_out.right, ref_out.right) 1406*da0073e9SAndroid Build Coastguard Worker 1407*da0073e9SAndroid Build Coastguard Worker def test_custom_proxy_type_literal(self): 1408*da0073e9SAndroid Build Coastguard Worker class TensorPair(metaclass=torch.fx.ProxyableClassMeta): 1409*da0073e9SAndroid Build Coastguard Worker def __init__(self, left, right): 1410*da0073e9SAndroid Build Coastguard Worker self.left, self.right = left, right 1411*da0073e9SAndroid Build Coastguard Worker 1412*da0073e9SAndroid Build Coastguard Worker def add(self, other): 1413*da0073e9SAndroid Build Coastguard Worker l = self.left + other.left 1414*da0073e9SAndroid Build Coastguard Worker r = self.right + other.right 1415*da0073e9SAndroid Build Coastguard Worker return TensorPair(l, r) 1416*da0073e9SAndroid Build Coastguard Worker 1417*da0073e9SAndroid Build Coastguard Worker def mul(self, other): 1418*da0073e9SAndroid Build Coastguard Worker l = self.left * other.left 1419*da0073e9SAndroid Build Coastguard Worker r = self.right * other.right 1420*da0073e9SAndroid Build Coastguard Worker return TensorPair(l, r) 1421*da0073e9SAndroid Build Coastguard Worker 1422*da0073e9SAndroid Build Coastguard Worker def use_tensor_pair_literal(x : TensorPair): 1423*da0073e9SAndroid Build Coastguard Worker s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3))) 1424*da0073e9SAndroid Build Coastguard Worker return s.mul(x) 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) 1427*da0073e9SAndroid Build Coastguard Worker 1428*da0073e9SAndroid Build Coastguard Worker ref_out = use_tensor_pair_literal(x) 1429*da0073e9SAndroid Build Coastguard Worker 1430*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(use_tensor_pair_literal) 1431*da0073e9SAndroid Build Coastguard Worker 1432*da0073e9SAndroid Build Coastguard Worker traced_out = traced(x) 1433*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_out.left, ref_out.left) 1434*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_out.right, ref_out.right) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker def test_custom_proxy_dynamic_value(self): 1437*da0073e9SAndroid Build Coastguard Worker class TensorPair(metaclass=torch.fx.ProxyableClassMeta): 1438*da0073e9SAndroid Build Coastguard Worker def __init__(self, left, right): 1439*da0073e9SAndroid Build Coastguard Worker self.left, self.right = left, right 1440*da0073e9SAndroid Build Coastguard Worker 1441*da0073e9SAndroid Build Coastguard Worker def add(self, other): 1442*da0073e9SAndroid Build Coastguard Worker l = self.left + other.left 1443*da0073e9SAndroid Build Coastguard Worker r = self.right + other.right 1444*da0073e9SAndroid Build Coastguard Worker return TensorPair(l, r) 1445*da0073e9SAndroid Build Coastguard Worker 1446*da0073e9SAndroid Build Coastguard Worker def mul(self, other): 1447*da0073e9SAndroid Build Coastguard Worker l = self.left * other.left 1448*da0073e9SAndroid Build Coastguard Worker r = self.right * other.right 1449*da0073e9SAndroid Build Coastguard Worker return TensorPair(l, r) 1450*da0073e9SAndroid Build Coastguard Worker 1451*da0073e9SAndroid Build Coastguard Worker def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): 1452*da0073e9SAndroid Build Coastguard Worker s = x.add(TensorPair(y, y)) 1453*da0073e9SAndroid Build Coastguard Worker return s.mul(x) 1454*da0073e9SAndroid Build Coastguard Worker 1455*da0073e9SAndroid Build Coastguard Worker x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) 1456*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 3) 1457*da0073e9SAndroid Build Coastguard Worker ref_out = use_tensor_pair_ctor(x, y) 1458*da0073e9SAndroid Build Coastguard Worker 1459*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(use_tensor_pair_ctor) 1460*da0073e9SAndroid Build Coastguard Worker 1461*da0073e9SAndroid Build Coastguard Worker traced_out = traced(x, y) 1462*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_out.left, ref_out.left) 1463*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_out.right, ref_out.right) 1464*da0073e9SAndroid Build Coastguard Worker 1465*da0073e9SAndroid Build Coastguard Worker def test_custom_proxy_input_dependent_control_flow(self): 1466*da0073e9SAndroid Build Coastguard Worker class ZeroTensor(metaclass=torch.fx.ProxyableClassMeta): 1467*da0073e9SAndroid Build Coastguard Worker def __init__(self, inp): 1468*da0073e9SAndroid Build Coastguard Worker if inp.sum() == 0: 1469*da0073e9SAndroid Build Coastguard Worker self.is_zero = True 1470*da0073e9SAndroid Build Coastguard Worker self.tensor = torch.tensor([]) 1471*da0073e9SAndroid Build Coastguard Worker else: 1472*da0073e9SAndroid Build Coastguard Worker self.is_zero = False 1473*da0073e9SAndroid Build Coastguard Worker self.tensor = inp 1474*da0073e9SAndroid Build Coastguard Worker 1475*da0073e9SAndroid Build Coastguard Worker def add(self, other): 1476*da0073e9SAndroid Build Coastguard Worker if self.is_zero: 1477*da0073e9SAndroid Build Coastguard Worker return ZeroTensor(other.tensor) 1478*da0073e9SAndroid Build Coastguard Worker elif other.is_zero: 1479*da0073e9SAndroid Build Coastguard Worker return self 1480*da0073e9SAndroid Build Coastguard Worker 1481*da0073e9SAndroid Build Coastguard Worker def use_zero_tensor(x : torch.Tensor, y : torch.Tensor): 1482*da0073e9SAndroid Build Coastguard Worker return ZeroTensor(x + y) 1483*da0073e9SAndroid Build Coastguard Worker 1484*da0073e9SAndroid Build Coastguard Worker x, y = torch.randn(5, 3), torch.randn(5, 3) 1485*da0073e9SAndroid Build Coastguard Worker 1486*da0073e9SAndroid Build Coastguard Worker ref_out = use_zero_tensor(x, y) 1487*da0073e9SAndroid Build Coastguard Worker 1488*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(use_zero_tensor) 1489*da0073e9SAndroid Build Coastguard Worker 1490*da0073e9SAndroid Build Coastguard Worker traced_out = traced(x, y) 1491*da0073e9SAndroid Build Coastguard Worker 1492*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_out.is_zero, ref_out.is_zero) 1493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_out.tensor, ref_out.tensor) 1494*da0073e9SAndroid Build Coastguard Worker 1495*da0073e9SAndroid Build Coastguard Worker def test_graph_fns(self): 1496*da0073e9SAndroid Build Coastguard Worker g = Graph() 1497*da0073e9SAndroid Build Coastguard Worker a = g.placeholder('a') 1498*da0073e9SAndroid Build Coastguard Worker b = g.call_module('linear', (a,)) 1499*da0073e9SAndroid Build Coastguard Worker c = g.get_attr('bias') 1500*da0073e9SAndroid Build Coastguard Worker d = g.call_method('add', (b, c)) 1501*da0073e9SAndroid Build Coastguard Worker e = g.call_function(torch.sin, (d,)) 1502*da0073e9SAndroid Build Coastguard Worker g.output(e) 1503*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.Module() 1504*da0073e9SAndroid Build Coastguard Worker mod.linear = torch.nn.Linear(3, 4) 1505*da0073e9SAndroid Build Coastguard Worker mod.bias = torch.rand(4) 1506*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(mod, g) 1507*da0073e9SAndroid Build Coastguard Worker gm.graph.lint() 1508*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3) 1509*da0073e9SAndroid Build Coastguard Worker r = gm(input) 1510*da0073e9SAndroid Build Coastguard Worker ref = torch.sin(mod.linear(input) + mod.bias) 1511*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, ref) 1512*da0073e9SAndroid Build Coastguard Worker 1513*da0073e9SAndroid Build Coastguard Worker def test_remove_uses(self): 1514*da0073e9SAndroid Build Coastguard Worker g : torch.fx.Graph = Graph() 1515*da0073e9SAndroid Build Coastguard Worker x : torch.fx.Node = g.placeholder('x') 1516*da0073e9SAndroid Build Coastguard Worker relu : torch.fx.Node = g.call_function(torch.relu, (x,)) 1517*da0073e9SAndroid Build Coastguard Worker neg : torch.fx.Node = g.call_function(torch.neg, (relu,)) 1518*da0073e9SAndroid Build Coastguard Worker g.output(neg) 1519*da0073e9SAndroid Build Coastguard Worker 1520*da0073e9SAndroid Build Coastguard Worker neg.replace_all_uses_with(relu) 1521*da0073e9SAndroid Build Coastguard Worker g.erase_node(neg) 1522*da0073e9SAndroid Build Coastguard Worker 1523*da0073e9SAndroid Build Coastguard Worker self.assertTrue(neg not in relu.users) 1524*da0073e9SAndroid Build Coastguard Worker 1525*da0073e9SAndroid Build Coastguard Worker def test_remove_uses_with_custom_filter(self): 1526*da0073e9SAndroid Build Coastguard Worker g : torch.fx.Graph = Graph() 1527*da0073e9SAndroid Build Coastguard Worker x : torch.fx.Node = g.placeholder('x') 1528*da0073e9SAndroid Build Coastguard Worker relu : torch.fx.Node = g.call_function(torch.relu, (x,)) 1529*da0073e9SAndroid Build Coastguard Worker neg : torch.fx.Node = g.call_function(torch.neg, (relu,)) 1530*da0073e9SAndroid Build Coastguard Worker g.output(neg) 1531*da0073e9SAndroid Build Coastguard Worker 1532*da0073e9SAndroid Build Coastguard Worker neg.replace_all_uses_with(relu, lambda x: x != neg) 1533*da0073e9SAndroid Build Coastguard Worker 1534*da0073e9SAndroid Build Coastguard Worker self.assertTrue(neg in relu.users) 1535*da0073e9SAndroid Build Coastguard Worker 1536*da0073e9SAndroid Build Coastguard Worker def test_nonetype_annotation(self): 1537*da0073e9SAndroid Build Coastguard Worker eb = torch.nn.EmbeddingBag(3, 4) 1538*da0073e9SAndroid Build Coastguard Worker symbolic_trace(eb) 1539*da0073e9SAndroid Build Coastguard Worker 1540*da0073e9SAndroid Build Coastguard Worker def test_pickle_nonetype_annotation(self): 1541*da0073e9SAndroid Build Coastguard Worker eb = torch.nn.EmbeddingBag(10, 3, mode='sum') 1542*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(eb) 1543*da0073e9SAndroid Build Coastguard Worker pickled = pickle.dumps(traced) 1544*da0073e9SAndroid Build Coastguard Worker loaded = pickle.loads(pickled) 1545*da0073e9SAndroid Build Coastguard Worker loaded.graph.lint() 1546*da0073e9SAndroid Build Coastguard Worker input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) 1547*da0073e9SAndroid Build Coastguard Worker offsets = torch.LongTensor([0, 4]) 1548*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded(input, offsets), traced(input, offsets)) 1549*da0073e9SAndroid Build Coastguard Worker 1550*da0073e9SAndroid Build Coastguard Worker def test_return_tuple(self): 1551*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 1552*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 1553*da0073e9SAndroid Build Coastguard Worker return (x, x + x) 1554*da0073e9SAndroid Build Coastguard Worker 1555*da0073e9SAndroid Build Coastguard Worker original = M() 1556*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(original) 1557*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1))) 1558*da0073e9SAndroid Build Coastguard Worker 1559*da0073e9SAndroid Build Coastguard Worker def test_construct_root_dict(self): 1560*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.Graph = torch.fx.Graph() 1561*da0073e9SAndroid Build Coastguard Worker a : torch.fx.Node = graph.create_node('placeholder', 'x') 1562*da0073e9SAndroid Build Coastguard Worker b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) 1563*da0073e9SAndroid Build Coastguard Worker c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') 1564*da0073e9SAndroid Build Coastguard Worker d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) 1565*da0073e9SAndroid Build Coastguard Worker graph.output(d) 1566*da0073e9SAndroid Build Coastguard Worker 1567*da0073e9SAndroid Build Coastguard Worker linear_mod : torch.nn.Module = torch.nn.Linear(3, 4) 1568*da0073e9SAndroid Build Coastguard Worker add_param : torch.Tensor = torch.rand(3, 4) 1569*da0073e9SAndroid Build Coastguard Worker gm : torch.fx.GraphModule = torch.fx.GraphModule( 1570*da0073e9SAndroid Build Coastguard Worker {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph) 1571*da0073e9SAndroid Build Coastguard Worker gm.graph.lint() 1572*da0073e9SAndroid Build Coastguard Worker 1573*da0073e9SAndroid Build Coastguard Worker assert 'self.foo.bar.baz' in gm.code 1574*da0073e9SAndroid Build Coastguard Worker 1575*da0073e9SAndroid Build Coastguard Worker x : torch.Tensor = torch.rand(3, 3) 1576*da0073e9SAndroid Build Coastguard Worker out : torch.Tensor = gm(x) 1577*da0073e9SAndroid Build Coastguard Worker ref_out : torch.Tensor = linear_mod(x) + add_param 1578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 1579*da0073e9SAndroid Build Coastguard Worker 1580*da0073e9SAndroid Build Coastguard Worker def test_symbolic_trace_assert(self): 1581*da0073e9SAndroid Build Coastguard Worker 1582*da0073e9SAndroid Build Coastguard Worker class AssertsTensorShape(torch.nn.Module): 1583*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1584*da0073e9SAndroid Build Coastguard Worker torch._assert(x.shape[1] > 4, "assert_foobar") 1585*da0073e9SAndroid Build Coastguard Worker return x 1586*da0073e9SAndroid Build Coastguard Worker 1587*da0073e9SAndroid Build Coastguard Worker m = AssertsTensorShape() 1588*da0073e9SAndroid Build Coastguard Worker # verify traceability 1589*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 1590*da0073e9SAndroid Build Coastguard Worker # verify assertion on traced model works correctly at runtime 1591*da0073e9SAndroid Build Coastguard Worker traced(torch.rand(4, 5)) 1592*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "assert_foobar"): 1593*da0073e9SAndroid Build Coastguard Worker traced(torch.rand(4, 3)) 1594*da0073e9SAndroid Build Coastguard Worker # verify the symbolically traced module is scriptable 1595*da0073e9SAndroid Build Coastguard Worker ms = torch.jit.script(m) 1596*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"): 1597*da0073e9SAndroid Build Coastguard Worker ms(torch.rand(4, 3)) 1598*da0073e9SAndroid Build Coastguard Worker 1599*da0073e9SAndroid Build Coastguard Worker def test_fx_create_arg(self): 1600*da0073e9SAndroid Build Coastguard Worker class CustomArgObject: 1601*da0073e9SAndroid Build Coastguard Worker def __init__(self, x, y): 1602*da0073e9SAndroid Build Coastguard Worker self.x = x 1603*da0073e9SAndroid Build Coastguard Worker self.y = y 1604*da0073e9SAndroid Build Coastguard Worker 1605*da0073e9SAndroid Build Coastguard Worker def __fx_create_arg__(self, tracer: torch.fx.Tracer): 1606*da0073e9SAndroid Build Coastguard Worker return tracer.create_node( 1607*da0073e9SAndroid Build Coastguard Worker "call_function", 1608*da0073e9SAndroid Build Coastguard Worker CustomArgObject, 1609*da0073e9SAndroid Build Coastguard Worker args=( 1610*da0073e9SAndroid Build Coastguard Worker tracer.create_arg(self.x), 1611*da0073e9SAndroid Build Coastguard Worker tracer.create_arg(self.y), 1612*da0073e9SAndroid Build Coastguard Worker ), 1613*da0073e9SAndroid Build Coastguard Worker kwargs={}, 1614*da0073e9SAndroid Build Coastguard Worker ) 1615*da0073e9SAndroid Build Coastguard Worker 1616*da0073e9SAndroid Build Coastguard Worker class HasCustomArgObjectWhenLeaf(torch.nn.Module): 1617*da0073e9SAndroid Build Coastguard Worker def forward(self, o: CustomArgObject): 1618*da0073e9SAndroid Build Coastguard Worker # Not normally traceable; good reason to make 1619*da0073e9SAndroid Build Coastguard Worker # this module a leaf. 1620*da0073e9SAndroid Build Coastguard Worker for x in o.x: 1621*da0073e9SAndroid Build Coastguard Worker o.y += x 1622*da0073e9SAndroid Build Coastguard Worker return o.y 1623*da0073e9SAndroid Build Coastguard Worker 1624*da0073e9SAndroid Build Coastguard Worker class Root(torch.nn.Module): 1625*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1626*da0073e9SAndroid Build Coastguard Worker super().__init__() 1627*da0073e9SAndroid Build Coastguard Worker self.inner = HasCustomArgObjectWhenLeaf() 1628*da0073e9SAndroid Build Coastguard Worker 1629*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 1630*da0073e9SAndroid Build Coastguard Worker o = CustomArgObject(x, y) 1631*da0073e9SAndroid Build Coastguard Worker return self.inner(o) 1632*da0073e9SAndroid Build Coastguard Worker 1633*da0073e9SAndroid Build Coastguard Worker class CreateArgTracer(torch.fx.Tracer): 1634*da0073e9SAndroid Build Coastguard Worker def is_leaf_module(self, m, module_qualified_name): 1635*da0073e9SAndroid Build Coastguard Worker return type(m) is HasCustomArgObjectWhenLeaf 1636*da0073e9SAndroid Build Coastguard Worker 1637*da0073e9SAndroid Build Coastguard Worker m = Root() 1638*da0073e9SAndroid Build Coastguard Worker graph = CreateArgTracer().trace(m) 1639*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.GraphModule(m, graph) 1640*da0073e9SAndroid Build Coastguard Worker assert "CustomArgObject(" in gm.code 1641*da0073e9SAndroid Build Coastguard Worker 1642*da0073e9SAndroid Build Coastguard Worker def test_trace_fn_constant(self): 1643*da0073e9SAndroid Build Coastguard Worker some_constant = torch.rand(3, 4) 1644*da0073e9SAndroid Build Coastguard Worker 1645*da0073e9SAndroid Build Coastguard Worker def add_const(x): 1646*da0073e9SAndroid Build Coastguard Worker return some_constant + x 1647*da0073e9SAndroid Build Coastguard Worker 1648*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(add_const) 1649*da0073e9SAndroid Build Coastguard Worker 1650*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3, 4) 1651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(input), add_const(input)) 1652*da0073e9SAndroid Build Coastguard Worker 1653*da0073e9SAndroid Build Coastguard Worker def test_copy_no_remap(self): 1654*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(SimpleTest()) 1655*da0073e9SAndroid Build Coastguard Worker g = traced.graph 1656*da0073e9SAndroid Build Coastguard Worker copied = torch.fx.Graph() 1657*da0073e9SAndroid Build Coastguard Worker for node in g.nodes: 1658*da0073e9SAndroid Build Coastguard Worker copied.node_copy(node) 1659*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'): 1660*da0073e9SAndroid Build Coastguard Worker copied.lint() 1661*da0073e9SAndroid Build Coastguard Worker 1662*da0073e9SAndroid Build Coastguard Worker def test_wrong_topo(self): 1663*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.Graph = torch.fx.Graph() 1664*da0073e9SAndroid Build Coastguard Worker a : torch.fx.Node = graph.create_node('placeholder', 'x') 1665*da0073e9SAndroid Build Coastguard Worker b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) 1666*da0073e9SAndroid Build Coastguard Worker c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') 1667*da0073e9SAndroid Build Coastguard Worker d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) 1668*da0073e9SAndroid Build Coastguard Worker graph.output(d) 1669*da0073e9SAndroid Build Coastguard Worker nodes = list(graph.nodes) 1670*da0073e9SAndroid Build Coastguard Worker nodes[3].append(nodes[2]) 1671*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'): 1672*da0073e9SAndroid Build Coastguard Worker graph.lint() 1673*da0073e9SAndroid Build Coastguard Worker 1674*da0073e9SAndroid Build Coastguard Worker def test_wrong_target_type(self): 1675*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.Graph = torch.fx.Graph() 1676*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1677*da0073e9SAndroid Build Coastguard Worker n = torch.fx.Node(graph=graph, name='foo', op='call_function', target='foo', 1678*da0073e9SAndroid Build Coastguard Worker args=(), kwargs={}) 1679*da0073e9SAndroid Build Coastguard Worker 1680*da0073e9SAndroid Build Coastguard Worker def test_example_shape_prop(self): 1681*da0073e9SAndroid Build Coastguard Worker class TestCase(torch.nn.Module): 1682*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1683*da0073e9SAndroid Build Coastguard Worker super().__init__() 1684*da0073e9SAndroid Build Coastguard Worker self.attr = torch.randn(3, 4) 1685*da0073e9SAndroid Build Coastguard Worker self.submod = torch.nn.Linear(4, 4) 1686*da0073e9SAndroid Build Coastguard Worker 1687*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1688*da0073e9SAndroid Build Coastguard Worker return torch.neg(self.submod(x.relu() + self.attr)) 1689*da0073e9SAndroid Build Coastguard Worker tc = TestCase() 1690*da0073e9SAndroid Build Coastguard Worker tc_traced = symbolic_trace(tc) 1691*da0073e9SAndroid Build Coastguard Worker ref_out = tc_traced(torch.rand(3, 4)) 1692*da0073e9SAndroid Build Coastguard Worker shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4)) 1693*da0073e9SAndroid Build Coastguard Worker 1694*da0073e9SAndroid Build Coastguard Worker # Make sure we're testing all opcodes 1695*da0073e9SAndroid Build Coastguard Worker opcodes = set() 1696*da0073e9SAndroid Build Coastguard Worker output_shape : Optional[torch.Shape] = None 1697*da0073e9SAndroid Build Coastguard Worker output_stride : Optional[Tuple[int]] = None 1698*da0073e9SAndroid Build Coastguard Worker for node in tc_traced.graph.nodes: 1699*da0073e9SAndroid Build Coastguard Worker opcodes.add(node.op) 1700*da0073e9SAndroid Build Coastguard Worker if node.op == 'output': 1701*da0073e9SAndroid Build Coastguard Worker output_shape = node.args[0].meta['tensor_meta'].shape 1702*da0073e9SAndroid Build Coastguard Worker output_stride = node.args[0].meta['tensor_meta'].stride 1703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opcodes, {'placeholder', 'get_attr', 'call_function', 'call_method', 1704*da0073e9SAndroid Build Coastguard Worker 'call_module', 'output'}) 1705*da0073e9SAndroid Build Coastguard Worker 1706*da0073e9SAndroid Build Coastguard Worker # Test shape propagation and make sure results match actual 1707*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_shape, ref_out.shape) 1708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_stride, ref_out.stride()) 1709*da0073e9SAndroid Build Coastguard Worker 1710*da0073e9SAndroid Build Coastguard Worker def test_shape_prop_layout(self): 1711*da0073e9SAndroid Build Coastguard Worker class ConvTest(torch.nn.Module): 1712*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1713*da0073e9SAndroid Build Coastguard Worker super().__init__() 1714*da0073e9SAndroid Build Coastguard Worker self.conv_mod = torch.nn.Conv2d(5, 5, 3) 1715*da0073e9SAndroid Build Coastguard Worker 1716*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1717*da0073e9SAndroid Build Coastguard Worker return self.conv_mod(x) 1718*da0073e9SAndroid Build Coastguard Worker 1719*da0073e9SAndroid Build Coastguard Worker # contiguous layout 1720*da0073e9SAndroid Build Coastguard Worker test_mod = ConvTest() 1721*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(test_mod) 1722*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, 224, 224) 1723*da0073e9SAndroid Build Coastguard Worker shape_prop.ShapeProp(traced).propagate(x) 1724*da0073e9SAndroid Build Coastguard Worker 1725*da0073e9SAndroid Build Coastguard Worker assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format 1726*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes) 1727*da0073e9SAndroid Build Coastguard Worker 1728*da0073e9SAndroid Build Coastguard Worker x_channels_last = x.contiguous(memory_format=torch.channels_last) 1729*da0073e9SAndroid Build Coastguard Worker traced.to(memory_format=torch.channels_last) 1730*da0073e9SAndroid Build Coastguard Worker shape_prop.ShapeProp(traced).propagate(x_channels_last) 1731*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 1732*da0073e9SAndroid Build Coastguard Worker # NB: the implementation of conv may not preserve the memory format, 1733*da0073e9SAndroid Build Coastguard Worker # unfortunately. The best we can do is just check that the placeholder 1734*da0073e9SAndroid Build Coastguard Worker # node is channels-last 1735*da0073e9SAndroid Build Coastguard Worker if node.op in {'placeholder'}: 1736*da0073e9SAndroid Build Coastguard Worker self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last) 1737*da0073e9SAndroid Build Coastguard Worker 1738*da0073e9SAndroid Build Coastguard Worker def test_shape_prop_aggregate(self): 1739*da0073e9SAndroid Build Coastguard Worker class ReturnTwo(torch.nn.Module): 1740*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1741*da0073e9SAndroid Build Coastguard Worker return (3, torch.sum(x)) 1742*da0073e9SAndroid Build Coastguard Worker 1743*da0073e9SAndroid Build Coastguard Worker class UnderTest(torch.nn.Module): 1744*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1745*da0073e9SAndroid Build Coastguard Worker super().__init__() 1746*da0073e9SAndroid Build Coastguard Worker self.rt = ReturnTwo() 1747*da0073e9SAndroid Build Coastguard Worker 1748*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1749*da0073e9SAndroid Build Coastguard Worker return self.rt(x) 1750*da0073e9SAndroid Build Coastguard Worker 1751*da0073e9SAndroid Build Coastguard Worker ut = UnderTest() 1752*da0073e9SAndroid Build Coastguard Worker 1753*da0073e9SAndroid Build Coastguard Worker class RTTracer(torch.fx.Tracer): 1754*da0073e9SAndroid Build Coastguard Worker def is_leaf_module(self, m, module_qualified_name): 1755*da0073e9SAndroid Build Coastguard Worker return type(m) is ReturnTwo 1756*da0073e9SAndroid Build Coastguard Worker 1757*da0073e9SAndroid Build Coastguard Worker graph = RTTracer().trace(ut) 1758*da0073e9SAndroid Build Coastguard Worker mod = torch.fx.GraphModule(ut, graph) 1759*da0073e9SAndroid Build Coastguard Worker 1760*da0073e9SAndroid Build Coastguard Worker shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4)) 1761*da0073e9SAndroid Build Coastguard Worker 1762*da0073e9SAndroid Build Coastguard Worker for node in mod.graph.nodes: 1763*da0073e9SAndroid Build Coastguard Worker if node.op == 'call_module': 1764*da0073e9SAndroid Build Coastguard Worker assert 'tensor_meta' in node.meta 1765*da0073e9SAndroid Build Coastguard Worker tensor_meta = node.meta['tensor_meta'] 1766*da0073e9SAndroid Build Coastguard Worker assert tensor_meta[0] == 3 1767*da0073e9SAndroid Build Coastguard Worker assert tensor_meta[1].shape == torch.Size([]) 1768*da0073e9SAndroid Build Coastguard Worker 1769*da0073e9SAndroid Build Coastguard Worker def test_shape_prop_layout_3d(self): 1770*da0073e9SAndroid Build Coastguard Worker class ConvTest3d(torch.nn.Module): 1771*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1772*da0073e9SAndroid Build Coastguard Worker super().__init__() 1773*da0073e9SAndroid Build Coastguard Worker self.conv_mod = torch.nn.Conv3d(5, 5, 3) 1774*da0073e9SAndroid Build Coastguard Worker 1775*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1776*da0073e9SAndroid Build Coastguard Worker return self.conv_mod(x) 1777*da0073e9SAndroid Build Coastguard Worker 1778*da0073e9SAndroid Build Coastguard Worker test_mod_3d = ConvTest3d() 1779*da0073e9SAndroid Build Coastguard Worker traced_3d = symbolic_trace(test_mod_3d) 1780*da0073e9SAndroid Build Coastguard Worker x_3d = torch.randn(5, 5, 224, 224, 15) 1781*da0073e9SAndroid Build Coastguard Worker shape_prop.ShapeProp(traced_3d).propagate(x_3d) 1782*da0073e9SAndroid Build Coastguard Worker assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format 1783*da0073e9SAndroid Build Coastguard Worker for node in traced_3d.graph.nodes) 1784*da0073e9SAndroid Build Coastguard Worker 1785*da0073e9SAndroid Build Coastguard Worker x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d) 1786*da0073e9SAndroid Build Coastguard Worker traced_3d.to(memory_format=torch.channels_last_3d) 1787*da0073e9SAndroid Build Coastguard Worker shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d) 1788*da0073e9SAndroid Build Coastguard Worker for node in traced_3d.graph.nodes: 1789*da0073e9SAndroid Build Coastguard Worker # NB: the implementation of conv may not preserve the memory format, 1790*da0073e9SAndroid Build Coastguard Worker # unfortunately. The best we can do is just check that the placeholder 1791*da0073e9SAndroid Build Coastguard Worker # node is channels-last 1792*da0073e9SAndroid Build Coastguard Worker if node.op in {'placeholder'}: 1793*da0073e9SAndroid Build Coastguard Worker self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d) 1794*da0073e9SAndroid Build Coastguard Worker 1795*da0073e9SAndroid Build Coastguard Worker def test_nn_module_stack(self): 1796*da0073e9SAndroid Build Coastguard Worker class SubModule(torch.nn.Module): 1797*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1798*da0073e9SAndroid Build Coastguard Worker super().__init__() 1799*da0073e9SAndroid Build Coastguard Worker self.conv_mod = torch.nn.Conv2d(64, 64, (3, 3), padding=1, bias=False) 1800*da0073e9SAndroid Build Coastguard Worker 1801*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1802*da0073e9SAndroid Build Coastguard Worker return self.conv_mod(x) 1803*da0073e9SAndroid Build Coastguard Worker 1804*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1805*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1806*da0073e9SAndroid Build Coastguard Worker super().__init__() 1807*da0073e9SAndroid Build Coastguard Worker self.sub_mod = SubModule() 1808*da0073e9SAndroid Build Coastguard Worker 1809*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1810*da0073e9SAndroid Build Coastguard Worker return self.sub_mod(x) 1811*da0073e9SAndroid Build Coastguard Worker 1812*da0073e9SAndroid Build Coastguard Worker m = MyModule() 1813*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(m) 1814*da0073e9SAndroid Build Coastguard Worker 1815*da0073e9SAndroid Build Coastguard Worker mod_stack = {} 1816*da0073e9SAndroid Build Coastguard Worker expected_stack = [('sub_mod', ('sub_mod', type(m.sub_mod))), 1817*da0073e9SAndroid Build Coastguard Worker ('sub_mod.conv_mod', ('sub_mod.conv_mod', type(m.sub_mod.conv_mod)))] 1818*da0073e9SAndroid Build Coastguard Worker for node in gm.graph.nodes: 1819*da0073e9SAndroid Build Coastguard Worker mod_stack = node.meta.get('nn_module_stack', {}) 1820*da0073e9SAndroid Build Coastguard Worker if mod_stack: 1821*da0073e9SAndroid Build Coastguard Worker break 1822*da0073e9SAndroid Build Coastguard Worker stack_list = list(mod_stack.items()) 1823*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stack_list, expected_stack) 1824*da0073e9SAndroid Build Coastguard Worker 1825*da0073e9SAndroid Build Coastguard Worker def test_transformer_preserves_nn_module_stack_for_get_attr(self): 1826*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 1827*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1828*da0073e9SAndroid Build Coastguard Worker super().__init__() 1829*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.ones(1, 1)) 1830*da0073e9SAndroid Build Coastguard Worker 1831*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1832*da0073e9SAndroid Build Coastguard Worker return self.weight + x 1833*da0073e9SAndroid Build Coastguard Worker 1834*da0073e9SAndroid Build Coastguard Worker tracer = torch.fx.Tracer() 1835*da0073e9SAndroid Build Coastguard Worker graph = tracer.trace(M()) 1836*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(tracer.root, graph) 1837*da0073e9SAndroid Build Coastguard Worker for node in gm.graph.nodes: 1838*da0073e9SAndroid Build Coastguard Worker if node.op == 'get_attr': 1839*da0073e9SAndroid Build Coastguard Worker node.meta["nn_module_stack"] = "self" 1840*da0073e9SAndroid Build Coastguard Worker node.meta["stack_trace"] = "stack_trace" 1841*da0073e9SAndroid Build Coastguard Worker node.meta["source_fn_stack"] = "source_fn_stack" 1842*da0073e9SAndroid Build Coastguard Worker new_gm = Transformer(gm).transform() 1843*da0073e9SAndroid Build Coastguard Worker for node in new_gm.graph.nodes: 1844*da0073e9SAndroid Build Coastguard Worker if node.op == 'get_attr': 1845*da0073e9SAndroid Build Coastguard Worker self.assertEqual(node.meta["nn_module_stack"], "self") 1846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(node.meta["stack_trace"], "stack_trace") 1847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack") 1848*da0073e9SAndroid Build Coastguard Worker 1849*da0073e9SAndroid Build Coastguard Worker def test_interpreter(self): 1850*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1851*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1852*da0073e9SAndroid Build Coastguard Worker super().__init__() 1853*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 4)) 1854*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 5) 1855*da0073e9SAndroid Build Coastguard Worker 1856*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1857*da0073e9SAndroid Build Coastguard Worker return self.linear(x + self.param).clamp(min=0.0, max=1.0) 1858*da0073e9SAndroid Build Coastguard Worker 1859*da0073e9SAndroid Build Coastguard Worker m = MyModule() 1860*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(m) 1861*da0073e9SAndroid Build Coastguard Worker 1862*da0073e9SAndroid Build Coastguard Worker interpreter = Interpreter(gm) 1863*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 1864*da0073e9SAndroid Build Coastguard Worker self.assertEqual(interpreter.run(input), gm(input)) 1865*da0073e9SAndroid Build Coastguard Worker self.assertEqual(interpreter.run(input), m(input)) 1866*da0073e9SAndroid Build Coastguard Worker 1867*da0073e9SAndroid Build Coastguard Worker def test_interpreter_other_graph(self): 1868*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1869*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1870*da0073e9SAndroid Build Coastguard Worker super().__init__() 1871*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 4)) 1872*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 5) 1873*da0073e9SAndroid Build Coastguard Worker 1874*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1875*da0073e9SAndroid Build Coastguard Worker return self.linear(x + self.param).clamp(min=0.0, max=1.0) 1876*da0073e9SAndroid Build Coastguard Worker 1877*da0073e9SAndroid Build Coastguard Worker m = MyModule() 1878*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(m) 1879*da0073e9SAndroid Build Coastguard Worker 1880*da0073e9SAndroid Build Coastguard Worker interpreter = Interpreter(gm, graph=gm.graph) 1881*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 1882*da0073e9SAndroid Build Coastguard Worker self.assertEqual(interpreter.run(input), gm(input)) 1883*da0073e9SAndroid Build Coastguard Worker self.assertEqual(interpreter.run(input), m(input)) 1884*da0073e9SAndroid Build Coastguard Worker 1885*da0073e9SAndroid Build Coastguard Worker def test_interpreter_run_node_override(self): 1886*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1887*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1888*da0073e9SAndroid Build Coastguard Worker super().__init__() 1889*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 4)) 1890*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 5) 1891*da0073e9SAndroid Build Coastguard Worker 1892*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1893*da0073e9SAndroid Build Coastguard Worker return self.linear(x + self.param).clamp(min=0.0, max=1.0) 1894*da0073e9SAndroid Build Coastguard Worker 1895*da0073e9SAndroid Build Coastguard Worker m = MyModule() 1896*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(m) 1897*da0073e9SAndroid Build Coastguard Worker 1898*da0073e9SAndroid Build Coastguard Worker class RunNodeInterpreter(Interpreter): 1899*da0073e9SAndroid Build Coastguard Worker def __init__(self, module): 1900*da0073e9SAndroid Build Coastguard Worker super().__init__(module) 1901*da0073e9SAndroid Build Coastguard Worker 1902*da0073e9SAndroid Build Coastguard Worker def run_node(self, n : Node) -> Any: 1903*da0073e9SAndroid Build Coastguard Worker result = super().run_node(n) 1904*da0073e9SAndroid Build Coastguard Worker n.cached_value = result 1905*da0073e9SAndroid Build Coastguard Worker return result 1906*da0073e9SAndroid Build Coastguard Worker 1907*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 1908*da0073e9SAndroid Build Coastguard Worker RunNodeInterpreter(gm).run(input) 1909*da0073e9SAndroid Build Coastguard Worker for node in gm.graph.nodes: 1910*da0073e9SAndroid Build Coastguard Worker assert hasattr(node, 'cached_value') 1911*da0073e9SAndroid Build Coastguard Worker 1912*da0073e9SAndroid Build Coastguard Worker def test_interpreter_onthefly_swap(self): 1913*da0073e9SAndroid Build Coastguard Worker 1914*da0073e9SAndroid Build Coastguard Worker def fn(x): 1915*da0073e9SAndroid Build Coastguard Worker return torch.sigmoid(x).neg() 1916*da0073e9SAndroid Build Coastguard Worker 1917*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(fn) 1918*da0073e9SAndroid Build Coastguard Worker 1919*da0073e9SAndroid Build Coastguard Worker class NegSigmSwapInterpreter(Interpreter): 1920*da0073e9SAndroid Build Coastguard Worker def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: 1921*da0073e9SAndroid Build Coastguard Worker if target == torch.sigmoid: 1922*da0073e9SAndroid Build Coastguard Worker return torch.neg(*args, **kwargs) 1923*da0073e9SAndroid Build Coastguard Worker return super().call_function(n) # noqa: F821 1924*da0073e9SAndroid Build Coastguard Worker 1925*da0073e9SAndroid Build Coastguard Worker def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: 1926*da0073e9SAndroid Build Coastguard Worker if target == 'neg': 1927*da0073e9SAndroid Build Coastguard Worker call_self, *args_tail = args 1928*da0073e9SAndroid Build Coastguard Worker return call_self.sigmoid(*args_tail, **kwargs) 1929*da0073e9SAndroid Build Coastguard Worker return super().call_method(n) # noqa: F821 1930*da0073e9SAndroid Build Coastguard Worker 1931*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 1932*da0073e9SAndroid Build Coastguard Worker result = NegSigmSwapInterpreter(gm).run(input) 1933*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.neg(input).sigmoid()) 1934*da0073e9SAndroid Build Coastguard Worker 1935*da0073e9SAndroid Build Coastguard Worker def test_interpreter_partial_eval(self): 1936*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1937*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1938*da0073e9SAndroid Build Coastguard Worker super().__init__() 1939*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 4)) 1940*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 5) 1941*da0073e9SAndroid Build Coastguard Worker 1942*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1943*da0073e9SAndroid Build Coastguard Worker return self.linear(x + self.param).clamp(min=0.0, max=1.0) 1944*da0073e9SAndroid Build Coastguard Worker 1945*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(MyModule()) 1946*da0073e9SAndroid Build Coastguard Worker interp = Interpreter(gm) 1947*da0073e9SAndroid Build Coastguard Worker env = {} 1948*da0073e9SAndroid Build Coastguard Worker for node in gm.graph.nodes: 1949*da0073e9SAndroid Build Coastguard Worker if node.op == 'call_module' and node.target == 'linear': 1950*da0073e9SAndroid Build Coastguard Worker env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0 1951*da0073e9SAndroid Build Coastguard Worker break 1952*da0073e9SAndroid Build Coastguard Worker assert len(env) == 1 1953*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4) 1954*da0073e9SAndroid Build Coastguard Worker result = interp.run(x, initial_env=env) 1955*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0)) 1956*da0073e9SAndroid Build Coastguard Worker 1957*da0073e9SAndroid Build Coastguard Worker def test_interpreter_star_args(self): 1958*da0073e9SAndroid Build Coastguard Worker def with_star_args(x, *args): 1959*da0073e9SAndroid Build Coastguard Worker return x + args[0] 1960*da0073e9SAndroid Build Coastguard Worker 1961*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(with_star_args) 1962*da0073e9SAndroid Build Coastguard Worker interp = Interpreter(gm) 1963*da0073e9SAndroid Build Coastguard Worker result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4)) 1964*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.ones(3, 4) * 2.0) 1965*da0073e9SAndroid Build Coastguard Worker 1966*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 1967*da0073e9SAndroid Build Coastguard Worker def test_interpreter_noop_resnet18(self): 1968*da0073e9SAndroid Build Coastguard Worker rn18 = torchvision_models.resnet18() 1969*da0073e9SAndroid Build Coastguard Worker transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform() 1970*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(5, 3, 224, 224) 1971*da0073e9SAndroid Build Coastguard Worker self.assertEqual(transformed(inp), rn18(inp)) 1972*da0073e9SAndroid Build Coastguard Worker 1973*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 1974*da0073e9SAndroid Build Coastguard Worker def test_interpreter_gc_values(self): 1975*da0073e9SAndroid Build Coastguard Worker rn18 = torchvision_models.resnet18() 1976*da0073e9SAndroid Build Coastguard Worker interp = Interpreter(symbolic_trace(rn18)) 1977*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(5, 3, 224, 224) 1978*da0073e9SAndroid Build Coastguard Worker out = interp.run(inp) 1979*da0073e9SAndroid Build Coastguard Worker env_key_names = {n.name for n in interp.env.keys()} 1980*da0073e9SAndroid Build Coastguard Worker self.assertEqual(env_key_names, {'output'}) 1981*da0073e9SAndroid Build Coastguard Worker 1982*da0073e9SAndroid Build Coastguard Worker def test_interpreter_default_args(self): 1983*da0073e9SAndroid Build Coastguard Worker class Model(torch.nn.Module): 1984*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y=3.14159): 1985*da0073e9SAndroid Build Coastguard Worker return x + y 1986*da0073e9SAndroid Build Coastguard Worker 1987*da0073e9SAndroid Build Coastguard Worker model = Model() 1988*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(model) 1989*da0073e9SAndroid Build Coastguard Worker 1990*da0073e9SAndroid Build Coastguard Worker interp = Interpreter(gm) 1991*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 3) 1992*da0073e9SAndroid Build Coastguard Worker out = interp.run(x) 1993*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(out, x + 3.14159) 1994*da0073e9SAndroid Build Coastguard Worker 1995*da0073e9SAndroid Build Coastguard Worker def test_interpreter_not_enough_args(self): 1996*da0073e9SAndroid Build Coastguard Worker class Model(torch.nn.Module): 1997*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 1998*da0073e9SAndroid Build Coastguard Worker return x + y 1999*da0073e9SAndroid Build Coastguard Worker 2000*da0073e9SAndroid Build Coastguard Worker model = Model() 2001*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(model) 2002*da0073e9SAndroid Build Coastguard Worker 2003*da0073e9SAndroid Build Coastguard Worker interp = Interpreter(gm) 2004*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 3) 2005*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 2006*da0073e9SAndroid Build Coastguard Worker 'Expected positional argument for parameter y, but one was not passed in'): 2007*da0073e9SAndroid Build Coastguard Worker out = interp.run(x) 2008*da0073e9SAndroid Build Coastguard Worker 2009*da0073e9SAndroid Build Coastguard Worker def test_transformer_noop(self): 2010*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2011*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2012*da0073e9SAndroid Build Coastguard Worker super().__init__() 2013*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 4)) 2014*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 5) 2015*da0073e9SAndroid Build Coastguard Worker 2016*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2017*da0073e9SAndroid Build Coastguard Worker return self.linear(x + self.param).clamp(min=0.0, max=1.0) 2018*da0073e9SAndroid Build Coastguard Worker 2019*da0073e9SAndroid Build Coastguard Worker m = MyModule() 2020*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(m) 2021*da0073e9SAndroid Build Coastguard Worker 2022*da0073e9SAndroid Build Coastguard Worker new_gm = Transformer(gm).transform() 2023*da0073e9SAndroid Build Coastguard Worker 2024*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 2025*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_gm(input), gm(input)) 2026*da0073e9SAndroid Build Coastguard Worker 2027*da0073e9SAndroid Build Coastguard Worker def test_transformer_op_swap(self): 2028*da0073e9SAndroid Build Coastguard Worker 2029*da0073e9SAndroid Build Coastguard Worker def fn(x): 2030*da0073e9SAndroid Build Coastguard Worker return torch.sigmoid(x).neg() 2031*da0073e9SAndroid Build Coastguard Worker 2032*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(fn) 2033*da0073e9SAndroid Build Coastguard Worker 2034*da0073e9SAndroid Build Coastguard Worker class NegSigmSwapXformer(Transformer): 2035*da0073e9SAndroid Build Coastguard Worker def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: 2036*da0073e9SAndroid Build Coastguard Worker if target == torch.sigmoid: 2037*da0073e9SAndroid Build Coastguard Worker return torch.neg(*args, **kwargs) 2038*da0073e9SAndroid Build Coastguard Worker return super().call_function(n) # noqa: F821 2039*da0073e9SAndroid Build Coastguard Worker 2040*da0073e9SAndroid Build Coastguard Worker def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: 2041*da0073e9SAndroid Build Coastguard Worker if target == 'neg': 2042*da0073e9SAndroid Build Coastguard Worker call_self, *args_tail = args 2043*da0073e9SAndroid Build Coastguard Worker return call_self.sigmoid(*args_tail, **kwargs) 2044*da0073e9SAndroid Build Coastguard Worker return super().call_method(n) # noqa: F821 2045*da0073e9SAndroid Build Coastguard Worker 2046*da0073e9SAndroid Build Coastguard Worker transformed = NegSigmSwapXformer(gm).transform() 2047*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 2048*da0073e9SAndroid Build Coastguard Worker self.assertEqual(transformed(input), torch.neg(input).sigmoid()) 2049*da0073e9SAndroid Build Coastguard Worker 2050*da0073e9SAndroid Build Coastguard Worker def test_transformer_multi_outputs(self): 2051*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2052*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2053*da0073e9SAndroid Build Coastguard Worker super().__init__() 2054*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 4)) 2055*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 5) 2056*da0073e9SAndroid Build Coastguard Worker 2057*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2058*da0073e9SAndroid Build Coastguard Worker x = x + self.param 2059*da0073e9SAndroid Build Coastguard Worker out = self.linear(x) 2060*da0073e9SAndroid Build Coastguard Worker return x, out 2061*da0073e9SAndroid Build Coastguard Worker 2062*da0073e9SAndroid Build Coastguard Worker m = MyModule() 2063*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(m) 2064*da0073e9SAndroid Build Coastguard Worker 2065*da0073e9SAndroid Build Coastguard Worker new_gm = Transformer(gm).transform() 2066*da0073e9SAndroid Build Coastguard Worker 2067*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 2068*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_gm(input), gm(input)) 2069*da0073e9SAndroid Build Coastguard Worker 2070*da0073e9SAndroid Build Coastguard Worker def test_fn_type_annotations(self): 2071*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 2072*da0073e9SAndroid Build Coastguard Worker def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]: 2073*da0073e9SAndroid Build Coastguard Worker return {'a': p.x + p.y + z + i} 2074*da0073e9SAndroid Build Coastguard Worker 2075*da0073e9SAndroid Build Coastguard Worker foo_scripted = torch.jit.script(Foo()) 2076*da0073e9SAndroid Build Coastguard Worker foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) 2077*da0073e9SAndroid Build Coastguard Worker 2078*da0073e9SAndroid Build Coastguard Worker fxed = symbolic_trace(Foo()) 2079*da0073e9SAndroid Build Coastguard Worker fxed_scripted = torch.jit.script(fxed) 2080*da0073e9SAndroid Build Coastguard Worker fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) 2081*da0073e9SAndroid Build Coastguard Worker 2082*da0073e9SAndroid Build Coastguard Worker def test_fn_type_annotation_empty(self): 2083*da0073e9SAndroid Build Coastguard Worker def forward(a : List[torch.Tensor]): 2084*da0073e9SAndroid Build Coastguard Worker return a[0] 2085*da0073e9SAndroid Build Coastguard Worker torch.jit.script(symbolic_trace(forward)) 2086*da0073e9SAndroid Build Coastguard Worker 2087*da0073e9SAndroid Build Coastguard Worker def test_wrapped_method(self): 2088*da0073e9SAndroid Build Coastguard Worker def wrap_with_relu(fn): 2089*da0073e9SAndroid Build Coastguard Worker @functools.wraps(fn) 2090*da0073e9SAndroid Build Coastguard Worker def wrapper(*args, **kwargs): 2091*da0073e9SAndroid Build Coastguard Worker return torch.relu(fn(*args, **kwargs)) 2092*da0073e9SAndroid Build Coastguard Worker return wrapper 2093*da0073e9SAndroid Build Coastguard Worker 2094*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 2095*da0073e9SAndroid Build Coastguard Worker @wrap_with_relu 2096*da0073e9SAndroid Build Coastguard Worker def forward(self, x, w): 2097*da0073e9SAndroid Build Coastguard Worker return torch.matmul(x, w) 2098*da0073e9SAndroid Build Coastguard Worker 2099*da0073e9SAndroid Build Coastguard Worker f = Foo() 2100*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(f) 2101*da0073e9SAndroid Build Coastguard Worker x, w = torch.rand(3, 4), torch.rand(4, 4) 2102*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes)) 2103*da0073e9SAndroid Build Coastguard Worker 2104*da0073e9SAndroid Build Coastguard Worker def test_empty_graph_codegen(self): 2105*da0073e9SAndroid Build Coastguard Worker graph = torch.fx.Graph() 2106*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gm(), None) 2108*da0073e9SAndroid Build Coastguard Worker 2109*da0073e9SAndroid Build Coastguard Worker def test_sequential(self): 2110*da0073e9SAndroid Build Coastguard Worker m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)) 2111*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(m) 2112*da0073e9SAndroid Build Coastguard Worker gm_copy = copy.deepcopy(gm) 2113*da0073e9SAndroid Build Coastguard Worker 2114*da0073e9SAndroid Build Coastguard Worker def test_ctx_mgr(self): 2115*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 2116*da0073e9SAndroid Build Coastguard Worker def do_nothing(): 2117*da0073e9SAndroid Build Coastguard Worker yield 2118*da0073e9SAndroid Build Coastguard Worker 2119*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2120*da0073e9SAndroid Build Coastguard Worker @do_nothing() 2121*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2122*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) 2123*da0073e9SAndroid Build Coastguard Worker 2124*da0073e9SAndroid Build Coastguard Worker m = M() 2125*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, (torch.rand(3, 4),)) 2126*da0073e9SAndroid Build Coastguard Worker 2127*da0073e9SAndroid Build Coastguard Worker def test_typename_print(self): 2128*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.Graph = torch.fx.Graph() 2129*da0073e9SAndroid Build Coastguard Worker x : torch.fx.Node = graph.create_node('placeholder', 'x') 2130*da0073e9SAndroid Build Coastguard Worker b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,), 2131*da0073e9SAndroid Build Coastguard Worker type_expr=List[float]) 2132*da0073e9SAndroid Build Coastguard Worker output : torch.fx.Node = graph.output(b) 2133*da0073e9SAndroid Build Coastguard Worker 2134*da0073e9SAndroid Build Coastguard Worker self.assertTrue('typing.List[float]' in str(graph)) 2135*da0073e9SAndroid Build Coastguard Worker 2136*da0073e9SAndroid Build Coastguard Worker def test_layout(self): 2137*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2138*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2139*da0073e9SAndroid Build Coastguard Worker return torch.empty_like(x, layout=torch.strided, pin_memory=False).fill_(0) 2140*da0073e9SAndroid Build Coastguard Worker 2141*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(M()) 2142*da0073e9SAndroid Build Coastguard Worker x = torch.rand(5, 9, 3, 4) 2143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(x), torch.zeros_like(x)) 2144*da0073e9SAndroid Build Coastguard Worker 2145*da0073e9SAndroid Build Coastguard Worker def test_ellipsis(self): 2146*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2147*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 2148*da0073e9SAndroid Build Coastguard Worker return x + y[:, 1:10, ...] 2149*da0073e9SAndroid Build Coastguard Worker 2150*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(M()) 2151*da0073e9SAndroid Build Coastguard Worker x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4) 2152*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(x, y), x + y[:, 1:10, ...]) 2153*da0073e9SAndroid Build Coastguard Worker 2154*da0073e9SAndroid Build Coastguard Worker def test_inf_nan(self): 2155*da0073e9SAndroid Build Coastguard Worker class FooMod(torch.nn.Module): 2156*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2157*da0073e9SAndroid Build Coastguard Worker return x + float('inf'), x + float('-inf'), x + float('nan') 2158*da0073e9SAndroid Build Coastguard Worker 2159*da0073e9SAndroid Build Coastguard Worker fm = FooMod() 2160*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(fm, (torch.rand(3, 4),)) 2161*da0073e9SAndroid Build Coastguard Worker 2162*da0073e9SAndroid Build Coastguard Worker def test_inf_nan_kwds(self): 2163*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.Graph = torch.fx.Graph() 2164*da0073e9SAndroid Build Coastguard Worker x : torch.fx.Node = graph.create_node('placeholder', 'x') 2165*da0073e9SAndroid Build Coastguard Worker b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf') 2166*da0073e9SAndroid Build Coastguard Worker c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan') 2167*da0073e9SAndroid Build Coastguard Worker graph.output((b, c)) 2168*da0073e9SAndroid Build Coastguard Worker 2169*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2170*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 2171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gm(x), (x + float('inf'), x + float('nan'))) 2172*da0073e9SAndroid Build Coastguard Worker 2173*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_recursion_depth(self): 2174*da0073e9SAndroid Build Coastguard Worker depth = sys.getrecursionlimit() + 20 2175*da0073e9SAndroid Build Coastguard Worker 2176*da0073e9SAndroid Build Coastguard Worker g = torch.fx.Graph() 2177*da0073e9SAndroid Build Coastguard Worker x = g.placeholder('x') 2178*da0073e9SAndroid Build Coastguard Worker for i in range(depth): 2179*da0073e9SAndroid Build Coastguard Worker x = g.call_function(torch.relu, (x,)) 2180*da0073e9SAndroid Build Coastguard Worker g.output(x) 2181*da0073e9SAndroid Build Coastguard Worker 2182*da0073e9SAndroid Build Coastguard Worker copied_graph = copy.deepcopy(g) 2183*da0073e9SAndroid Build Coastguard Worker 2184*da0073e9SAndroid Build Coastguard Worker val_map = {} 2185*da0073e9SAndroid Build Coastguard Worker for orig_node, new_node in zip(g.nodes, copied_graph.nodes): 2186*da0073e9SAndroid Build Coastguard Worker val_map[orig_node] = new_node 2187*da0073e9SAndroid Build Coastguard Worker 2188*da0073e9SAndroid Build Coastguard Worker for orig_node, new_node in zip(g.nodes, copied_graph.nodes): 2189*da0073e9SAndroid Build Coastguard Worker orig_users = set(orig_node.users.keys()) 2190*da0073e9SAndroid Build Coastguard Worker orig_users_equiv = {val_map[u] for u in orig_users} 2191*da0073e9SAndroid Build Coastguard Worker new_users = set(new_node.users.keys()) 2192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(orig_users_equiv, new_users) 2193*da0073e9SAndroid Build Coastguard Worker 2194*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 2195*da0073e9SAndroid Build Coastguard Worker def test_replace_uses(self): 2196*da0073e9SAndroid Build Coastguard Worker rn18 = torchvision_models.resnet18() 2197*da0073e9SAndroid Build Coastguard Worker 2198*da0073e9SAndroid Build Coastguard Worker class LowerReluTracer(torch.fx.Tracer): 2199*da0073e9SAndroid Build Coastguard Worker def is_leaf_module(self, m : torch.nn.Module, qualname : str): 2200*da0073e9SAndroid Build Coastguard Worker if isinstance(m, torch.nn.ReLU): 2201*da0073e9SAndroid Build Coastguard Worker return False 2202*da0073e9SAndroid Build Coastguard Worker return super().is_leaf_module(m, qualname) 2203*da0073e9SAndroid Build Coastguard Worker 2204*da0073e9SAndroid Build Coastguard Worker rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18)) 2205*da0073e9SAndroid Build Coastguard Worker 2206*da0073e9SAndroid Build Coastguard Worker to_erase = [] 2207*da0073e9SAndroid Build Coastguard Worker for node in rn18_traced.graph.nodes: 2208*da0073e9SAndroid Build Coastguard Worker if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]: 2209*da0073e9SAndroid Build Coastguard Worker kwargs = node.kwargs.copy() 2210*da0073e9SAndroid Build Coastguard Worker # Neg doesn't have in-place 2211*da0073e9SAndroid Build Coastguard Worker kwargs.pop('inplace') 2212*da0073e9SAndroid Build Coastguard Worker with rn18_traced.graph.inserting_before(node): 2213*da0073e9SAndroid Build Coastguard Worker new_node = rn18_traced.graph.call_function( 2214*da0073e9SAndroid Build Coastguard Worker the_function=torch.neg, args=node.args, kwargs=node.kwargs) 2215*da0073e9SAndroid Build Coastguard Worker node.replace_all_uses_with(replace_with=new_node) 2216*da0073e9SAndroid Build Coastguard Worker to_erase.append(node) 2217*da0073e9SAndroid Build Coastguard Worker 2218*da0073e9SAndroid Build Coastguard Worker for node in to_erase: 2219*da0073e9SAndroid Build Coastguard Worker rn18_traced.graph.erase_node(node) 2220*da0073e9SAndroid Build Coastguard Worker 2221*da0073e9SAndroid Build Coastguard Worker def test_replace_input(self): 2222*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.Graph = torch.fx.Graph() 2223*da0073e9SAndroid Build Coastguard Worker x : torch.fx.Node = graph.create_node('placeholder', 'x') 2224*da0073e9SAndroid Build Coastguard Worker y : torch.fx.Node = graph.create_node('placeholder', 'y') 2225*da0073e9SAndroid Build Coastguard Worker b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) 2226*da0073e9SAndroid Build Coastguard Worker output : torch.fx.Node = graph.output(b) 2227*da0073e9SAndroid Build Coastguard Worker 2228*da0073e9SAndroid Build Coastguard Worker b.replace_input_with(x, y) 2229*da0073e9SAndroid Build Coastguard Worker 2230*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2231*da0073e9SAndroid Build Coastguard Worker 2232*da0073e9SAndroid Build Coastguard Worker input_x = torch.randn(33, 44) 2233*da0073e9SAndroid Build Coastguard Worker input_y = torch.randn(11, 22) 2234*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gm(input_x, input_y), torch.relu(input_y)) 2235*da0073e9SAndroid Build Coastguard Worker 2236*da0073e9SAndroid Build Coastguard Worker def test_insertion_point(self): 2237*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.Graph = torch.fx.Graph() 2238*da0073e9SAndroid Build Coastguard Worker x : torch.fx.Node = graph.create_node('placeholder', 'x') 2239*da0073e9SAndroid Build Coastguard Worker b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) 2240*da0073e9SAndroid Build Coastguard Worker output : torch.fx.Node = graph.output(b) 2241*da0073e9SAndroid Build Coastguard Worker 2242*da0073e9SAndroid Build Coastguard Worker with graph.inserting_before(b): 2243*da0073e9SAndroid Build Coastguard Worker neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) 2244*da0073e9SAndroid Build Coastguard Worker _, *relu_args = b.args 2245*da0073e9SAndroid Build Coastguard Worker b.args = (neg, *relu_args) 2246*da0073e9SAndroid Build Coastguard Worker 2247*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2248*da0073e9SAndroid Build Coastguard Worker 2249*da0073e9SAndroid Build Coastguard Worker input = torch.randn(33, 44) 2250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gm(input), torch.relu(torch.neg(input))) 2251*da0073e9SAndroid Build Coastguard Worker 2252*da0073e9SAndroid Build Coastguard Worker def test_update_args_api(self): 2253*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.Graph = torch.fx.Graph() 2254*da0073e9SAndroid Build Coastguard Worker x : torch.fx.Node = graph.create_node('placeholder', 'x') 2255*da0073e9SAndroid Build Coastguard Worker y : torch.fx.Node = graph.create_node('placeholder', 'y') 2256*da0073e9SAndroid Build Coastguard Worker b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) 2257*da0073e9SAndroid Build Coastguard Worker output : torch.fx.Node = graph.output(b) 2258*da0073e9SAndroid Build Coastguard Worker 2259*da0073e9SAndroid Build Coastguard Worker orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2260*da0073e9SAndroid Build Coastguard Worker inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5) 2261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x)) 2262*da0073e9SAndroid Build Coastguard Worker 2263*da0073e9SAndroid Build Coastguard Worker b.update_arg(0, y) 2264*da0073e9SAndroid Build Coastguard Worker new_gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2265*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y)) 2266*da0073e9SAndroid Build Coastguard Worker 2267*da0073e9SAndroid Build Coastguard Worker def test_update_kwargs_api(self): 2268*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.Graph = torch.fx.Graph() 2269*da0073e9SAndroid Build Coastguard Worker x : torch.fx.Node = graph.create_node('placeholder', 'x') 2270*da0073e9SAndroid Build Coastguard Worker y : torch.fx.Node = graph.create_node('placeholder', 'y') 2271*da0073e9SAndroid Build Coastguard Worker b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, kwargs={'input': x}) 2272*da0073e9SAndroid Build Coastguard Worker output : torch.fx.Node = graph.output(b) 2273*da0073e9SAndroid Build Coastguard Worker 2274*da0073e9SAndroid Build Coastguard Worker orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2275*da0073e9SAndroid Build Coastguard Worker inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5) 2276*da0073e9SAndroid Build Coastguard Worker self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x)) 2277*da0073e9SAndroid Build Coastguard Worker 2278*da0073e9SAndroid Build Coastguard Worker b.update_kwarg('input', y) 2279*da0073e9SAndroid Build Coastguard Worker new_gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2280*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y)) 2281*da0073e9SAndroid Build Coastguard Worker 2282*da0073e9SAndroid Build Coastguard Worker def test_immutable_list_pytree_ops(self): 2283*da0073e9SAndroid Build Coastguard Worker rand_tensor = torch.randn(5, 3) 2284*da0073e9SAndroid Build Coastguard Worker l = immutable_list([3, [rand_tensor, 42]]) 2285*da0073e9SAndroid Build Coastguard Worker 2286*da0073e9SAndroid Build Coastguard Worker flattened, spec = pytree.tree_flatten(l) 2287*da0073e9SAndroid Build Coastguard Worker assert flattened == [3, rand_tensor, 42] 2288*da0073e9SAndroid Build Coastguard Worker 2289*da0073e9SAndroid Build Coastguard Worker unflattened = pytree.tree_unflatten(flattened, spec) 2290*da0073e9SAndroid Build Coastguard Worker assert unflattened == l 2291*da0073e9SAndroid Build Coastguard Worker assert isinstance(unflattened, immutable_list) 2292*da0073e9SAndroid Build Coastguard Worker 2293*da0073e9SAndroid Build Coastguard Worker def test_immutable_dict_pytree_ops(self): 2294*da0073e9SAndroid Build Coastguard Worker rand_tensor = torch.randn(5, 3) 2295*da0073e9SAndroid Build Coastguard Worker d = immutable_dict({'a': 3, 'b': [rand_tensor, 42]}) 2296*da0073e9SAndroid Build Coastguard Worker 2297*da0073e9SAndroid Build Coastguard Worker flattened, spec = pytree.tree_flatten(d) 2298*da0073e9SAndroid Build Coastguard Worker assert flattened == [3, rand_tensor, 42] 2299*da0073e9SAndroid Build Coastguard Worker 2300*da0073e9SAndroid Build Coastguard Worker unflattened = pytree.tree_unflatten(flattened, spec) 2301*da0073e9SAndroid Build Coastguard Worker assert unflattened == d 2302*da0073e9SAndroid Build Coastguard Worker assert isinstance(unflattened, immutable_dict) 2303*da0073e9SAndroid Build Coastguard Worker 2304*da0073e9SAndroid Build Coastguard Worker def test_move_before(self): 2305*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.Graph = torch.fx.Graph() 2306*da0073e9SAndroid Build Coastguard Worker x : torch.fx.Node = graph.create_node('placeholder', 'x') 2307*da0073e9SAndroid Build Coastguard Worker b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) 2308*da0073e9SAndroid Build Coastguard Worker output : torch.fx.Node = graph.output(b) 2309*da0073e9SAndroid Build Coastguard Worker 2310*da0073e9SAndroid Build Coastguard Worker neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) 2311*da0073e9SAndroid Build Coastguard Worker _, *relu_args = b.args 2312*da0073e9SAndroid Build Coastguard Worker b.args = (neg, *relu_args) 2313*da0073e9SAndroid Build Coastguard Worker b.prepend(neg) 2314*da0073e9SAndroid Build Coastguard Worker 2315*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.GraphModule(torch.nn.Module(), graph) 2316*da0073e9SAndroid Build Coastguard Worker 2317*da0073e9SAndroid Build Coastguard Worker input = torch.randn(33, 44) 2318*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gm(input), torch.relu(torch.neg(input))) 2319*da0073e9SAndroid Build Coastguard Worker 2320*da0073e9SAndroid Build Coastguard Worker def test_prepend_self(self): 2321*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.Graph = torch.fx.Graph() 2322*da0073e9SAndroid Build Coastguard Worker x : torch.fx.Node = graph.create_node('placeholder', 'x') 2323*da0073e9SAndroid Build Coastguard Worker b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) 2324*da0073e9SAndroid Build Coastguard Worker output : torch.fx.Node = graph.output(b) 2325*da0073e9SAndroid Build Coastguard Worker 2326*da0073e9SAndroid Build Coastguard Worker b.prepend(b) 2327*da0073e9SAndroid Build Coastguard Worker x.append(b) 2328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(graph.nodes), 3) 2329*da0073e9SAndroid Build Coastguard Worker 2330*da0073e9SAndroid Build Coastguard Worker def test_erase_node_error(self): 2331*da0073e9SAndroid Build Coastguard Worker st = SimpleTest() 2332*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(st) 2333*da0073e9SAndroid Build Coastguard Worker 2334*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 2335*da0073e9SAndroid Build Coastguard Worker # Test deleting with uses both in another Node and at the output 2336*da0073e9SAndroid Build Coastguard Worker if node.target in [operator.add, torch.relu]: 2337*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'): 2338*da0073e9SAndroid Build Coastguard Worker traced.graph.erase_node(node) 2339*da0073e9SAndroid Build Coastguard Worker 2340*da0073e9SAndroid Build Coastguard Worker def test_copy_it(self): 2341*da0073e9SAndroid Build Coastguard Worker d = immutable_dict([(3, 4), (5, 6)]) 2342*da0073e9SAndroid Build Coastguard Worker l = immutable_list([(3, 4), (5, 6)]) 2343*da0073e9SAndroid Build Coastguard Worker 2344*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d, deepcopy(d)) 2345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l, deepcopy(l)) 2346*da0073e9SAndroid Build Coastguard Worker 2347*da0073e9SAndroid Build Coastguard Worker def test_get_torch_func_signature(self): 2348*da0073e9SAndroid Build Coastguard Worker for key in dir(torch): 2349*da0073e9SAndroid Build Coastguard Worker obj = getattr(torch, key) 2350*da0073e9SAndroid Build Coastguard Worker if callable(obj): 2351*da0073e9SAndroid Build Coastguard Worker schemas = get_signature_for_torch_op(obj) 2352*da0073e9SAndroid Build Coastguard Worker 2353*da0073e9SAndroid Build Coastguard Worker def test_find_uses(self): 2354*da0073e9SAndroid Build Coastguard Worker graph = torch.fx.Graph() 2355*da0073e9SAndroid Build Coastguard Worker x = torch.fx.Proxy(graph.placeholder('x')) 2356*da0073e9SAndroid Build Coastguard Worker 2357*da0073e9SAndroid Build Coastguard Worker y = torch.relu(x) 2358*da0073e9SAndroid Build Coastguard Worker z = x + x 2359*da0073e9SAndroid Build Coastguard Worker u = torch.neg(x) 2360*da0073e9SAndroid Build Coastguard Worker graph.output((y + z + u).node) 2361*da0073e9SAndroid Build Coastguard Worker graph.lint() 2362*da0073e9SAndroid Build Coastguard Worker 2363*da0073e9SAndroid Build Coastguard Worker users_of_x = x.node.users 2364*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(users_of_x), 3) 2365*da0073e9SAndroid Build Coastguard Worker expected_ops = {'relu', 'add', 'neg'} 2366*da0073e9SAndroid Build Coastguard Worker for use in users_of_x: 2367*da0073e9SAndroid Build Coastguard Worker assert any(use.name.startswith(prefix) for prefix in expected_ops) 2368*da0073e9SAndroid Build Coastguard Worker 2369*da0073e9SAndroid Build Coastguard Worker def test_inline_graph(self): 2370*da0073e9SAndroid Build Coastguard Worker class InlineInto(torch.nn.Module): 2371*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2372*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) 2373*da0073e9SAndroid Build Coastguard Worker 2374*da0073e9SAndroid Build Coastguard Worker class ToInline(torch.nn.Module): 2375*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2376*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 2377*da0073e9SAndroid Build Coastguard Worker 2378*da0073e9SAndroid Build Coastguard Worker inline_into = symbolic_trace(InlineInto()) 2379*da0073e9SAndroid Build Coastguard Worker to_inline = symbolic_trace(ToInline()) 2380*da0073e9SAndroid Build Coastguard Worker 2381*da0073e9SAndroid Build Coastguard Worker combined_graph = torch.fx.Graph() 2382*da0073e9SAndroid Build Coastguard Worker output_node = combined_graph.graph_copy(inline_into.graph, {}) 2383*da0073e9SAndroid Build Coastguard Worker 2384*da0073e9SAndroid Build Coastguard Worker input_node = next(iter(to_inline.graph.nodes)) 2385*da0073e9SAndroid Build Coastguard Worker assert input_node and input_node.op == 'placeholder' 2386*da0073e9SAndroid Build Coastguard Worker 2387*da0073e9SAndroid Build Coastguard Worker val_map = {input_node : output_node} 2388*da0073e9SAndroid Build Coastguard Worker output = combined_graph.graph_copy(to_inline.graph, val_map) 2389*da0073e9SAndroid Build Coastguard Worker combined_graph.output(output) 2390*da0073e9SAndroid Build Coastguard Worker 2391*da0073e9SAndroid Build Coastguard Worker combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph) 2392*da0073e9SAndroid Build Coastguard Worker 2393*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3, 4) 2394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(combined_module(input), input.relu().neg()) 2395*da0073e9SAndroid Build Coastguard Worker 2396*da0073e9SAndroid Build Coastguard Worker def test_multi_insert_point(self): 2397*da0073e9SAndroid Build Coastguard Worker graph = torch.fx.Graph() 2398*da0073e9SAndroid Build Coastguard Worker x = torch.fx.Proxy(graph.placeholder('x')) 2399*da0073e9SAndroid Build Coastguard Worker relu = torch.relu(x) 2400*da0073e9SAndroid Build Coastguard Worker 2401*da0073e9SAndroid Build Coastguard Worker with graph.inserting_before(relu.node): 2402*da0073e9SAndroid Build Coastguard Worker y = torch.neg(x) 2403*da0073e9SAndroid Build Coastguard Worker z = torch.tanh(y) 2404*da0073e9SAndroid Build Coastguard Worker 2405*da0073e9SAndroid Build Coastguard Worker graph.output((relu.node, z.node)) 2406*da0073e9SAndroid Build Coastguard Worker graph.lint() 2407*da0073e9SAndroid Build Coastguard Worker 2408*da0073e9SAndroid Build Coastguard Worker expected_ops = ['x', 'neg', 'tanh', 'relu'] 2409*da0073e9SAndroid Build Coastguard Worker for node, expected in zip(graph.nodes, expected_ops): 2410*da0073e9SAndroid Build Coastguard Worker assert expected in node.name 2411*da0073e9SAndroid Build Coastguard Worker 2412*da0073e9SAndroid Build Coastguard Worker def test_reassign_args_kwargs_uses(self): 2413*da0073e9SAndroid Build Coastguard Worker graph = torch.fx.Graph() 2414*da0073e9SAndroid Build Coastguard Worker x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y')) 2415*da0073e9SAndroid Build Coastguard Worker z = x + y 2416*da0073e9SAndroid Build Coastguard Worker zed = z + z + z 2417*da0073e9SAndroid Build Coastguard Worker graph.output(zed.node) 2418*da0073e9SAndroid Build Coastguard Worker graph.lint() 2419*da0073e9SAndroid Build Coastguard Worker 2420*da0073e9SAndroid Build Coastguard Worker # zed = z + z + z -> zed = z + z + x 2421*da0073e9SAndroid Build Coastguard Worker zed.node.args = (zed.node.args[0], x.node) 2422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(x.node.users.keys()), [z.node, zed.node]) 2423*da0073e9SAndroid Build Coastguard Worker 2424*da0073e9SAndroid Build Coastguard Worker # z = x + y -> z = y + y 2425*da0073e9SAndroid Build Coastguard Worker z.node.args = (y.node, y.node) 2426*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(x.node.users.keys()), [zed.node]) 2427*da0073e9SAndroid Build Coastguard Worker 2428*da0073e9SAndroid Build Coastguard Worker def test_trace_function(self): 2429*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 2430*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) + y 2431*da0073e9SAndroid Build Coastguard Worker 2432*da0073e9SAndroid Build Coastguard Worker x, y = torch.randn(3, 4), torch.randn(3, 4) 2433*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(foo, (x, y)) 2434*da0073e9SAndroid Build Coastguard Worker 2435*da0073e9SAndroid Build Coastguard Worker def test_trace_return_dataclass(self): 2436*da0073e9SAndroid Build Coastguard Worker """ 2437*da0073e9SAndroid Build Coastguard Worker Test case for Module that return dataclass 2438*da0073e9SAndroid Build Coastguard Worker """ 2439*da0073e9SAndroid Build Coastguard Worker from dataclasses import dataclass 2440*da0073e9SAndroid Build Coastguard Worker 2441*da0073e9SAndroid Build Coastguard Worker @dataclass 2442*da0073e9SAndroid Build Coastguard Worker class MyOutput: 2443*da0073e9SAndroid Build Coastguard Worker foo: torch.Tensor 2444*da0073e9SAndroid Build Coastguard Worker bar: torch.Tensor 2445*da0073e9SAndroid Build Coastguard Worker 2446*da0073e9SAndroid Build Coastguard Worker class ModuleReturnDataclass(torch.nn.Module): 2447*da0073e9SAndroid Build Coastguard Worker def forward(self, d : torch.Tensor): 2448*da0073e9SAndroid Build Coastguard Worker return MyOutput(foo=d + d, bar=d * 3) 2449*da0073e9SAndroid Build Coastguard Worker 2450*da0073e9SAndroid Build Coastguard Worker module = ModuleReturnDataclass() 2451*da0073e9SAndroid Build Coastguard Worker traced_graph = symbolic_trace(module).graph 2452*da0073e9SAndroid Build Coastguard Worker print(traced_graph) 2453*da0073e9SAndroid Build Coastguard Worker 2454*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(module, traced_graph) 2455*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1) 2456*da0073e9SAndroid Build Coastguard Worker 2457*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module(x), gm(x)) 2458*da0073e9SAndroid Build Coastguard Worker 2459*da0073e9SAndroid Build Coastguard Worker def test_trace_return_dataclass_nested(self): 2460*da0073e9SAndroid Build Coastguard Worker """ 2461*da0073e9SAndroid Build Coastguard Worker Test case for Module that return dataclass 2462*da0073e9SAndroid Build Coastguard Worker """ 2463*da0073e9SAndroid Build Coastguard Worker from dataclasses import dataclass 2464*da0073e9SAndroid Build Coastguard Worker 2465*da0073e9SAndroid Build Coastguard Worker @dataclass 2466*da0073e9SAndroid Build Coastguard Worker class MyOutput: 2467*da0073e9SAndroid Build Coastguard Worker foo: torch.Tensor 2468*da0073e9SAndroid Build Coastguard Worker bar: torch.Tensor 2469*da0073e9SAndroid Build Coastguard Worker 2470*da0073e9SAndroid Build Coastguard Worker class ModuleReturnDataclass(torch.nn.Module): 2471*da0073e9SAndroid Build Coastguard Worker def forward(self, d : torch.Tensor): 2472*da0073e9SAndroid Build Coastguard Worker return MyOutput(foo=d + d, bar=d * 3) 2473*da0073e9SAndroid Build Coastguard Worker 2474*da0073e9SAndroid Build Coastguard Worker class CallsModule(torch.nn.Module): 2475*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2476*da0073e9SAndroid Build Coastguard Worker super().__init__() 2477*da0073e9SAndroid Build Coastguard Worker self.m = ModuleReturnDataclass() 2478*da0073e9SAndroid Build Coastguard Worker 2479*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2480*da0073e9SAndroid Build Coastguard Worker tmp = self.m(x) 2481*da0073e9SAndroid Build Coastguard Worker return MyOutput(foo=tmp.foo, bar=tmp.bar) 2482*da0073e9SAndroid Build Coastguard Worker 2483*da0073e9SAndroid Build Coastguard Worker module = CallsModule() 2484*da0073e9SAndroid Build Coastguard Worker traced_graph = symbolic_trace(module).graph 2485*da0073e9SAndroid Build Coastguard Worker print(traced_graph) 2486*da0073e9SAndroid Build Coastguard Worker 2487*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(module, traced_graph) 2488*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1) 2489*da0073e9SAndroid Build Coastguard Worker 2490*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module(x), gm(x)) 2491*da0073e9SAndroid Build Coastguard Worker 2492*da0073e9SAndroid Build Coastguard Worker def test_trace_return_namedtuple(self): 2493*da0073e9SAndroid Build Coastguard Worker """ 2494*da0073e9SAndroid Build Coastguard Worker Test case for Module that return namedtuple 2495*da0073e9SAndroid Build Coastguard Worker """ 2496*da0073e9SAndroid Build Coastguard Worker class MyOutput(NamedTuple): 2497*da0073e9SAndroid Build Coastguard Worker foo: torch.Tensor 2498*da0073e9SAndroid Build Coastguard Worker bar: torch.Tensor 2499*da0073e9SAndroid Build Coastguard Worker 2500*da0073e9SAndroid Build Coastguard Worker class ModuleReturnNamedTuple(torch.nn.Module): 2501*da0073e9SAndroid Build Coastguard Worker def forward(self, d : torch.Tensor): 2502*da0073e9SAndroid Build Coastguard Worker return MyOutput(foo=d, bar=d) 2503*da0073e9SAndroid Build Coastguard Worker 2504*da0073e9SAndroid Build Coastguard Worker module = ModuleReturnNamedTuple() 2505*da0073e9SAndroid Build Coastguard Worker 2506*da0073e9SAndroid Build Coastguard Worker traced_graph = symbolic_trace(module).graph 2507*da0073e9SAndroid Build Coastguard Worker print(traced_graph) 2508*da0073e9SAndroid Build Coastguard Worker 2509*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(module, traced_graph) 2510*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1) 2511*da0073e9SAndroid Build Coastguard Worker 2512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module(x), gm(x)) 2513*da0073e9SAndroid Build Coastguard Worker 2514*da0073e9SAndroid Build Coastguard Worker def test_trace_dict_int_keys(self): 2515*da0073e9SAndroid Build Coastguard Worker class ModWithDictArg(torch.nn.Module): 2516*da0073e9SAndroid Build Coastguard Worker def forward(self, d : Dict[int, torch.Tensor]): 2517*da0073e9SAndroid Build Coastguard Worker return d[42] 2518*da0073e9SAndroid Build Coastguard Worker 2519*da0073e9SAndroid Build Coastguard Worker class CallsModWithDict(torch.nn.Module): 2520*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2521*da0073e9SAndroid Build Coastguard Worker super().__init__() 2522*da0073e9SAndroid Build Coastguard Worker self.m = ModWithDictArg() 2523*da0073e9SAndroid Build Coastguard Worker 2524*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2525*da0073e9SAndroid Build Coastguard Worker return self.m({42: x}) 2526*da0073e9SAndroid Build Coastguard Worker 2527*da0073e9SAndroid Build Coastguard Worker class MyTracer(torch.fx.Tracer): 2528*da0073e9SAndroid Build Coastguard Worker def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: 2529*da0073e9SAndroid Build Coastguard Worker return isinstance(m, ModWithDictArg) 2530*da0073e9SAndroid Build Coastguard Worker 2531*da0073e9SAndroid Build Coastguard Worker traced_graph = MyTracer().trace(CallsModWithDict()) 2532*da0073e9SAndroid Build Coastguard Worker 2533*da0073e9SAndroid Build Coastguard Worker def test_trace_dict_proxy_keys(self): 2534*da0073e9SAndroid Build Coastguard Worker class ModWithDictArg(torch.nn.Module): 2535*da0073e9SAndroid Build Coastguard Worker def forward(self, d : Dict[torch.Tensor, torch.Tensor]): 2536*da0073e9SAndroid Build Coastguard Worker return d[42] 2537*da0073e9SAndroid Build Coastguard Worker 2538*da0073e9SAndroid Build Coastguard Worker class CallsModWithDict(torch.nn.Module): 2539*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2540*da0073e9SAndroid Build Coastguard Worker super().__init__() 2541*da0073e9SAndroid Build Coastguard Worker self.m = ModWithDictArg() 2542*da0073e9SAndroid Build Coastguard Worker 2543*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2544*da0073e9SAndroid Build Coastguard Worker return self.m({x: x}) 2545*da0073e9SAndroid Build Coastguard Worker 2546*da0073e9SAndroid Build Coastguard Worker class MyTracer(torch.fx.Tracer): 2547*da0073e9SAndroid Build Coastguard Worker def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: 2548*da0073e9SAndroid Build Coastguard Worker return isinstance(m, ModWithDictArg) 2549*da0073e9SAndroid Build Coastguard Worker 2550*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'): 2551*da0073e9SAndroid Build Coastguard Worker traced_graph = MyTracer().trace(CallsModWithDict()) 2552*da0073e9SAndroid Build Coastguard Worker 2553*da0073e9SAndroid Build Coastguard Worker def test_module_deepcopy_edit_nodes(self): 2554*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 2555*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2556*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) 2557*da0073e9SAndroid Build Coastguard Worker 2558*da0073e9SAndroid Build Coastguard Worker traced1 = symbolic_trace(Foo()) 2559*da0073e9SAndroid Build Coastguard Worker copied = copy.deepcopy(traced1) 2560*da0073e9SAndroid Build Coastguard Worker 2561*da0073e9SAndroid Build Coastguard Worker for node in copied.graph.nodes: 2562*da0073e9SAndroid Build Coastguard Worker if node.target == torch.relu: 2563*da0073e9SAndroid Build Coastguard Worker node.target = torch.neg 2564*da0073e9SAndroid Build Coastguard Worker 2565*da0073e9SAndroid Build Coastguard Worker copied.recompile() 2566*da0073e9SAndroid Build Coastguard Worker traced1.recompile() 2567*da0073e9SAndroid Build Coastguard Worker 2568*da0073e9SAndroid Build Coastguard Worker x = torch.randn(15, 15) 2569*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(traced1(x), torch.relu(x)) 2570*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(copied(x), torch.neg(x)) 2571*da0073e9SAndroid Build Coastguard Worker 2572*da0073e9SAndroid Build Coastguard Worker def test_direct_param_use(self): 2573*da0073e9SAndroid Build Coastguard Worker class TransposeTest(torch.nn.Module): 2574*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2575*da0073e9SAndroid Build Coastguard Worker super().__init__() 2576*da0073e9SAndroid Build Coastguard Worker self.b = torch.nn.Parameter(torch.rand(4, 3)) 2577*da0073e9SAndroid Build Coastguard Worker 2578*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2579*da0073e9SAndroid Build Coastguard Worker return self.b 2580*da0073e9SAndroid Build Coastguard Worker 2581*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 2582*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2583*da0073e9SAndroid Build Coastguard Worker super().__init__() 2584*da0073e9SAndroid Build Coastguard Worker self.a = TransposeTest() 2585*da0073e9SAndroid Build Coastguard Worker 2586*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2587*da0073e9SAndroid Build Coastguard Worker return self.a.b, self.a.b.t(), self.a.b.view(12) 2588*da0073e9SAndroid Build Coastguard Worker 2589*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(Foo()) 2590*da0073e9SAndroid Build Coastguard Worker assert all('constant' not in node.target for node in traced.graph.nodes) 2591*da0073e9SAndroid Build Coastguard Worker 2592*da0073e9SAndroid Build Coastguard Worker def test_single_default_arg(self): 2593*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2594*da0073e9SAndroid Build Coastguard Worker def forward(self, y=1): 2595*da0073e9SAndroid Build Coastguard Worker return y 2596*da0073e9SAndroid Build Coastguard Worker 2597*da0073e9SAndroid Build Coastguard Worker m = M() 2598*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, ()) 2599*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, (3,)) 2600*da0073e9SAndroid Build Coastguard Worker 2601*da0073e9SAndroid Build Coastguard Worker def test_multiple_default_args(self): 2602*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2603*da0073e9SAndroid Build Coastguard Worker def forward(self, y=1, z=2): 2604*da0073e9SAndroid Build Coastguard Worker return y + z 2605*da0073e9SAndroid Build Coastguard Worker 2606*da0073e9SAndroid Build Coastguard Worker m = M() 2607*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, ()) 2608*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, (3,)) 2609*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, (3, 4)) 2610*da0073e9SAndroid Build Coastguard Worker 2611*da0073e9SAndroid Build Coastguard Worker def test_regular_and_default_args(self): 2612*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2613*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y=1): 2614*da0073e9SAndroid Build Coastguard Worker return x + y 2615*da0073e9SAndroid Build Coastguard Worker 2616*da0073e9SAndroid Build Coastguard Worker m = M() 2617*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, (2,)) 2618*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, (2, 3)) 2619*da0073e9SAndroid Build Coastguard Worker 2620*da0073e9SAndroid Build Coastguard Worker def test_string_literal_return(self): 2621*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2622*da0073e9SAndroid Build Coastguard Worker def forward(self): 2623*da0073e9SAndroid Build Coastguard Worker return "foo" 2624*da0073e9SAndroid Build Coastguard Worker 2625*da0073e9SAndroid Build Coastguard Worker m = M() 2626*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, ()) 2627*da0073e9SAndroid Build Coastguard Worker 2628*da0073e9SAndroid Build Coastguard Worker def test_namedtuple_return_qualname(self): 2629*da0073e9SAndroid Build Coastguard Worker class NamedTupReturn(torch.nn.Module): 2630*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2631*da0073e9SAndroid Build Coastguard Worker return MyNamedTup(x, x) 2632*da0073e9SAndroid Build Coastguard Worker 2633*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(NamedTupReturn()) 2634*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3, 4) 2635*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(input), MyNamedTup(input, input)) 2636*da0073e9SAndroid Build Coastguard Worker 2637*da0073e9SAndroid Build Coastguard Worker def test_update_args_kwargs_yells_at_you(self): 2638*da0073e9SAndroid Build Coastguard Worker symtraced = symbolic_trace(SimpleTest()) 2639*da0073e9SAndroid Build Coastguard Worker node = next(iter(symtraced.graph.nodes)) 2640*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'): 2641*da0073e9SAndroid Build Coastguard Worker node.__update_args_kwargs((), {}) 2642*da0073e9SAndroid Build Coastguard Worker 2643*da0073e9SAndroid Build Coastguard Worker def test_torchbind_class_attribute_in_fx(self): 2644*da0073e9SAndroid Build Coastguard Worker if IS_FBCODE or IS_WINDOWS or IS_MACOS: 2645*da0073e9SAndroid Build Coastguard Worker self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping") 2646*da0073e9SAndroid Build Coastguard Worker 2647*da0073e9SAndroid Build Coastguard Worker class FooBar1234(torch.nn.Module): 2648*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2649*da0073e9SAndroid Build Coastguard Worker super().__init__() 2650*da0073e9SAndroid Build Coastguard Worker self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"]) 2651*da0073e9SAndroid Build Coastguard Worker 2652*da0073e9SAndroid Build Coastguard Worker def forward(self): 2653*da0073e9SAndroid Build Coastguard Worker return self.f.top() 2654*da0073e9SAndroid Build Coastguard Worker 2655*da0073e9SAndroid Build Coastguard Worker m = FooBar1234() 2656*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(m, ()) 2657*da0073e9SAndroid Build Coastguard Worker 2658*da0073e9SAndroid Build Coastguard Worker def test_torchbind_class_attribute_in_fx_tensor_arg(self): 2659*da0073e9SAndroid Build Coastguard Worker if IS_FBCODE or IS_WINDOWS or IS_MACOS: 2660*da0073e9SAndroid Build Coastguard Worker self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping") 2661*da0073e9SAndroid Build Coastguard Worker 2662*da0073e9SAndroid Build Coastguard Worker class FooBar2341(torch.nn.Module): 2663*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2664*da0073e9SAndroid Build Coastguard Worker super().__init__() 2665*da0073e9SAndroid Build Coastguard Worker self.f = torch.classes._TorchScriptTesting._ReLUClass() 2666*da0073e9SAndroid Build Coastguard Worker 2667*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2668*da0073e9SAndroid Build Coastguard Worker return self.f.run(x) 2669*da0073e9SAndroid Build Coastguard Worker 2670*da0073e9SAndroid Build Coastguard Worker m = FooBar2341() 2671*da0073e9SAndroid Build Coastguard Worker 2672*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 2673*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 2674*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(input), m(input)) 2675*da0073e9SAndroid Build Coastguard Worker 2676*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes)) 2677*da0073e9SAndroid Build Coastguard Worker 2678*da0073e9SAndroid Build Coastguard Worker def test_script_method_trace(self): 2679*da0073e9SAndroid Build Coastguard Worker class Scripted(torch.nn.Module): 2680*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2681*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) 2682*da0073e9SAndroid Build Coastguard Worker 2683*da0073e9SAndroid Build Coastguard Worker class Holder(torch.nn.Module): 2684*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2685*da0073e9SAndroid Build Coastguard Worker super().__init__() 2686*da0073e9SAndroid Build Coastguard Worker self.s = torch.jit.script(Scripted()) 2687*da0073e9SAndroid Build Coastguard Worker 2688*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2689*da0073e9SAndroid Build Coastguard Worker return self.s(x) 2690*da0073e9SAndroid Build Coastguard Worker 2691*da0073e9SAndroid Build Coastguard Worker h = Holder() 2692*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(h) 2693*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 2694*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(input), h(input)) 2695*da0073e9SAndroid Build Coastguard Worker 2696*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes)) 2697*da0073e9SAndroid Build Coastguard Worker 2698*da0073e9SAndroid Build Coastguard Worker def test_namedtuple_return_trace(self): 2699*da0073e9SAndroid Build Coastguard Worker class NamedTupReturn(torch.nn.Module): 2700*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2701*da0073e9SAndroid Build Coastguard Worker return Pair(x, x) 2702*da0073e9SAndroid Build Coastguard Worker 2703*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(NamedTupReturn()) 2704*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3, 4) 2705*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(input), Pair(input, input)) 2706*da0073e9SAndroid Build Coastguard Worker 2707*da0073e9SAndroid Build Coastguard Worker def test_named_tuple_inlined(self): 2708*da0073e9SAndroid Build Coastguard Worker class NamedTupMod(torch.nn.Module): 2709*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 2710*da0073e9SAndroid Build Coastguard Worker return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp)) 2711*da0073e9SAndroid Build Coastguard Worker 2712*da0073e9SAndroid Build Coastguard Worker m = NamedTupMod() 2713*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3, 4) 2714*da0073e9SAndroid Build Coastguard Worker ref = m(input) 2715*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 2716*da0073e9SAndroid Build Coastguard Worker 2717*da0073e9SAndroid Build Coastguard Worker res = traced(input) 2718*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 2719*da0073e9SAndroid Build Coastguard Worker 2720*da0073e9SAndroid Build Coastguard Worker # Check Pair NamedTuple works when inlined into the function call. 2721*da0073e9SAndroid Build Coastguard Worker ph = call_func = None 2722*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 2723*da0073e9SAndroid Build Coastguard Worker if node.op == "placeholder": 2724*da0073e9SAndroid Build Coastguard Worker ph = node 2725*da0073e9SAndroid Build Coastguard Worker elif node.op == "call_function" and node.target == wrapped_named_tup: 2726*da0073e9SAndroid Build Coastguard Worker node.update_arg(0, Pair(ph, 1.2)) 2727*da0073e9SAndroid Build Coastguard Worker node.update_kwarg("p2", Pair(3.4, ph)) 2728*da0073e9SAndroid Build Coastguard Worker call_func = node 2729*da0073e9SAndroid Build Coastguard Worker break 2730*da0073e9SAndroid Build Coastguard Worker self.assertTrue(call_func is not None) 2731*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(call_func.args[0], Pair)) 2732*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(call_func.kwargs["p2"], Pair)) 2733*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)") 2734*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)") 2735*da0073e9SAndroid Build Coastguard Worker 2736*da0073e9SAndroid Build Coastguard Worker traced.graph.eliminate_dead_code() 2737*da0073e9SAndroid Build Coastguard Worker traced.recompile() 2738*da0073e9SAndroid Build Coastguard Worker res = traced(input) 2739*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 2740*da0073e9SAndroid Build Coastguard Worker 2741*da0073e9SAndroid Build Coastguard Worker def test_return_type_exists(self): 2742*da0073e9SAndroid Build Coastguard Worker class ReturnTypeModule(torch.nn.Module): 2743*da0073e9SAndroid Build Coastguard Worker def other(self, x: List[str]) -> List[str]: 2744*da0073e9SAndroid Build Coastguard Worker return x 2745*da0073e9SAndroid Build Coastguard Worker 2746*da0073e9SAndroid Build Coastguard Worker def forward(self, x: List[str]) -> List[str]: 2747*da0073e9SAndroid Build Coastguard Worker return self.other(x) 2748*da0073e9SAndroid Build Coastguard Worker 2749*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(ReturnTypeModule()) 2750*da0073e9SAndroid Build Coastguard Worker self.assertIn("-> typing_List[str]", traced._code) 2751*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(traced) 2752*da0073e9SAndroid Build Coastguard Worker self.assertIn("-> List[str]", scripted.code) 2753*da0073e9SAndroid Build Coastguard Worker 2754*da0073e9SAndroid Build Coastguard Worker def getitem_inner(self): 2755*da0073e9SAndroid Build Coastguard Worker class GetItemBase(torch.nn.Module): 2756*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2757*da0073e9SAndroid Build Coastguard Worker super().__init__() 2758*da0073e9SAndroid Build Coastguard Worker self.pe = torch.nn.Buffer(torch.randn(8, 8)) 2759*da0073e9SAndroid Build Coastguard Worker 2760*da0073e9SAndroid Build Coastguard Worker class GetItem1(GetItemBase): 2761*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2762*da0073e9SAndroid Build Coastguard Worker return self.pe[:, :x.size(0)] 2763*da0073e9SAndroid Build Coastguard Worker 2764*da0073e9SAndroid Build Coastguard Worker class GetItem2(GetItemBase): 2765*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2766*da0073e9SAndroid Build Coastguard Worker return self.pe[x.size(0)] 2767*da0073e9SAndroid Build Coastguard Worker 2768*da0073e9SAndroid Build Coastguard Worker class GetItem3(GetItemBase): 2769*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2770*da0073e9SAndroid Build Coastguard Worker return self.pe[4] # fx creates `self._tensor_constant0` here 2771*da0073e9SAndroid Build Coastguard Worker 2772*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(GetItem1(), [torch.zeros(4)]) 2773*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(GetItem2(), [torch.zeros(4)]) 2774*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(GetItem3(), [torch.zeros(4)]) 2775*da0073e9SAndroid Build Coastguard Worker 2776*da0073e9SAndroid Build Coastguard Worker @unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1", 2777*da0073e9SAndroid Build Coastguard Worker "Will be checked in test_getitem_subproc") 2778*da0073e9SAndroid Build Coastguard Worker def test_getitem(self): 2779*da0073e9SAndroid Build Coastguard Worker self.getitem_inner() 2780*da0073e9SAndroid Build Coastguard Worker 2781*da0073e9SAndroid Build Coastguard Worker def test_getitem_subproc(self): 2782*da0073e9SAndroid Build Coastguard Worker # need to run this test in a subproc to work around: 2783*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/50710 2784*da0073e9SAndroid Build Coastguard Worker proc = Process(target=run_getitem_target) 2785*da0073e9SAndroid Build Coastguard Worker proc.start() 2786*da0073e9SAndroid Build Coastguard Worker proc.join() 2787*da0073e9SAndroid Build Coastguard Worker self.assertEqual(proc.exitcode, 0) 2788*da0073e9SAndroid Build Coastguard Worker 2789*da0073e9SAndroid Build Coastguard Worker def test_user_friendly_call_provenance_with_function(self): 2790*da0073e9SAndroid Build Coastguard Worker def fn(x): 2791*da0073e9SAndroid Build Coastguard Worker return wrapper_fn(x) 2792*da0073e9SAndroid Build Coastguard Worker 2793*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(fn) 2794*da0073e9SAndroid Build Coastguard Worker 2795*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is " 2796*da0073e9SAndroid Build Coastguard Worker "being compiled since it was called" 2797*da0073e9SAndroid Build Coastguard Worker " from 'fn.forward'"): 2798*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(traced) 2799*da0073e9SAndroid Build Coastguard Worker 2800*da0073e9SAndroid Build Coastguard Worker def test_user_friendly_call_provenance_with_module(self): 2801*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2802*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2803*da0073e9SAndroid Build Coastguard Worker return wrapper_fn(x) 2804*da0073e9SAndroid Build Coastguard Worker 2805*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(M()) 2806*da0073e9SAndroid Build Coastguard Worker 2807*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is " 2808*da0073e9SAndroid Build Coastguard Worker "being compiled since it was called" 2809*da0073e9SAndroid Build Coastguard Worker " from 'M.forward'"): 2810*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(traced) 2811*da0073e9SAndroid Build Coastguard Worker 2812*da0073e9SAndroid Build Coastguard Worker def test_snake_case(self): 2813*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2814*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2815*da0073e9SAndroid Build Coastguard Worker super().__init__() 2816*da0073e9SAndroid Build Coastguard Worker self.activations = torch.nn.ModuleDict([ 2817*da0073e9SAndroid Build Coastguard Worker ["snake_case", torch.nn.ReLU()], 2818*da0073e9SAndroid Build Coastguard Worker ["PascalCase", torch.nn.LeakyReLU()], 2819*da0073e9SAndroid Build Coastguard Worker ["ALL_CAPS", torch.nn.PReLU()] 2820*da0073e9SAndroid Build Coastguard Worker ]) 2821*da0073e9SAndroid Build Coastguard Worker 2822*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2823*da0073e9SAndroid Build Coastguard Worker a = self.activations["snake_case"](x) 2824*da0073e9SAndroid Build Coastguard Worker b = self.activations["PascalCase"](x) 2825*da0073e9SAndroid Build Coastguard Worker c = self.activations["ALL_CAPS"](x) 2826*da0073e9SAndroid Build Coastguard Worker return a, b, c 2827*da0073e9SAndroid Build Coastguard Worker 2828*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(M()) 2829*da0073e9SAndroid Build Coastguard Worker 2830*da0073e9SAndroid Build Coastguard Worker check = [ 2831*da0073e9SAndroid Build Coastguard Worker ("activations_snake_case", "activations.snake_case"), 2832*da0073e9SAndroid Build Coastguard Worker ("activations_pascal_case", "activations.PascalCase"), 2833*da0073e9SAndroid Build Coastguard Worker ("activations_all_caps", "activations.ALL_CAPS") 2834*da0073e9SAndroid Build Coastguard Worker ] 2835*da0073e9SAndroid Build Coastguard Worker 2836*da0073e9SAndroid Build Coastguard Worker i = 0 2837*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 2838*da0073e9SAndroid Build Coastguard Worker if node.op == "placeholder" or node.op == "output": 2839*da0073e9SAndroid Build Coastguard Worker continue 2840*da0073e9SAndroid Build Coastguard Worker name = check[i][0] 2841*da0073e9SAndroid Build Coastguard Worker target = check[i][1] 2842*da0073e9SAndroid Build Coastguard Worker self.assertEqual(name, node.name) 2843*da0073e9SAndroid Build Coastguard Worker self.assertEqual(target, node.target) 2844*da0073e9SAndroid Build Coastguard Worker i += 1 2845*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i, 3) 2846*da0073e9SAndroid Build Coastguard Worker 2847*da0073e9SAndroid Build Coastguard Worker def test_no_mutation(self): 2848*da0073e9SAndroid Build Coastguard Worker from torch.fx.immutable_collections import immutable_list 2849*da0073e9SAndroid Build Coastguard Worker x = immutable_list([3, 4]) 2850*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "new_args"): 2851*da0073e9SAndroid Build Coastguard Worker x[0] = 4 2852*da0073e9SAndroid Build Coastguard Worker 2853*da0073e9SAndroid Build Coastguard Worker def test_partial_trace(self): 2854*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 2855*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 2856*da0073e9SAndroid Build Coastguard Worker if y: 2857*da0073e9SAndroid Build Coastguard Worker return 2 * x 2858*da0073e9SAndroid Build Coastguard Worker else: 2859*da0073e9SAndroid Build Coastguard Worker return x 2860*da0073e9SAndroid Build Coastguard Worker mod = Foo() 2861*da0073e9SAndroid Build Coastguard Worker mod_true = symbolic_trace(mod, concrete_args={'y': True}) 2862*da0073e9SAndroid Build Coastguard Worker mod_false = symbolic_trace(mod, concrete_args={'y': False}) 2863*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_true(3, True), 6) 2864*da0073e9SAndroid Build Coastguard Worker print(mod_true.code) 2865*da0073e9SAndroid Build Coastguard Worker assert any(i.target == torch._assert for i in mod_true.graph.nodes) 2866*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 2867*da0073e9SAndroid Build Coastguard Worker mod_true(3, False) 2868*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_false(3, False), 3) 2869*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 2870*da0073e9SAndroid Build Coastguard Worker mod_false(3, True) 2871*da0073e9SAndroid Build Coastguard Worker 2872*da0073e9SAndroid Build Coastguard Worker def f_higher(a, f): 2873*da0073e9SAndroid Build Coastguard Worker return f(a) 2874*da0073e9SAndroid Build Coastguard Worker 2875*da0073e9SAndroid Build Coastguard Worker nf = symbolic_trace(f_higher, concrete_args={'f': lambda x: x * 2}) 2876*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf(3, lambda x: x * 2), 6) 2877*da0073e9SAndroid Build Coastguard Worker 2878*da0073e9SAndroid Build Coastguard Worker def test_custom_traceback_raised_when_exception_source_is_graphmodule(self): 2879*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2880*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2881*da0073e9SAndroid Build Coastguard Worker super().__init__() 2882*da0073e9SAndroid Build Coastguard Worker self.W = torch.nn.Parameter(torch.randn(5)) 2883*da0073e9SAndroid Build Coastguard Worker 2884*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2885*da0073e9SAndroid Build Coastguard Worker return torch.dot(self.W, x) 2886*da0073e9SAndroid Build Coastguard Worker 2887*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(M()) 2888*da0073e9SAndroid Build Coastguard Worker 2889*da0073e9SAndroid Build Coastguard Worker out = [n for n in traced.graph.nodes if n.op == "output"][-1] 2890*da0073e9SAndroid Build Coastguard Worker with traced.graph.inserting_before(out): 2891*da0073e9SAndroid Build Coastguard Worker relu_out = traced.graph.call_method(method_name='relu', 2892*da0073e9SAndroid Build Coastguard Worker args=(out.args[0],)) 2893*da0073e9SAndroid Build Coastguard Worker out.args = (relu_out,) 2894*da0073e9SAndroid Build Coastguard Worker 2895*da0073e9SAndroid Build Coastguard Worker traced.recompile() 2896*da0073e9SAndroid Build Coastguard Worker 2897*da0073e9SAndroid Build Coastguard Worker with self.capture_stderr() as captured: 2898*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 2899*da0073e9SAndroid Build Coastguard Worker traced(5) 2900*da0073e9SAndroid Build Coastguard Worker 2901*da0073e9SAndroid Build Coastguard Worker self.assertRegex(captured[0], 2902*da0073e9SAndroid Build Coastguard Worker r"Call using an FX-traced Module, line .* of the " 2903*da0073e9SAndroid Build Coastguard Worker r"traced Module's generated forward function:") 2904*da0073e9SAndroid Build Coastguard Worker 2905*da0073e9SAndroid Build Coastguard Worker def test_custom_traceback_not_raised_when_exception_source_is_submodule(self): 2906*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2907*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2908*da0073e9SAndroid Build Coastguard Worker super().__init__() 2909*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(3, 4) 2910*da0073e9SAndroid Build Coastguard Worker 2911*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2912*da0073e9SAndroid Build Coastguard Worker return self.linear(x) 2913*da0073e9SAndroid Build Coastguard Worker 2914*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(M()) 2915*da0073e9SAndroid Build Coastguard Worker 2916*da0073e9SAndroid Build Coastguard Worker # Do not change this to `capture_stderr` or another context 2917*da0073e9SAndroid Build Coastguard Worker # manager without ensuring that the output is as expected 2918*da0073e9SAndroid Build Coastguard Worker try: 2919*da0073e9SAndroid Build Coastguard Worker traced(torch.rand(5, 5)) 2920*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 2921*da0073e9SAndroid Build Coastguard Worker captured = traceback.format_exc() 2922*da0073e9SAndroid Build Coastguard Worker 2923*da0073e9SAndroid Build Coastguard Worker self.assertNotRegex(captured, 2924*da0073e9SAndroid Build Coastguard Worker r"Call using an FX-traced Module, line .* of the " 2925*da0073e9SAndroid Build Coastguard Worker r"traced Module's generated forward function:") 2926*da0073e9SAndroid Build Coastguard Worker 2927*da0073e9SAndroid Build Coastguard Worker def test_graph_module_replicate_for_dp(self): 2928*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 2929*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2930*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) 2931*da0073e9SAndroid Build Coastguard Worker 2932*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(Foo()) 2933*da0073e9SAndroid Build Coastguard Worker 2934*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 3) 2935*da0073e9SAndroid Build Coastguard Worker out = gm(x) 2936*da0073e9SAndroid Build Coastguard Worker 2937*da0073e9SAndroid Build Coastguard Worker replica = gm._replicate_for_data_parallel() 2938*da0073e9SAndroid Build Coastguard Worker out_replica = replica(x) 2939*da0073e9SAndroid Build Coastguard Worker 2940*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(out_replica, out) 2941*da0073e9SAndroid Build Coastguard Worker 2942*da0073e9SAndroid Build Coastguard Worker def test_ast_rewriter_rewrites_assert(self): 2943*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2944*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, y: int, z: int): 2945*da0073e9SAndroid Build Coastguard Worker assert y == z 2946*da0073e9SAndroid Build Coastguard Worker return torch.add(x, x) 2947*da0073e9SAndroid Build Coastguard Worker 2948*da0073e9SAndroid Build Coastguard Worker ast_rewriter = RewritingTracer() 2949*da0073e9SAndroid Build Coastguard Worker graph = ast_rewriter.trace(M()) 2950*da0073e9SAndroid Build Coastguard Worker traced = GraphModule(ast_rewriter.root, graph, "gm") 2951*da0073e9SAndroid Build Coastguard Worker 2952*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 2953*da0073e9SAndroid Build Coastguard Worker 2954*da0073e9SAndroid Build Coastguard Worker def test_ast_rewriter_rewrites_assert_with_message(self): 2955*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2956*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, y: int, z: int): 2957*da0073e9SAndroid Build Coastguard Worker assert y == z, "msg" 2958*da0073e9SAndroid Build Coastguard Worker return torch.add(x, x) 2959*da0073e9SAndroid Build Coastguard Worker 2960*da0073e9SAndroid Build Coastguard Worker ast_rewriter = RewritingTracer() 2961*da0073e9SAndroid Build Coastguard Worker graph = ast_rewriter.trace(M()) 2962*da0073e9SAndroid Build Coastguard Worker traced = GraphModule(ast_rewriter.root, graph, "gm") 2963*da0073e9SAndroid Build Coastguard Worker 2964*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 2965*da0073e9SAndroid Build Coastguard Worker 2966*da0073e9SAndroid Build Coastguard Worker def test_throw_out_variant(self): 2967*da0073e9SAndroid Build Coastguard Worker def foo(x): 2968*da0073e9SAndroid Build Coastguard Worker y = torch.rand_like(x) 2969*da0073e9SAndroid Build Coastguard Worker torch.sigmoid(x, out=y) 2970*da0073e9SAndroid Build Coastguard Worker return y 2971*da0073e9SAndroid Build Coastguard Worker 2972*da0073e9SAndroid Build Coastguard Worker class MyTracer(torch.fx.Tracer): 2973*da0073e9SAndroid Build Coastguard Worker check_mutable_operations = True 2974*da0073e9SAndroid Build Coastguard Worker 2975*da0073e9SAndroid Build Coastguard Worker tracer = MyTracer() 2976*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'): 2977*da0073e9SAndroid Build Coastguard Worker traced_graph = tracer.trace(foo) 2978*da0073e9SAndroid Build Coastguard Worker 2979*da0073e9SAndroid Build Coastguard Worker def test_ast_rewriter_reassigns_submodules(self): 2980*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2981*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2982*da0073e9SAndroid Build Coastguard Worker super().__init__() 2983*da0073e9SAndroid Build Coastguard Worker self.bn = torch.nn.BatchNorm2d(100) 2984*da0073e9SAndroid Build Coastguard Worker 2985*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor): 2986*da0073e9SAndroid Build Coastguard Worker return torch.add(x, x) 2987*da0073e9SAndroid Build Coastguard Worker 2988*da0073e9SAndroid Build Coastguard Worker ast_rewriter = RewritingTracer() 2989*da0073e9SAndroid Build Coastguard Worker graph = ast_rewriter.trace(M()) 2990*da0073e9SAndroid Build Coastguard Worker traced = GraphModule(ast_rewriter.root, graph, "gm") 2991*da0073e9SAndroid Build Coastguard Worker 2992*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 2993*da0073e9SAndroid Build Coastguard Worker 2994*da0073e9SAndroid Build Coastguard Worker def test_ast_rewriter_wrap(self): 2995*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5)) 2996*da0073e9SAndroid Build Coastguard Worker 2997*da0073e9SAndroid Build Coastguard Worker def to_trace(y): 2998*da0073e9SAndroid Build Coastguard Worker return ( 2999*da0073e9SAndroid Build Coastguard Worker a_lifted_leaf((4, y), 3) 3000*da0073e9SAndroid Build Coastguard Worker + a_lifted_leaf((3, 4), 5) 3001*da0073e9SAndroid Build Coastguard Worker + a_lifted_leaf((y, y), y) 3002*da0073e9SAndroid Build Coastguard Worker ) 3003*da0073e9SAndroid Build Coastguard Worker 3004*da0073e9SAndroid Build Coastguard Worker ast_rewriter = RewritingTracer() 3005*da0073e9SAndroid Build Coastguard Worker graph = ast_rewriter.trace(to_trace) 3006*da0073e9SAndroid Build Coastguard Worker traced = GraphModule(ast_rewriter.root, graph, "gm") 3007*da0073e9SAndroid Build Coastguard Worker 3008*da0073e9SAndroid Build Coastguard Worker self.assertIn("a_lifted_leaf", traced.code) 3009*da0073e9SAndroid Build Coastguard Worker self.assertEqual(27, traced(2)) 3010*da0073e9SAndroid Build Coastguard Worker self.assertIs(a_lifted_leaf, real_a_lifed_leaf) 3011*da0073e9SAndroid Build Coastguard Worker 3012*da0073e9SAndroid Build Coastguard Worker def test_ast_rewriter_wrap_fn_directly(self): 3013*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5)) 3014*da0073e9SAndroid Build Coastguard Worker 3015*da0073e9SAndroid Build Coastguard Worker def to_trace(y): 3016*da0073e9SAndroid Build Coastguard Worker return ( 3017*da0073e9SAndroid Build Coastguard Worker a_lifted_leaf2((4, y), 3) 3018*da0073e9SAndroid Build Coastguard Worker + a_lifted_leaf2((3, 4), 5) 3019*da0073e9SAndroid Build Coastguard Worker + a_lifted_leaf2((y, y), y) 3020*da0073e9SAndroid Build Coastguard Worker ) 3021*da0073e9SAndroid Build Coastguard Worker 3022*da0073e9SAndroid Build Coastguard Worker ast_rewriter = RewritingTracer() 3023*da0073e9SAndroid Build Coastguard Worker graph = ast_rewriter.trace(to_trace) 3024*da0073e9SAndroid Build Coastguard Worker traced = GraphModule(ast_rewriter.root, graph, "gm") 3025*da0073e9SAndroid Build Coastguard Worker 3026*da0073e9SAndroid Build Coastguard Worker self.assertIn("a_lifted_leaf2", traced.code) 3027*da0073e9SAndroid Build Coastguard Worker self.assertEqual(27, traced(2)) 3028*da0073e9SAndroid Build Coastguard Worker self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2) 3029*da0073e9SAndroid Build Coastguard Worker 3030*da0073e9SAndroid Build Coastguard Worker def test_profiler_ranges_side_effect(self): 3031*da0073e9SAndroid Build Coastguard Worker g = torch.fx.Graph() 3032*da0073e9SAndroid Build Coastguard Worker handle = g.call_function(torch.ops.profiler._record_function_enter_new, ('test_range',)) 3033*da0073e9SAndroid Build Coastguard Worker g.call_function(torch.ops.profiler._record_function_exit, (handle,)) 3034*da0073e9SAndroid Build Coastguard Worker g.output(None) 3035*da0073e9SAndroid Build Coastguard Worker 3036*da0073e9SAndroid Build Coastguard Worker found_targets = {} 3037*da0073e9SAndroid Build Coastguard Worker for node in g.nodes: 3038*da0073e9SAndroid Build Coastguard Worker if node.op == 'call_function': 3039*da0073e9SAndroid Build Coastguard Worker found_targets.setdefault(node.target) 3040*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3041*da0073e9SAndroid Build Coastguard Worker list(found_targets.keys()), 3042*da0073e9SAndroid Build Coastguard Worker [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit] 3043*da0073e9SAndroid Build Coastguard Worker ) 3044*da0073e9SAndroid Build Coastguard Worker 3045*da0073e9SAndroid Build Coastguard Worker g.eliminate_dead_code() 3046*da0073e9SAndroid Build Coastguard Worker found_targets = {} 3047*da0073e9SAndroid Build Coastguard Worker for node in g.nodes: 3048*da0073e9SAndroid Build Coastguard Worker if node.op == 'call_function': 3049*da0073e9SAndroid Build Coastguard Worker found_targets.setdefault(node.target) 3050*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3051*da0073e9SAndroid Build Coastguard Worker list(found_targets.keys()), 3052*da0073e9SAndroid Build Coastguard Worker [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit] 3053*da0073e9SAndroid Build Coastguard Worker ) 3054*da0073e9SAndroid Build Coastguard Worker 3055*da0073e9SAndroid Build Coastguard Worker def test_ast_rewriter_wrapped_via_decorator(self): 3056*da0073e9SAndroid Build Coastguard Worker class F(torch.nn.Module): 3057*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3058*da0073e9SAndroid Build Coastguard Worker return wrapped_via_decorator(x) 3059*da0073e9SAndroid Build Coastguard Worker 3060*da0073e9SAndroid Build Coastguard Worker ast_rewriter = RewritingTracer() 3061*da0073e9SAndroid Build Coastguard Worker graph = ast_rewriter.trace(F()) 3062*da0073e9SAndroid Build Coastguard Worker traced = GraphModule(ast_rewriter.root, graph, "gm") 3063*da0073e9SAndroid Build Coastguard Worker 3064*da0073e9SAndroid Build Coastguard Worker self.assertIn("wrapped_via_decorator", traced.code) 3065*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(0), 1) 3066*da0073e9SAndroid Build Coastguard Worker self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) 3067*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) 3068*da0073e9SAndroid Build Coastguard Worker 3069*da0073e9SAndroid Build Coastguard Worker def test_ast_rewriter_wrapped_via_decorator_and_transformed(self): 3070*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapped_via_decorator(0), 1) 3071*da0073e9SAndroid Build Coastguard Worker 3072*da0073e9SAndroid Build Coastguard Worker def to_trace(y): 3073*da0073e9SAndroid Build Coastguard Worker return wrapped_via_decorator(y) 3074*da0073e9SAndroid Build Coastguard Worker 3075*da0073e9SAndroid Build Coastguard Worker ast_rewriter = RewritingTracer() 3076*da0073e9SAndroid Build Coastguard Worker graph = ast_rewriter.trace(to_trace) 3077*da0073e9SAndroid Build Coastguard Worker traced = GraphModule(ast_rewriter.root, graph, "gm") 3078*da0073e9SAndroid Build Coastguard Worker 3079*da0073e9SAndroid Build Coastguard Worker self.assertIn("wrapped_via_decorator", traced.code) 3080*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(0), 1) 3081*da0073e9SAndroid Build Coastguard Worker self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) 3082*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) 3083*da0073e9SAndroid Build Coastguard Worker 3084*da0073e9SAndroid Build Coastguard Worker transformed = torch.fx.Transformer(traced).transform() 3085*da0073e9SAndroid Build Coastguard Worker self.assertIn("wrapped_via_decorator", transformed.code) 3086*da0073e9SAndroid Build Coastguard Worker self.assertEqual(transformed(0), 1) 3087*da0073e9SAndroid Build Coastguard Worker self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) 3088*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) 3089*da0073e9SAndroid Build Coastguard Worker 3090*da0073e9SAndroid Build Coastguard Worker def test_ast_rewriter_wrap_with_submodule(self): 3091*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 3092*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3093*da0073e9SAndroid Build Coastguard Worker super().__init__() 3094*da0073e9SAndroid Build Coastguard Worker self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) 3095*da0073e9SAndroid Build Coastguard Worker 3096*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor): 3097*da0073e9SAndroid Build Coastguard Worker return wrapped_with_submodule(x, self.batchnorm1d) 3098*da0073e9SAndroid Build Coastguard Worker 3099*da0073e9SAndroid Build Coastguard Worker ast_rewriter = RewritingTracer() 3100*da0073e9SAndroid Build Coastguard Worker graph = ast_rewriter.trace(M()) 3101*da0073e9SAndroid Build Coastguard Worker traced = GraphModule(ast_rewriter.root, graph, "gm") 3102*da0073e9SAndroid Build Coastguard Worker 3103*da0073e9SAndroid Build Coastguard Worker self.assertIn("wrapped_with_submodule", traced.code) 3104*da0073e9SAndroid Build Coastguard Worker 3105*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3, 2) 3106*da0073e9SAndroid Build Coastguard Worker ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) 3107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_batchnorm1d(input), traced(input)) 3108*da0073e9SAndroid Build Coastguard Worker 3109*da0073e9SAndroid Build Coastguard Worker def test_submodule_manipulation_API(self): 3110*da0073e9SAndroid Build Coastguard Worker class C(torch.nn.Module): 3111*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3112*da0073e9SAndroid Build Coastguard Worker super().__init__() 3113*da0073e9SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d(16, 33, 3, stride=2) 3114*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(2, 3)) 3115*da0073e9SAndroid Build Coastguard Worker 3116*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3117*da0073e9SAndroid Build Coastguard Worker return self.conv(torch.cat([self.param, x])) 3118*da0073e9SAndroid Build Coastguard Worker 3119*da0073e9SAndroid Build Coastguard Worker class B(torch.nn.Module): 3120*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3121*da0073e9SAndroid Build Coastguard Worker super().__init__() 3122*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(100, 200) 3123*da0073e9SAndroid Build Coastguard Worker self.buf = torch.nn.Buffer(torch.randn(2, 3)) 3124*da0073e9SAndroid Build Coastguard Worker self.net_c = C() 3125*da0073e9SAndroid Build Coastguard Worker 3126*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3127*da0073e9SAndroid Build Coastguard Worker return self.linear(torch.cat([self.buf, self.net_c(x)])) 3128*da0073e9SAndroid Build Coastguard Worker 3129*da0073e9SAndroid Build Coastguard Worker class A(torch.nn.Module): 3130*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3131*da0073e9SAndroid Build Coastguard Worker super().__init__() 3132*da0073e9SAndroid Build Coastguard Worker self.net_b = B() 3133*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(2, 3)) 3134*da0073e9SAndroid Build Coastguard Worker 3135*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3136*da0073e9SAndroid Build Coastguard Worker return self.net_b(x) + self.param 3137*da0073e9SAndroid Build Coastguard Worker 3138*da0073e9SAndroid Build Coastguard Worker a = symbolic_trace(A()) 3139*da0073e9SAndroid Build Coastguard Worker 3140*da0073e9SAndroid Build Coastguard Worker a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2)) 3141*da0073e9SAndroid Build Coastguard Worker 3142*da0073e9SAndroid Build Coastguard Worker conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1] 3143*da0073e9SAndroid Build Coastguard Worker with a.graph.inserting_before(conv): 3144*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 3145*da0073e9SAndroid Build Coastguard Worker dropout = a.graph.call_module(module_name="net_b.net_c.dropout", 3146*da0073e9SAndroid Build Coastguard Worker args=conv.args) 3147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 3148*da0073e9SAndroid Build Coastguard Worker 3149*da0073e9SAndroid Build Coastguard Worker conv.replace_all_uses_with(dropout) 3150*da0073e9SAndroid Build Coastguard Worker a.graph.erase_node(conv) 3151*da0073e9SAndroid Build Coastguard Worker a.recompile() 3152*da0073e9SAndroid Build Coastguard Worker 3153*da0073e9SAndroid Build Coastguard Worker def module_exists(gm: GraphModule, path: str) -> bool: 3154*da0073e9SAndroid Build Coastguard Worker return any(path == name for name, _ in gm.named_modules()) 3155*da0073e9SAndroid Build Coastguard Worker 3156*da0073e9SAndroid Build Coastguard Worker def parameter_exists(gm: GraphModule, path: str) -> bool: 3157*da0073e9SAndroid Build Coastguard Worker return (any(path == name for name, _ in gm.named_parameters()) 3158*da0073e9SAndroid Build Coastguard Worker and any(path == name for name in gm.state_dict().keys())) 3159*da0073e9SAndroid Build Coastguard Worker 3160*da0073e9SAndroid Build Coastguard Worker def buffer_exists(gm: GraphModule, path: str) -> bool: 3161*da0073e9SAndroid Build Coastguard Worker return (any(path == name for name, _ in gm.named_buffers()) 3162*da0073e9SAndroid Build Coastguard Worker and any(path == name for name in gm.state_dict().keys())) 3163*da0073e9SAndroid Build Coastguard Worker 3164*da0073e9SAndroid Build Coastguard Worker # Test that we added the "dropout" submodule 3165*da0073e9SAndroid Build Coastguard Worker self.assertTrue(module_exists(a, "net_b.net_c.dropout")) 3166*da0073e9SAndroid Build Coastguard Worker 3167*da0073e9SAndroid Build Coastguard Worker # Test `get_submodule` with an added submodule 3168*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout")) 3169*da0073e9SAndroid Build Coastguard Worker 3170*da0073e9SAndroid Build Coastguard Worker # Test that the "conv" submodule is still there 3171*da0073e9SAndroid Build Coastguard Worker self.assertTrue(module_exists(a, "net_b.net_c.conv")) 3172*da0073e9SAndroid Build Coastguard Worker 3173*da0073e9SAndroid Build Coastguard Worker # Test `get_submodule` with an original module 3174*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(a.get_submodule("net_b.net_c.conv")) 3175*da0073e9SAndroid Build Coastguard Worker 3176*da0073e9SAndroid Build Coastguard Worker # Test that the "conv" node is NOT still there 3177*da0073e9SAndroid Build Coastguard Worker conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"] 3178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv, []) 3179*da0073e9SAndroid Build Coastguard Worker 3180*da0073e9SAndroid Build Coastguard Worker a.delete_submodule("net_b.net_c.conv") 3181*da0073e9SAndroid Build Coastguard Worker 3182*da0073e9SAndroid Build Coastguard Worker # Test that the "conv" submodule is now gone 3183*da0073e9SAndroid Build Coastguard Worker self.assertFalse(module_exists(a, "net_b.net_c.conv")) 3184*da0073e9SAndroid Build Coastguard Worker 3185*da0073e9SAndroid Build Coastguard Worker # Test `get_submodule` with a deleted submodule 3186*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AttributeError, "has no attribute " 3187*da0073e9SAndroid Build Coastguard Worker "`conv`"): 3188*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(a.get_submodule("net_b.net_c.conv")) 3189*da0073e9SAndroid Build Coastguard Worker 3190*da0073e9SAndroid Build Coastguard Worker # Test `get_attr` warnings 3191*da0073e9SAndroid Build Coastguard Worker cat = [n for n in a.graph.nodes if n.target == torch.cat][-1] 3192*da0073e9SAndroid Build Coastguard Worker 3193*da0073e9SAndroid Build Coastguard Worker with a.graph.inserting_before(cat): 3194*da0073e9SAndroid Build Coastguard Worker 3195*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 3196*da0073e9SAndroid Build Coastguard Worker param = a.graph.get_attr(qualified_name="net_b.net_c.param") 3197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 3198*da0073e9SAndroid Build Coastguard Worker 3199*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "Attempted to " 3200*da0073e9SAndroid Build Coastguard Worker "insert a get_attr Node with no " 3201*da0073e9SAndroid Build Coastguard Worker "underlying reference in the " 3202*da0073e9SAndroid Build Coastguard Worker "owning GraphModule"): 3203*da0073e9SAndroid Build Coastguard Worker bad_param = a.graph.get_attr(qualified_name="net_b.param") 3204*da0073e9SAndroid Build Coastguard Worker a.graph.erase_node(bad_param) 3205*da0073e9SAndroid Build Coastguard Worker 3206*da0073e9SAndroid Build Coastguard Worker cat.args = (*cat.args, param) 3207*da0073e9SAndroid Build Coastguard Worker 3208*da0073e9SAndroid Build Coastguard Worker a.recompile() 3209*da0073e9SAndroid Build Coastguard Worker 3210*da0073e9SAndroid Build Coastguard Worker a.graph.lint() 3211*da0073e9SAndroid Build Coastguard Worker 3212*da0073e9SAndroid Build Coastguard Worker # Test `get_parameter` 3213*da0073e9SAndroid Build Coastguard Worker a.get_parameter("net_b.net_c.param") 3214*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AttributeError, "is not an " 3215*da0073e9SAndroid Build Coastguard Worker "nn.Parameter"): 3216*da0073e9SAndroid Build Coastguard Worker a.get_parameter("net_b.buf") 3217*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AttributeError, "has no attribute " 3218*da0073e9SAndroid Build Coastguard Worker "`param`"): 3219*da0073e9SAndroid Build Coastguard Worker a.get_parameter("net_b.param") 3220*da0073e9SAndroid Build Coastguard Worker 3221*da0073e9SAndroid Build Coastguard Worker # Test `get_buffer` 3222*da0073e9SAndroid Build Coastguard Worker a.get_buffer("net_b.buf") 3223*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AttributeError, "is not a " 3224*da0073e9SAndroid Build Coastguard Worker "buffer"): 3225*da0073e9SAndroid Build Coastguard Worker a.get_buffer("net_b.net_c.param") 3226*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AttributeError, "has no attribute " 3227*da0073e9SAndroid Build Coastguard Worker "`buf`"): 3228*da0073e9SAndroid Build Coastguard Worker a.get_buffer("net_b.net_c.buf") 3229*da0073e9SAndroid Build Coastguard Worker 3230*da0073e9SAndroid Build Coastguard Worker # Test non-nested attributes 3231*da0073e9SAndroid Build Coastguard Worker a.get_submodule("") 3232*da0073e9SAndroid Build Coastguard Worker a.get_parameter("param") 3233*da0073e9SAndroid Build Coastguard Worker 3234*da0073e9SAndroid Build Coastguard Worker # Insert some unused submodules 3235*da0073e9SAndroid Build Coastguard Worker a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3)) 3236*da0073e9SAndroid Build Coastguard Worker a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3)) 3237*da0073e9SAndroid Build Coastguard Worker a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2)) 3238*da0073e9SAndroid Build Coastguard Worker a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100)) 3239*da0073e9SAndroid Build Coastguard Worker 3240*da0073e9SAndroid Build Coastguard Worker # Garbage collection 3241*da0073e9SAndroid Build Coastguard Worker a.delete_all_unused_submodules() 3242*da0073e9SAndroid Build Coastguard Worker 3243*da0073e9SAndroid Build Coastguard Worker # Test that all the unused submodules are gone 3244*da0073e9SAndroid Build Coastguard Worker self.assertFalse(module_exists(a, "net_b.embedding")) 3245*da0073e9SAndroid Build Coastguard Worker self.assertFalse(module_exists(a, "net_b.net_c.embedding")) 3246*da0073e9SAndroid Build Coastguard Worker self.assertFalse(module_exists(a, "net_b.net_c.rnn")) 3247*da0073e9SAndroid Build Coastguard Worker self.assertFalse(module_exists(a, "batch_norm_2d")) 3248*da0073e9SAndroid Build Coastguard Worker 3249*da0073e9SAndroid Build Coastguard Worker # Test that we didn't delete any unused Parameters or buffers 3250*da0073e9SAndroid Build Coastguard Worker self.assertTrue(parameter_exists(a, "net_b.net_c.param")) 3251*da0073e9SAndroid Build Coastguard Worker self.assertTrue(buffer_exists(a, "net_b.buf")) 3252*da0073e9SAndroid Build Coastguard Worker 3253*da0073e9SAndroid Build Coastguard Worker a.graph.lint() 3254*da0073e9SAndroid Build Coastguard Worker 3255*da0073e9SAndroid Build Coastguard Worker def test_delete_unused_submodules_leaf(self): 3256*da0073e9SAndroid Build Coastguard Worker class SubModule(torch.nn.Module): 3257*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3258*da0073e9SAndroid Build Coastguard Worker super().__init__() 3259*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(10, 10) 3260*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 3261*da0073e9SAndroid Build Coastguard Worker 3262*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3263*da0073e9SAndroid Build Coastguard Worker x = self.linear(x) 3264*da0073e9SAndroid Build Coastguard Worker x = self.relu(x) 3265*da0073e9SAndroid Build Coastguard Worker return x 3266*da0073e9SAndroid Build Coastguard Worker 3267*da0073e9SAndroid Build Coastguard Worker class Model(torch.nn.Module): 3268*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3269*da0073e9SAndroid Build Coastguard Worker super().__init__() 3270*da0073e9SAndroid Build Coastguard Worker self.submod = SubModule() 3271*da0073e9SAndroid Build Coastguard Worker 3272*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3273*da0073e9SAndroid Build Coastguard Worker x = self.submod(x) 3274*da0073e9SAndroid Build Coastguard Worker return x 3275*da0073e9SAndroid Build Coastguard Worker 3276*da0073e9SAndroid Build Coastguard Worker model = Model() 3277*da0073e9SAndroid Build Coastguard Worker 3278*da0073e9SAndroid Build Coastguard Worker class MyCustomTracer(torch.fx.Tracer): 3279*da0073e9SAndroid Build Coastguard Worker def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: 3280*da0073e9SAndroid Build Coastguard Worker return module_qualified_name == "submod" 3281*da0073e9SAndroid Build Coastguard Worker 3282*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(1, 10) 3283*da0073e9SAndroid Build Coastguard Worker traced_graph = MyCustomTracer().trace(model) 3284*da0073e9SAndroid Build Coastguard Worker gm2 = torch.fx.GraphModule(model, traced_graph) 3285*da0073e9SAndroid Build Coastguard Worker gm2.delete_all_unused_submodules() 3286*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(gm2(inputs), model(inputs)) 3287*da0073e9SAndroid Build Coastguard Worker 3288*da0073e9SAndroid Build Coastguard Worker def test_fx_stateless(self): 3289*da0073e9SAndroid Build Coastguard Worker class MockModule(torch.nn.Module): 3290*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3291*da0073e9SAndroid Build Coastguard Worker super().__init__() 3292*da0073e9SAndroid Build Coastguard Worker self.l1 = torch.nn.Linear(1, 1) 3293*da0073e9SAndroid Build Coastguard Worker self.buffer = torch.nn.Buffer(torch.ones(1)) 3294*da0073e9SAndroid Build Coastguard Worker 3295*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3296*da0073e9SAndroid Build Coastguard Worker return self.l1(x) + self.buffer 3297*da0073e9SAndroid Build Coastguard Worker 3298*da0073e9SAndroid Build Coastguard Worker module = MockModule() 3299*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1, 1)) 3300*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor([[1.0]], requires_grad=True) 3301*da0073e9SAndroid Build Coastguard Worker bias = torch.tensor([0.0], requires_grad=True) 3302*da0073e9SAndroid Build Coastguard Worker buffer = torch.tensor([0.0]) 3303*da0073e9SAndroid Build Coastguard Worker parameters = {'l1.weight': weight, 3304*da0073e9SAndroid Build Coastguard Worker 'l1.bias': bias, 3305*da0073e9SAndroid Build Coastguard Worker 'buffer': buffer} 3306*da0073e9SAndroid Build Coastguard Worker fx_module = torch.fx.symbolic_trace(module) 3307*da0073e9SAndroid Build Coastguard Worker res = torch.func.functional_call(fx_module, parameters, x) 3308*da0073e9SAndroid Build Coastguard Worker res.backward() 3309*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(weight.grad) 3310*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(bias.grad) 3311*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(buffer.grad) 3312*da0073e9SAndroid Build Coastguard Worker # Gradient was not calculated for the module stated and buffers 3313*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.l1.weight.grad) 3314*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.l1.bias.grad) 3315*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.buffer.grad) 3316*da0073e9SAndroid Build Coastguard Worker 3317*da0073e9SAndroid Build Coastguard Worker def test_tracing_graphmodules_as_leaf_submodules(self): 3318*da0073e9SAndroid Build Coastguard Worker class A(torch.nn.Module): 3319*da0073e9SAndroid Build Coastguard Worker def forward(self, t): 3320*da0073e9SAndroid Build Coastguard Worker return t + t 3321*da0073e9SAndroid Build Coastguard Worker 3322*da0073e9SAndroid Build Coastguard Worker class B(torch.nn.Module): 3323*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3324*da0073e9SAndroid Build Coastguard Worker super(type(self), self).__init__() 3325*da0073e9SAndroid Build Coastguard Worker self.calling = False 3326*da0073e9SAndroid Build Coastguard Worker self.called = False 3327*da0073e9SAndroid Build Coastguard Worker 3328*da0073e9SAndroid Build Coastguard Worker def forward(self, t): 3329*da0073e9SAndroid Build Coastguard Worker if self.calling: 3330*da0073e9SAndroid Build Coastguard Worker return t - t 3331*da0073e9SAndroid Build Coastguard Worker else: 3332*da0073e9SAndroid Build Coastguard Worker return t + t 3333*da0073e9SAndroid Build Coastguard Worker 3334*da0073e9SAndroid Build Coastguard Worker def __call__(self, *args): 3335*da0073e9SAndroid Build Coastguard Worker self.called = True 3336*da0073e9SAndroid Build Coastguard Worker self.calling = True 3337*da0073e9SAndroid Build Coastguard Worker return super(type(self), self).__call__(*args) 3338*da0073e9SAndroid Build Coastguard Worker self.calling = False 3339*da0073e9SAndroid Build Coastguard Worker 3340*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 3341*da0073e9SAndroid Build Coastguard Worker def __init__(self, a, b): 3342*da0073e9SAndroid Build Coastguard Worker super().__init__() 3343*da0073e9SAndroid Build Coastguard Worker self.a = a 3344*da0073e9SAndroid Build Coastguard Worker self.b = b 3345*da0073e9SAndroid Build Coastguard Worker 3346*da0073e9SAndroid Build Coastguard Worker def forward(self, t): 3347*da0073e9SAndroid Build Coastguard Worker x = self.a(t) 3348*da0073e9SAndroid Build Coastguard Worker y = self.b(t) 3349*da0073e9SAndroid Build Coastguard Worker return x + y 3350*da0073e9SAndroid Build Coastguard Worker 3351*da0073e9SAndroid Build Coastguard Worker class LeafTracer(Tracer): 3352*da0073e9SAndroid Build Coastguard Worker def is_leaf_module(self, module, name): 3353*da0073e9SAndroid Build Coastguard Worker return True 3354*da0073e9SAndroid Build Coastguard Worker 3355*da0073e9SAndroid Build Coastguard Worker class LeafTracerNotB(Tracer): 3356*da0073e9SAndroid Build Coastguard Worker def is_leaf_module(self, module, name): 3357*da0073e9SAndroid Build Coastguard Worker return False if "b" in name else True 3358*da0073e9SAndroid Build Coastguard Worker 3359*da0073e9SAndroid Build Coastguard Worker # Recompile calls added "for fun", since they 3360*da0073e9SAndroid Build Coastguard Worker # chain __call__ wrappers. 3361*da0073e9SAndroid Build Coastguard Worker 3362*da0073e9SAndroid Build Coastguard Worker # 3363*da0073e9SAndroid Build Coastguard Worker # Test: B as a regular, non-leaf module 3364*da0073e9SAndroid Build Coastguard Worker # 3365*da0073e9SAndroid Build Coastguard Worker a = symbolic_trace(A()) 3366*da0073e9SAndroid Build Coastguard Worker a.recompile() 3367*da0073e9SAndroid Build Coastguard Worker m = M(a, B()) 3368*da0073e9SAndroid Build Coastguard Worker graph = LeafTracerNotB().trace(m) 3369*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(m, graph) 3370*da0073e9SAndroid Build Coastguard Worker gm.recompile() 3371*da0073e9SAndroid Build Coastguard Worker 3372*da0073e9SAndroid Build Coastguard Worker # Test graphmodule/submodule a is not inlined. 3373*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) 3374*da0073e9SAndroid Build Coastguard Worker match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"] 3375*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(match) == 1) 3376*da0073e9SAndroid Build Coastguard Worker 3377*da0073e9SAndroid Build Coastguard Worker # Test submodule b is not treated as leaf. 3378*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(gm, "b")) 3379*da0073e9SAndroid Build Coastguard Worker 3380*da0073e9SAndroid Build Coastguard Worker # Test assert custom __call__ on submodule b was honored. 3381*da0073e9SAndroid Build Coastguard Worker match = [ 3382*da0073e9SAndroid Build Coastguard Worker n 3383*da0073e9SAndroid Build Coastguard Worker for n in gm.graph.nodes 3384*da0073e9SAndroid Build Coastguard Worker if n.op == "call_function" and n.target == operator.sub 3385*da0073e9SAndroid Build Coastguard Worker ] 3386*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(match) == 1) 3387*da0073e9SAndroid Build Coastguard Worker 3388*da0073e9SAndroid Build Coastguard Worker # 3389*da0073e9SAndroid Build Coastguard Worker # Test: B as a regular, leaf module 3390*da0073e9SAndroid Build Coastguard Worker # symbolic_trace should only patch torch.nn.Module.__call__, 3391*da0073e9SAndroid Build Coastguard Worker # which means B.__call__ should still execute 3392*da0073e9SAndroid Build Coastguard Worker # 3393*da0073e9SAndroid Build Coastguard Worker a = symbolic_trace(A()) 3394*da0073e9SAndroid Build Coastguard Worker a.recompile() 3395*da0073e9SAndroid Build Coastguard Worker b = B() 3396*da0073e9SAndroid Build Coastguard Worker m = M(a, b) 3397*da0073e9SAndroid Build Coastguard Worker graph = LeafTracer().trace(m) 3398*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(m, graph) 3399*da0073e9SAndroid Build Coastguard Worker gm.recompile() 3400*da0073e9SAndroid Build Coastguard Worker 3401*da0073e9SAndroid Build Coastguard Worker # Test graphmodule/submodule a is not inlined. 3402*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) 3403*da0073e9SAndroid Build Coastguard Worker match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"] 3404*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(match) == 1) 3405*da0073e9SAndroid Build Coastguard Worker 3406*da0073e9SAndroid Build Coastguard Worker # Test submodule b is leaf: 3407*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module)) 3408*da0073e9SAndroid Build Coastguard Worker match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"] 3409*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(match) == 1) 3410*da0073e9SAndroid Build Coastguard Worker 3411*da0073e9SAndroid Build Coastguard Worker # Test b.__call__ was run 3412*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b.called) 3413*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gm.get_submodule("b").called) 3414*da0073e9SAndroid Build Coastguard Worker 3415*da0073e9SAndroid Build Coastguard Worker # 3416*da0073e9SAndroid Build Coastguard Worker # Test: B as GraphModule leaf 3417*da0073e9SAndroid Build Coastguard Worker # __call__ not honored since symbolic_trace directly invokes forward() 3418*da0073e9SAndroid Build Coastguard Worker # 3419*da0073e9SAndroid Build Coastguard Worker a = symbolic_trace(A()) 3420*da0073e9SAndroid Build Coastguard Worker a.recompile() 3421*da0073e9SAndroid Build Coastguard Worker b = symbolic_trace(B()) 3422*da0073e9SAndroid Build Coastguard Worker b.recompile() 3423*da0073e9SAndroid Build Coastguard Worker m = M(a, b) 3424*da0073e9SAndroid Build Coastguard Worker graph = LeafTracer().trace(m) 3425*da0073e9SAndroid Build Coastguard Worker gm = GraphModule(m, graph) 3426*da0073e9SAndroid Build Coastguard Worker gm.recompile() 3427*da0073e9SAndroid Build Coastguard Worker 3428*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) 3429*da0073e9SAndroid Build Coastguard Worker match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"] 3430*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(match) == 1) 3431*da0073e9SAndroid Build Coastguard Worker 3432*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module)) 3433*da0073e9SAndroid Build Coastguard Worker match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"] 3434*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(match) == 1) 3435*da0073e9SAndroid Build Coastguard Worker 3436*da0073e9SAndroid Build Coastguard Worker def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool): 3437*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 3438*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3439*da0073e9SAndroid Build Coastguard Worker super().__init__() 3440*da0073e9SAndroid Build Coastguard Worker self.my_buff = torch.nn.Buffer(torch.rand(3, 4)) 3441*da0073e9SAndroid Build Coastguard Worker self.register_parameter( 3442*da0073e9SAndroid Build Coastguard Worker "my_param", torch.nn.Parameter(torch.rand(3, 4)) 3443*da0073e9SAndroid Build Coastguard Worker ) 3444*da0073e9SAndroid Build Coastguard Worker 3445*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3446*da0073e9SAndroid Build Coastguard Worker return x + self.my_buff + self.my_param 3447*da0073e9SAndroid Build Coastguard Worker 3448*da0073e9SAndroid Build Coastguard Worker mod = MyModule() 3449*da0073e9SAndroid Build Coastguard Worker mod_traced = symbolic_trace(mod) 3450*da0073e9SAndroid Build Coastguard Worker 3451*da0073e9SAndroid Build Coastguard Worker # Create new GraphModule based on original, either w/ dict or root module. 3452*da0073e9SAndroid Build Coastguard Worker orig_buff = mod_traced.get_buffer("my_buff") 3453*da0073e9SAndroid Build Coastguard Worker orig_param = mod_traced.get_parameter("my_param") 3454*da0073e9SAndroid Build Coastguard Worker mod_traced_new = GraphModule( 3455*da0073e9SAndroid Build Coastguard Worker {"my_buff": orig_buff, "my_param": orig_param} if use_dict_init else mod, 3456*da0073e9SAndroid Build Coastguard Worker mod_traced.graph, 3457*da0073e9SAndroid Build Coastguard Worker ) 3458*da0073e9SAndroid Build Coastguard Worker 3459*da0073e9SAndroid Build Coastguard Worker # Check that both my_buff and my_param are found and the same. 3460*da0073e9SAndroid Build Coastguard Worker try: 3461*da0073e9SAndroid Build Coastguard Worker new_buff = mod_traced_new.get_buffer("my_buff") 3462*da0073e9SAndroid Build Coastguard Worker except Exception: 3463*da0073e9SAndroid Build Coastguard Worker self.fail("Did not find my_buff") 3464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(orig_buff, new_buff) 3465*da0073e9SAndroid Build Coastguard Worker 3466*da0073e9SAndroid Build Coastguard Worker try: 3467*da0073e9SAndroid Build Coastguard Worker new_param = mod_traced_new.get_parameter("my_param") 3468*da0073e9SAndroid Build Coastguard Worker except Exception: 3469*da0073e9SAndroid Build Coastguard Worker self.fail("Did not find my_param") 3470*da0073e9SAndroid Build Coastguard Worker self.assertEqual(orig_param, new_param) 3471*da0073e9SAndroid Build Coastguard Worker 3472*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 3473*da0073e9SAndroid Build Coastguard Worker orig_out = mod_traced(x) 3474*da0073e9SAndroid Build Coastguard Worker submodules_out = mod_traced_new(x) 3475*da0073e9SAndroid Build Coastguard Worker 3476*da0073e9SAndroid Build Coastguard Worker self.assertEqual(orig_out, submodules_out) 3477*da0073e9SAndroid Build Coastguard Worker 3478*da0073e9SAndroid Build Coastguard Worker def test_graph_module_init_buffer_param_copied_dict_init(self): 3479*da0073e9SAndroid Build Coastguard Worker self._test_graph_module_init_buffer_param_copied(use_dict_init=True) 3480*da0073e9SAndroid Build Coastguard Worker 3481*da0073e9SAndroid Build Coastguard Worker def test_graph_module_init_buffer_param_copied_mod_init(self): 3482*da0073e9SAndroid Build Coastguard Worker self._test_graph_module_init_buffer_param_copied(use_dict_init=False) 3483*da0073e9SAndroid Build Coastguard Worker 3484*da0073e9SAndroid Build Coastguard Worker def test_annotations_with_no_forward_references(self): 3485*da0073e9SAndroid Build Coastguard Worker class A: 3486*da0073e9SAndroid Build Coastguard Worker def __call__(self, x: torch.Tensor): 3487*da0073e9SAndroid Build Coastguard Worker return torch.add(x, x) 3488*da0073e9SAndroid Build Coastguard Worker 3489*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 3490*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, a: A) -> torch.Tensor: 3491*da0073e9SAndroid Build Coastguard Worker return a(x) 3492*da0073e9SAndroid Build Coastguard Worker 3493*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) 3494*da0073e9SAndroid Build Coastguard Worker 3495*da0073e9SAndroid Build Coastguard Worker def test_annotations_with_forward_references(self): 3496*da0073e9SAndroid Build Coastguard Worker class A: 3497*da0073e9SAndroid Build Coastguard Worker def __call__(self, x: torch.Tensor): 3498*da0073e9SAndroid Build Coastguard Worker return torch.add(x, x) 3499*da0073e9SAndroid Build Coastguard Worker 3500*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 3501*da0073e9SAndroid Build Coastguard Worker def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor': 3502*da0073e9SAndroid Build Coastguard Worker return a(x) 3503*da0073e9SAndroid Build Coastguard Worker 3504*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) 3505*da0073e9SAndroid Build Coastguard Worker 3506*da0073e9SAndroid Build Coastguard Worker def test_annotations_with_non_torch_reference_and_no_internal_forward_references(self): 3507*da0073e9SAndroid Build Coastguard Worker class A: 3508*da0073e9SAndroid Build Coastguard Worker def __call__(self, x: torch.Tensor): 3509*da0073e9SAndroid Build Coastguard Worker return torch.add(x, x) 3510*da0073e9SAndroid Build Coastguard Worker 3511*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 3512*da0073e9SAndroid Build Coastguard Worker def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor: 3513*da0073e9SAndroid Build Coastguard Worker return a(x[0]) 3514*da0073e9SAndroid Build Coastguard Worker 3515*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) 3516*da0073e9SAndroid Build Coastguard Worker 3517*da0073e9SAndroid Build Coastguard Worker def test_annotations_with_non_torch_reference_and_internal_forward_references(self): 3518*da0073e9SAndroid Build Coastguard Worker class A: 3519*da0073e9SAndroid Build Coastguard Worker def __call__(self, x: torch.Tensor): 3520*da0073e9SAndroid Build Coastguard Worker return torch.add(x, x) 3521*da0073e9SAndroid Build Coastguard Worker 3522*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 3523*da0073e9SAndroid Build Coastguard Worker def forward(self, x: List['torch.Tensor'], a: A) -> 'torch.Tensor': 3524*da0073e9SAndroid Build Coastguard Worker return a(x)[0] 3525*da0073e9SAndroid Build Coastguard Worker 3526*da0073e9SAndroid Build Coastguard Worker self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) 3527*da0073e9SAndroid Build Coastguard Worker 3528*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(sys.version_info < (3, 7), "`__future__` feature " 3529*da0073e9SAndroid Build Coastguard Worker "`annotations` is not defined in Python <3.7") 3530*da0073e9SAndroid Build Coastguard Worker def test_annotation_with_future(self): 3531*da0073e9SAndroid Build Coastguard Worker try: 3532*da0073e9SAndroid Build Coastguard Worker import fx.test_future # noqa: F401 3533*da0073e9SAndroid Build Coastguard Worker finally: 3534*da0073e9SAndroid Build Coastguard Worker del sys.modules["__future__"] 3535*da0073e9SAndroid Build Coastguard Worker 3536*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11") 3537*da0073e9SAndroid Build Coastguard Worker def test_annotations_empty_tuple(self): 3538*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 3539*da0073e9SAndroid Build Coastguard Worker def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]): 3540*da0073e9SAndroid Build Coastguard Worker return "foo" 3541*da0073e9SAndroid Build Coastguard Worker 3542*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(Foo()) 3543*da0073e9SAndroid Build Coastguard Worker 3544*da0073e9SAndroid Build Coastguard Worker x = () 3545*da0073e9SAndroid Build Coastguard Worker y = ("bar", ()) 3546*da0073e9SAndroid Build Coastguard Worker 3547*da0073e9SAndroid Build Coastguard Worker traced(x, y) 3548*da0073e9SAndroid Build Coastguard Worker 3549*da0073e9SAndroid Build Coastguard Worker FileCheck().check("_Tuple[()]") \ 3550*da0073e9SAndroid Build Coastguard Worker .check("typing_Tuple[str,typing_Tuple[()]]") \ 3551*da0073e9SAndroid Build Coastguard Worker .run(traced.code) 3552*da0073e9SAndroid Build Coastguard Worker 3553*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(traced) 3554*da0073e9SAndroid Build Coastguard Worker 3555*da0073e9SAndroid Build Coastguard Worker scripted(x, y) 3556*da0073e9SAndroid Build Coastguard Worker 3557*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Tuple[()]") \ 3558*da0073e9SAndroid Build Coastguard Worker .check("Tuple[str, Tuple[()]]") \ 3559*da0073e9SAndroid Build Coastguard Worker .run(scripted.code) 3560*da0073e9SAndroid Build Coastguard Worker 3561*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108") 3562*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(sys.version_info >= (3, 10), "Does not work on Python-3.10") 3563*da0073e9SAndroid Build Coastguard Worker def test_assert(self): 3564*da0073e9SAndroid Build Coastguard Worker def f(x): 3565*da0073e9SAndroid Build Coastguard Worker assert x > 1 3566*da0073e9SAndroid Build Coastguard Worker return x + 1 3567*da0073e9SAndroid Build Coastguard Worker try: 3568*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.trace_asserts = True 3569*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(f) 3570*da0073e9SAndroid Build Coastguard Worker finally: 3571*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.trace_asserts = False 3572*da0073e9SAndroid Build Coastguard Worker 3573*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(2), traced(2)) 3574*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 3575*da0073e9SAndroid Build Coastguard Worker traced(0) 3576*da0073e9SAndroid Build Coastguard Worker 3577*da0073e9SAndroid Build Coastguard Worker def test_pytree(self): 3578*da0073e9SAndroid Build Coastguard Worker # Used to test that you can use your own placeholder class 3579*da0073e9SAndroid Build Coastguard Worker class PHTest(PHBase): 3580*da0073e9SAndroid Build Coastguard Worker pass 3581*da0073e9SAndroid Build Coastguard Worker 3582*da0073e9SAndroid Build Coastguard Worker def f_sum(x): 3583*da0073e9SAndroid Build Coastguard Worker return sum(x) 3584*da0073e9SAndroid Build Coastguard Worker 3585*da0073e9SAndroid Build Coastguard Worker def f_sum_dict(x): 3586*da0073e9SAndroid Build Coastguard Worker out = 0 3587*da0073e9SAndroid Build Coastguard Worker for v in x.values(): 3588*da0073e9SAndroid Build Coastguard Worker out += v 3589*da0073e9SAndroid Build Coastguard Worker return out 3590*da0073e9SAndroid Build Coastguard Worker 3591*da0073e9SAndroid Build Coastguard Worker def f_dict_list_map(x): 3592*da0073e9SAndroid Build Coastguard Worker new_dict = {} 3593*da0073e9SAndroid Build Coastguard Worker for k, v in x.items(): 3594*da0073e9SAndroid Build Coastguard Worker new_dict[k] = [i + 1 for i in v] 3595*da0073e9SAndroid Build Coastguard Worker return new_dict 3596*da0073e9SAndroid Build Coastguard Worker 3597*da0073e9SAndroid Build Coastguard Worker def f_dict_add(x): 3598*da0073e9SAndroid Build Coastguard Worker return x['a'] + sum(x['z']) 3599*da0073e9SAndroid Build Coastguard Worker 3600*da0073e9SAndroid Build Coastguard Worker def f_namedtuple_add(x): 3601*da0073e9SAndroid Build Coastguard Worker return x.x + x.y 3602*da0073e9SAndroid Build Coastguard Worker 3603*da0073e9SAndroid Build Coastguard Worker pytree.register_pytree_node( 3604*da0073e9SAndroid Build Coastguard Worker Foo, 3605*da0073e9SAndroid Build Coastguard Worker lambda x: ([x.a, x.b], None), 3606*da0073e9SAndroid Build Coastguard Worker lambda x, _: Foo(x[0], x[1]), 3607*da0073e9SAndroid Build Coastguard Worker ) 3608*da0073e9SAndroid Build Coastguard Worker fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b]) 3609*da0073e9SAndroid Build Coastguard Worker 3610*da0073e9SAndroid Build Coastguard Worker def f_custom(x): 3611*da0073e9SAndroid Build Coastguard Worker return x.a + x.b 3612*da0073e9SAndroid Build Coastguard Worker 3613*da0073e9SAndroid Build Coastguard Worker def f_custom_dict(x): 3614*da0073e9SAndroid Build Coastguard Worker return f_sum_dict(x.a) + x.b 3615*da0073e9SAndroid Build Coastguard Worker 3616*da0073e9SAndroid Build Coastguard Worker def f_return_custom(x): 3617*da0073e9SAndroid Build Coastguard Worker return Foo(x.b, x.a) 3618*da0073e9SAndroid Build Coastguard Worker 3619*da0073e9SAndroid Build Coastguard Worker tests = [ 3620*da0073e9SAndroid Build Coastguard Worker (f_sum, [PH, PH, PH]), 3621*da0073e9SAndroid Build Coastguard Worker (f_sum, []), 3622*da0073e9SAndroid Build Coastguard Worker (f_sum, [PHTest(), PHTest(), PHTest()]), 3623*da0073e9SAndroid Build Coastguard Worker (f_sum_dict, {'a': PH, 'b': PH, 'c': PH}), 3624*da0073e9SAndroid Build Coastguard Worker (f_dict_list_map, {'a': (PH, PH), 'b': [PH], 'c': []}), 3625*da0073e9SAndroid Build Coastguard Worker (f_dict_list_map, {5: (PH, PH, PH)}), 3626*da0073e9SAndroid Build Coastguard Worker (f_dict_add, {'a': PH, 'z': (PH, PH, PH)}), 3627*da0073e9SAndroid Build Coastguard Worker (f_dict_add, {'a': PH, 'z': []}), 3628*da0073e9SAndroid Build Coastguard Worker (f_custom, Foo(PH, PH)), 3629*da0073e9SAndroid Build Coastguard Worker (f_custom, Foo(PH, 3)), 3630*da0073e9SAndroid Build Coastguard Worker (f_custom_dict, Foo({'a': PH, 'b': PH}, PH)), 3631*da0073e9SAndroid Build Coastguard Worker # (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees 3632*da0073e9SAndroid Build Coastguard Worker (f_namedtuple_add, Point(PH, PH)), 3633*da0073e9SAndroid Build Coastguard Worker ] 3634*da0073e9SAndroid Build Coastguard Worker 3635*da0073e9SAndroid Build Coastguard Worker def verify_pytree(f, inp): 3636*da0073e9SAndroid Build Coastguard Worker val = pytree.tree_map(lambda x: torch.randn(3) if isinstance(x, PHBase) else x, inp) 3637*da0073e9SAndroid Build Coastguard Worker num_flat_args = len(pytree.tree_leaves(inp)) 3638*da0073e9SAndroid Build Coastguard Worker orig_out = f(val) 3639*da0073e9SAndroid Build Coastguard Worker nf = symbolic_trace(f, concrete_args={'x': inp}) 3640*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf(val), orig_out) 3641*da0073e9SAndroid Build Coastguard Worker 3642*da0073e9SAndroid Build Coastguard Worker bare_fx = GraphModule({}, copy.deepcopy(nf.graph)) 3643*da0073e9SAndroid Build Coastguard Worker bare_fx.graph.set_codegen(CodeGen()) 3644*da0073e9SAndroid Build Coastguard Worker bare_fx.recompile() 3645*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out) 3646*da0073e9SAndroid Build Coastguard Worker 3647*da0073e9SAndroid Build Coastguard Worker assert num_flat_args == 0 or "tree_flatten_spec" in nf.code 3648*da0073e9SAndroid Build Coastguard Worker assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args 3649*da0073e9SAndroid Build Coastguard Worker 3650*da0073e9SAndroid Build Coastguard Worker nf = symbolic_trace(nf) 3651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf(val), orig_out) 3652*da0073e9SAndroid Build Coastguard Worker assert "tree_flatten_spec" not in nf.code 3653*da0073e9SAndroid Build Coastguard Worker assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == 1 3654*da0073e9SAndroid Build Coastguard Worker 3655*da0073e9SAndroid Build Coastguard Worker nf = symbolic_trace(nf, concrete_args={'x': inp}) 3656*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf(val), orig_out) 3657*da0073e9SAndroid Build Coastguard Worker assert num_flat_args == 0 or "tree_flatten_spec" in nf.code 3658*da0073e9SAndroid Build Coastguard Worker assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args 3659*da0073e9SAndroid Build Coastguard Worker 3660*da0073e9SAndroid Build Coastguard Worker pickled = pickle.dumps(nf) 3661*da0073e9SAndroid Build Coastguard Worker nf = pickle.loads(pickled) 3662*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf(val), orig_out) 3663*da0073e9SAndroid Build Coastguard Worker 3664*da0073e9SAndroid Build Coastguard Worker for f, inp in tests: 3665*da0073e9SAndroid Build Coastguard Worker verify_pytree(f, inp) 3666*da0073e9SAndroid Build Coastguard Worker 3667*da0073e9SAndroid Build Coastguard Worker def test_pytree_concrete(self): 3668*da0073e9SAndroid Build Coastguard Worker def f(b, a): 3669*da0073e9SAndroid Build Coastguard Worker if b: 3670*da0073e9SAndroid Build Coastguard Worker return a['a'] 3671*da0073e9SAndroid Build Coastguard Worker else: 3672*da0073e9SAndroid Build Coastguard Worker return a['z'] 3673*da0073e9SAndroid Build Coastguard Worker 3674*da0073e9SAndroid Build Coastguard Worker inp = {'a': {'a': PH, 'z': PH}, 'b': True} 3675*da0073e9SAndroid Build Coastguard Worker nf = symbolic_trace(f, concrete_args=inp) 3676*da0073e9SAndroid Build Coastguard Worker val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp) 3677*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf(**val), f(**val)) 3678*da0073e9SAndroid Build Coastguard Worker 3679*da0073e9SAndroid Build Coastguard Worker nf = symbolic_trace(nf) 3680*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf(**val), f(**val)) 3681*da0073e9SAndroid Build Coastguard Worker 3682*da0073e9SAndroid Build Coastguard Worker def test_metadata_on_ph(self): 3683*da0073e9SAndroid Build Coastguard Worker def f_sum(a: int, b: int) -> int: 3684*da0073e9SAndroid Build Coastguard Worker return a + b 3685*da0073e9SAndroid Build Coastguard Worker 3686*da0073e9SAndroid Build Coastguard Worker # Due to unflattening of dict, the batch argument 3687*da0073e9SAndroid Build Coastguard Worker # will be split into two separate nodes with the names 3688*da0073e9SAndroid Build Coastguard Worker # "batch_1" and "batch_2", referring to the keys 3689*da0073e9SAndroid Build Coastguard Worker # "f1" and "f2" respectively in the dict. 3690*da0073e9SAndroid Build Coastguard Worker def f_dict(a: Dict[str, str]) -> bool: 3691*da0073e9SAndroid Build Coastguard Worker return a["f1"] == a["f2"] 3692*da0073e9SAndroid Build Coastguard Worker 3693*da0073e9SAndroid Build Coastguard Worker def verify_metadata(gm: GraphModule, arg_names: List[str], metadata: List[str]): 3694*da0073e9SAndroid Build Coastguard Worker for node in gm.graph.nodes: 3695*da0073e9SAndroid Build Coastguard Worker if node.op == "placeholder": 3696*da0073e9SAndroid Build Coastguard Worker self.assertTrue(node.name in arg_names) 3697*da0073e9SAndroid Build Coastguard Worker self.assertTrue(node.ph_key in metadata) 3698*da0073e9SAndroid Build Coastguard Worker 3699*da0073e9SAndroid Build Coastguard Worker verify_metadata( 3700*da0073e9SAndroid Build Coastguard Worker gm=symbolic_trace( 3701*da0073e9SAndroid Build Coastguard Worker f_sum, 3702*da0073e9SAndroid Build Coastguard Worker concrete_args={"a": PHWithMeta(ph_key="a"), "b": PHWithMeta(ph_key="b")} 3703*da0073e9SAndroid Build Coastguard Worker ), 3704*da0073e9SAndroid Build Coastguard Worker arg_names=["a_1", "b_1"], 3705*da0073e9SAndroid Build Coastguard Worker metadata=["a", "b"] 3706*da0073e9SAndroid Build Coastguard Worker ) 3707*da0073e9SAndroid Build Coastguard Worker verify_metadata( 3708*da0073e9SAndroid Build Coastguard Worker gm=symbolic_trace( 3709*da0073e9SAndroid Build Coastguard Worker f_dict, 3710*da0073e9SAndroid Build Coastguard Worker concrete_args={"a": {"f1": PHWithMeta(ph_key="f1"), "f2": PHWithMeta(ph_key="f2")}} 3711*da0073e9SAndroid Build Coastguard Worker ), 3712*da0073e9SAndroid Build Coastguard Worker arg_names=["a_1", "a_2"], 3713*da0073e9SAndroid Build Coastguard Worker metadata=["f1", "f2"] 3714*da0073e9SAndroid Build Coastguard Worker ) 3715*da0073e9SAndroid Build Coastguard Worker 3716*da0073e9SAndroid Build Coastguard Worker # Ensures that tags on nodes are NOT overwritten by PH attributes with same attr name (tag) 3717*da0073e9SAndroid Build Coastguard Worker class TaggingTracer(Tracer): 3718*da0073e9SAndroid Build Coastguard Worker def create_node(self, kind : str, target : Union[str, Callable], 3719*da0073e9SAndroid Build Coastguard Worker args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None, 3720*da0073e9SAndroid Build Coastguard Worker type_expr : Optional[Any] = None) -> Node: 3721*da0073e9SAndroid Build Coastguard Worker n = super().create_node(kind, target, args, kwargs, name) 3722*da0073e9SAndroid Build Coastguard Worker n.tag = "foo" 3723*da0073e9SAndroid Build Coastguard Worker return n 3724*da0073e9SAndroid Build Coastguard Worker 3725*da0073e9SAndroid Build Coastguard Worker class PHWithTag(PHBase): 3726*da0073e9SAndroid Build Coastguard Worker def __init__(self, tag: str): 3727*da0073e9SAndroid Build Coastguard Worker super().__init__() 3728*da0073e9SAndroid Build Coastguard Worker 3729*da0073e9SAndroid Build Coastguard Worker self.tag = tag 3730*da0073e9SAndroid Build Coastguard Worker 3731*da0073e9SAndroid Build Coastguard Worker g = TaggingTracer().trace(f_sum, concrete_args={"a": PHWithTag(tag="bar"), "b": PHWithTag(tag="bar")}) 3732*da0073e9SAndroid Build Coastguard Worker for n in g.nodes: 3733*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(n, "tag")) 3734*da0073e9SAndroid Build Coastguard Worker # Ensure that tag is still "foo" and not "bar" (from PHWithTag) 3735*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n.tag, "foo") 3736*da0073e9SAndroid Build Coastguard Worker 3737*da0073e9SAndroid Build Coastguard Worker def test_custom_codegen(self): 3738*da0073e9SAndroid Build Coastguard Worker class ListCodeGen(CodeGen): 3739*da0073e9SAndroid Build Coastguard Worker def gen_fn_def(self, free_vars, maybe_return_annotation): 3740*da0073e9SAndroid Build Coastguard Worker lst_unpack = f""" 3741*da0073e9SAndroid Build Coastguard Workerdef forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: 3742*da0073e9SAndroid Build Coastguard Worker {', '.join(free_vars)} = args_list""" 3743*da0073e9SAndroid Build Coastguard Worker return lst_unpack 3744*da0073e9SAndroid Build Coastguard Worker 3745*da0073e9SAndroid Build Coastguard Worker def additional_globals(self): 3746*da0073e9SAndroid Build Coastguard Worker return [('List', typing.List)] 3747*da0073e9SAndroid Build Coastguard Worker 3748*da0073e9SAndroid Build Coastguard Worker def process_inputs(self, *inputs): 3749*da0073e9SAndroid Build Coastguard Worker assert len(inputs) == 1 3750*da0073e9SAndroid Build Coastguard Worker return inputs[0] 3751*da0073e9SAndroid Build Coastguard Worker 3752*da0073e9SAndroid Build Coastguard Worker def f(a, b): 3753*da0073e9SAndroid Build Coastguard Worker return a + b 3754*da0073e9SAndroid Build Coastguard Worker 3755*da0073e9SAndroid Build Coastguard Worker nf = symbolic_trace(f) 3756*da0073e9SAndroid Build Coastguard Worker vals = [torch.randn(3), torch.randn(3)] 3757*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf(*vals), f(*vals)) 3758*da0073e9SAndroid Build Coastguard Worker 3759*da0073e9SAndroid Build Coastguard Worker nf.graph.set_codegen(ListCodeGen()) 3760*da0073e9SAndroid Build Coastguard Worker nf.recompile() 3761*da0073e9SAndroid Build Coastguard Worker 3762*da0073e9SAndroid Build Coastguard Worker bare_fx = GraphModule({}, copy.deepcopy(nf.graph)) 3763*da0073e9SAndroid Build Coastguard Worker bare_fx.graph.set_codegen(CodeGen()) 3764*da0073e9SAndroid Build Coastguard Worker bare_fx.recompile() 3765*da0073e9SAndroid Build Coastguard Worker 3766*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf(vals), f(*vals)) 3767*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), f(*vals)) 3768*da0073e9SAndroid Build Coastguard Worker 3769*da0073e9SAndroid Build Coastguard Worker ts_f = torch.jit.script(nf) 3770*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf(vals), ts_f(vals)) 3771*da0073e9SAndroid Build Coastguard Worker 3772*da0073e9SAndroid Build Coastguard Worker def test_custom_codegen_with_transformer(self): 3773*da0073e9SAndroid Build Coastguard Worker class ListCodeGen(CodeGen): 3774*da0073e9SAndroid Build Coastguard Worker def gen_fn_def(self, free_vars, maybe_return_annotation): 3775*da0073e9SAndroid Build Coastguard Worker lst_unpack = f""" 3776*da0073e9SAndroid Build Coastguard Workerdef forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: 3777*da0073e9SAndroid Build Coastguard Worker {', '.join(free_vars)} = args_list""" 3778*da0073e9SAndroid Build Coastguard Worker return lst_unpack 3779*da0073e9SAndroid Build Coastguard Worker 3780*da0073e9SAndroid Build Coastguard Worker def additional_globals(self): 3781*da0073e9SAndroid Build Coastguard Worker return [('List', typing.List)] 3782*da0073e9SAndroid Build Coastguard Worker 3783*da0073e9SAndroid Build Coastguard Worker def process_inputs(self, *inputs): 3784*da0073e9SAndroid Build Coastguard Worker assert len(inputs) == 1 3785*da0073e9SAndroid Build Coastguard Worker return inputs[0] 3786*da0073e9SAndroid Build Coastguard Worker 3787*da0073e9SAndroid Build Coastguard Worker def f(a, b): 3788*da0073e9SAndroid Build Coastguard Worker return a + b 3789*da0073e9SAndroid Build Coastguard Worker 3790*da0073e9SAndroid Build Coastguard Worker nf = symbolic_trace(f) 3791*da0073e9SAndroid Build Coastguard Worker vals = [torch.randn(3), torch.randn(3)] 3792*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf(*vals), f(*vals)) 3793*da0073e9SAndroid Build Coastguard Worker 3794*da0073e9SAndroid Build Coastguard Worker nf.graph.set_codegen(ListCodeGen()) 3795*da0073e9SAndroid Build Coastguard Worker nf.recompile() 3796*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf(vals), f(*vals)) 3797*da0073e9SAndroid Build Coastguard Worker 3798*da0073e9SAndroid Build Coastguard Worker transformed_gm = Transformer(nf).transform() 3799*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nf(vals), transformed_gm(vals)) 3800*da0073e9SAndroid Build Coastguard Worker 3801*da0073e9SAndroid Build Coastguard Worker def test_interpreter_with_codegen(self): 3802*da0073e9SAndroid Build Coastguard Worker class ListCodeGen(CodeGen): 3803*da0073e9SAndroid Build Coastguard Worker def gen_fn_def(self, free_vars, maybe_return_annotation): 3804*da0073e9SAndroid Build Coastguard Worker lst_unpack = f""" 3805*da0073e9SAndroid Build Coastguard Workerdef forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: 3806*da0073e9SAndroid Build Coastguard Worker {', '.join(free_vars)} = args_list""" 3807*da0073e9SAndroid Build Coastguard Worker return lst_unpack 3808*da0073e9SAndroid Build Coastguard Worker 3809*da0073e9SAndroid Build Coastguard Worker def additional_globals(self): 3810*da0073e9SAndroid Build Coastguard Worker return [('List', typing.List)] 3811*da0073e9SAndroid Build Coastguard Worker 3812*da0073e9SAndroid Build Coastguard Worker def process_inputs(self, *inputs): 3813*da0073e9SAndroid Build Coastguard Worker assert len(inputs) == 1 3814*da0073e9SAndroid Build Coastguard Worker return inputs[0] 3815*da0073e9SAndroid Build Coastguard Worker 3816*da0073e9SAndroid Build Coastguard Worker def generate_output(self, output_args): 3817*da0073e9SAndroid Build Coastguard Worker return f'return list({repr(output_args)})' 3818*da0073e9SAndroid Build Coastguard Worker 3819*da0073e9SAndroid Build Coastguard Worker def process_outputs(self, outputs): 3820*da0073e9SAndroid Build Coastguard Worker return list(outputs) 3821*da0073e9SAndroid Build Coastguard Worker 3822*da0073e9SAndroid Build Coastguard Worker def f(a, b): 3823*da0073e9SAndroid Build Coastguard Worker a = a + b 3824*da0073e9SAndroid Build Coastguard Worker b = a + b 3825*da0073e9SAndroid Build Coastguard Worker return a, b 3826*da0073e9SAndroid Build Coastguard Worker 3827*da0073e9SAndroid Build Coastguard Worker nf = symbolic_trace(f) 3828*da0073e9SAndroid Build Coastguard Worker vals = [torch.randn(3), torch.randn(3)] 3829*da0073e9SAndroid Build Coastguard Worker nf.graph.set_codegen(ListCodeGen()) 3830*da0073e9SAndroid Build Coastguard Worker nf.recompile() 3831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Interpreter(nf).run(vals), nf(vals)) 3832*da0073e9SAndroid Build Coastguard Worker 3833*da0073e9SAndroid Build Coastguard Worker def test_imul_code_print(self): 3834*da0073e9SAndroid Build Coastguard Worker graph = torch.fx.Graph() 3835*da0073e9SAndroid Build Coastguard Worker a = graph.placeholder("a") 3836*da0073e9SAndroid Build Coastguard Worker b = graph.placeholder("b") 3837*da0073e9SAndroid Build Coastguard Worker graph.call_function(operator.imul, (a, b), {}) 3838*da0073e9SAndroid Build Coastguard Worker graph.output(a) 3839*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.GraphModule({}, graph) 3840*da0073e9SAndroid Build Coastguard Worker gm.recompile() 3841*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gm(2, 3), 6) 3842*da0073e9SAndroid Build Coastguard Worker self.assertIn("a *= b", gm.code) 3843*da0073e9SAndroid Build Coastguard Worker 3844*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_tracer(self): 3845*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 3846*da0073e9SAndroid Build Coastguard Worker return (x + y).relu().sin() 3847*da0073e9SAndroid Build Coastguard Worker 3848*da0073e9SAndroid Build Coastguard Worker tracer = Tracer() 3849*da0073e9SAndroid Build Coastguard Worker tracer_before = copy.deepcopy(tracer) 3850*da0073e9SAndroid Build Coastguard Worker tracer.trace(fn) 3851*da0073e9SAndroid Build Coastguard Worker tracer_after = copy.deepcopy(tracer) 3852*da0073e9SAndroid Build Coastguard Worker 3853*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(tracer.graph), str(tracer_after.graph)) 3854*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not hasattr(tracer_before, 'graph') or str(tracer.graph) != str(tracer_before.graph)) 3855*da0073e9SAndroid Build Coastguard Worker 3856*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_graphmodule(self): 3857*da0073e9SAndroid Build Coastguard Worker m = symbolic_trace(SimpleTest()) 3858*da0073e9SAndroid Build Coastguard Worker m.meta['hello'] = 'world' 3859*da0073e9SAndroid Build Coastguard Worker copy_m = copy.deepcopy(m) 3860*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copy_m.meta['hello'], 'world') 3861*da0073e9SAndroid Build Coastguard Worker 3862*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_no_recursion(self): 3863*da0073e9SAndroid Build Coastguard Worker m = symbolic_trace(SimpleTest()) 3864*da0073e9SAndroid Build Coastguard Worker m.meta['hello'] = m # circular reference 3865*da0073e9SAndroid Build Coastguard Worker copy_m = copy.deepcopy(m) # finishes 3866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(copy_m), id(copy_m.meta['hello'])) 3867*da0073e9SAndroid Build Coastguard Worker 3868*da0073e9SAndroid Build Coastguard Worker def test_enum(self): 3869*da0073e9SAndroid Build Coastguard Worker from enum import Enum 3870*da0073e9SAndroid Build Coastguard Worker 3871*da0073e9SAndroid Build Coastguard Worker class Foo(Enum): 3872*da0073e9SAndroid Build Coastguard Worker A = 1 3873*da0073e9SAndroid Build Coastguard Worker B = 2 3874*da0073e9SAndroid Build Coastguard Worker 3875*da0073e9SAndroid Build Coastguard Worker def leaf_fn(arr, enum_val): 3876*da0073e9SAndroid Build Coastguard Worker # Use the raw enum. 3877*da0073e9SAndroid Build Coastguard Worker arr.append(enum_val) 3878*da0073e9SAndroid Build Coastguard Worker return arr[-1].value 3879*da0073e9SAndroid Build Coastguard Worker 3880*da0073e9SAndroid Build Coastguard Worker def foo(x): 3881*da0073e9SAndroid Build Coastguard Worker # Pass the enum as argument. 3882*da0073e9SAndroid Build Coastguard Worker return leaf_fn(x, Foo.A) 3883*da0073e9SAndroid Build Coastguard Worker 3884*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(foo) 3885*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo([]), traced([])) 3886*da0073e9SAndroid Build Coastguard Worker 3887*da0073e9SAndroid Build Coastguard Worker def test_insert_arg(self): 3888*da0073e9SAndroid Build Coastguard Worker m = symbolic_trace(SimpleTest()) 3889*da0073e9SAndroid Build Coastguard Worker m.buf = torch.nn.Buffer(torch.tensor(0)) 3890*da0073e9SAndroid Build Coastguard Worker output_node = next(iter(reversed(m.graph.nodes))) 3891*da0073e9SAndroid Build Coastguard Worker with m.graph.inserting_before(output_node): 3892*da0073e9SAndroid Build Coastguard Worker a = m.graph.get_attr("buf") 3893*da0073e9SAndroid Build Coastguard Worker r = len(output_node.args) 3894*da0073e9SAndroid Build Coastguard Worker output_node.insert_arg(0, a) 3895*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(output_node.args), r + 1) 3896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(a.users), 1) 3897*da0073e9SAndroid Build Coastguard Worker self.assertIs(output_node.args[0], a) 3898*da0073e9SAndroid Build Coastguard Worker self.assertIs(next(iter(a.users.keys())), output_node) 3899*da0073e9SAndroid Build Coastguard Worker output_node.insert_arg(2, a) 3900*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(output_node.args), r + 2) 3901*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(a.users), 1) 3902*da0073e9SAndroid Build Coastguard Worker self.assertIs(output_node.args[2], a) 3903*da0073e9SAndroid Build Coastguard Worker self.assertIs(next(iter(a.users.keys())), output_node) 3904*da0073e9SAndroid Build Coastguard Worker m.graph.lint() 3905*da0073e9SAndroid Build Coastguard Worker 3906*da0073e9SAndroid Build Coastguard Worker def test_delete_unused_values(self): 3907*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.proxy_tensor import make_fx 3908*da0073e9SAndroid Build Coastguard Worker 3909*da0073e9SAndroid Build Coastguard Worker # disable mutable checking temporarily 3910*da0073e9SAndroid Build Coastguard Worker orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations 3911*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.check_mutable_operations = False 3912*da0073e9SAndroid Build Coastguard Worker 3913*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d): 3914*da0073e9SAndroid Build Coastguard Worker x = a + b 3915*da0073e9SAndroid Build Coastguard Worker y = c + d 3916*da0073e9SAndroid Build Coastguard Worker y.copy_(x) 3917*da0073e9SAndroid Build Coastguard Worker x = torch.relu(x) 3918*da0073e9SAndroid Build Coastguard Worker return x 3919*da0073e9SAndroid Build Coastguard Worker 3920*da0073e9SAndroid Build Coastguard Worker a, b, c, d = (torch.randn(2, 4, requires_grad=False) for _ in range(4)) 3921*da0073e9SAndroid Build Coastguard Worker fx_fn = make_fx(fn)(a, b, c, d) 3922*da0073e9SAndroid Build Coastguard Worker print(fx_fn) 3923*da0073e9SAndroid Build Coastguard Worker 3924*da0073e9SAndroid Build Coastguard Worker fx_fn.graph.eliminate_dead_code() 3925*da0073e9SAndroid Build Coastguard Worker py_code = fx_fn.recompile() 3926*da0073e9SAndroid Build Coastguard Worker self.assertTrue("copy_ = torch.ops.aten.copy_.default" in py_code.src) 3927*da0073e9SAndroid Build Coastguard Worker self.assertTrue("copy_ = None" in py_code.src) 3928*da0073e9SAndroid Build Coastguard Worker 3929*da0073e9SAndroid Build Coastguard Worker # recorver mutable checking flag 3930*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag 3931*da0073e9SAndroid Build Coastguard Worker 3932*da0073e9SAndroid Build Coastguard Workerdef run_getitem_target(): 3933*da0073e9SAndroid Build Coastguard Worker from torch.fx._symbolic_trace import _wrapped_methods_to_patch 3934*da0073e9SAndroid Build Coastguard Worker _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) 3935*da0073e9SAndroid Build Coastguard Worker try: 3936*da0073e9SAndroid Build Coastguard Worker TestFX().getitem_inner() 3937*da0073e9SAndroid Build Coastguard Worker finally: 3938*da0073e9SAndroid Build Coastguard Worker _wrapped_methods_to_patch.pop() 3939*da0073e9SAndroid Build Coastguard Worker 3940*da0073e9SAndroid Build Coastguard Worker 3941*da0073e9SAndroid Build Coastguard Workerclass TestOperatorSignatures(JitTestCase): 3942*da0073e9SAndroid Build Coastguard Worker def setUp(self): 3943*da0073e9SAndroid Build Coastguard Worker # Checking for mutable operations whil tracing is feature flagged 3944*da0073e9SAndroid Build Coastguard Worker # Enable it in testing but not by default 3945*da0073e9SAndroid Build Coastguard Worker self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations 3946*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.check_mutable_operations = True 3947*da0073e9SAndroid Build Coastguard Worker 3948*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 3949*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag 3950*da0073e9SAndroid Build Coastguard Worker 3951*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3952*da0073e9SAndroid Build Coastguard Worker @ops(op_db, allowed_dtypes=(torch.float,)) 3953*da0073e9SAndroid Build Coastguard Worker def test_get_torch_func_signature_exhaustive(self, device, dtype, op): 3954*da0073e9SAndroid Build Coastguard Worker if not isinstance(op.op, types.BuiltinFunctionType): 3955*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("This path doesn't work on Python functions") 3956*da0073e9SAndroid Build Coastguard Worker sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) 3957*da0073e9SAndroid Build Coastguard Worker schemas = get_signature_for_torch_op(op.op) 3958*da0073e9SAndroid Build Coastguard Worker if not schemas: 3959*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('No Schemas Returned') 3960*da0073e9SAndroid Build Coastguard Worker for sample_input in sample_inputs_itr: 3961*da0073e9SAndroid Build Coastguard Worker # Iterate through overloads until we hit a match. If we exit this 3962*da0073e9SAndroid Build Coastguard Worker # loop via `else`, we haven't found a match 3963*da0073e9SAndroid Build Coastguard Worker for schema in schemas: 3964*da0073e9SAndroid Build Coastguard Worker try: 3965*da0073e9SAndroid Build Coastguard Worker bound_args = schema.bind(sample_input.input, *sample_input.args, **sample_input.kwargs) 3966*da0073e9SAndroid Build Coastguard Worker bound_args.apply_defaults() 3967*da0073e9SAndroid Build Coastguard Worker op(*bound_args.args, **bound_args.kwargs) 3968*da0073e9SAndroid Build Coastguard Worker break 3969*da0073e9SAndroid Build Coastguard Worker except TypeError as e: 3970*da0073e9SAndroid Build Coastguard Worker pass 3971*da0073e9SAndroid Build Coastguard Worker else: 3972*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Did not match any schemas for op {op.name}!') 3973*da0073e9SAndroid Build Coastguard Worker 3974*da0073e9SAndroid Build Coastguard Worker 3975*da0073e9SAndroid Build Coastguard Workerclass TestFXAPIBackwardCompatibility(JitTestCase): 3976*da0073e9SAndroid Build Coastguard Worker def setUp(self): 3977*da0073e9SAndroid Build Coastguard Worker super().setUp() 3978*da0073e9SAndroid Build Coastguard Worker self.maxDiff = None 3979*da0073e9SAndroid Build Coastguard Worker 3980*da0073e9SAndroid Build Coastguard Worker # Checking for mutable operations whil tracing is feature flagged 3981*da0073e9SAndroid Build Coastguard Worker # Enable it in testing but not by default 3982*da0073e9SAndroid Build Coastguard Worker self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations 3983*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.check_mutable_operations = True 3984*da0073e9SAndroid Build Coastguard Worker 3985*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 3986*da0073e9SAndroid Build Coastguard Worker super().tearDown() 3987*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag 3988*da0073e9SAndroid Build Coastguard Worker 3989*da0073e9SAndroid Build Coastguard Worker 3990*da0073e9SAndroid Build Coastguard Worker def _fn_to_stable_annotation_str(self, obj): 3991*da0073e9SAndroid Build Coastguard Worker """ 3992*da0073e9SAndroid Build Coastguard Worker Unfortunately we have to serialize function signatures manually since 3993*da0073e9SAndroid Build Coastguard Worker serialization for `inspect.Signature` objects is not stable across 3994*da0073e9SAndroid Build Coastguard Worker python versions 3995*da0073e9SAndroid Build Coastguard Worker """ 3996*da0073e9SAndroid Build Coastguard Worker fn_name = torch.typename(obj) 3997*da0073e9SAndroid Build Coastguard Worker 3998*da0073e9SAndroid Build Coastguard Worker signature = inspect.signature(obj) 3999*da0073e9SAndroid Build Coastguard Worker 4000*da0073e9SAndroid Build Coastguard Worker sig_str = f'{fn_name}{signature}' 4001*da0073e9SAndroid Build Coastguard Worker 4002*da0073e9SAndroid Build Coastguard Worker arg_strs = [] 4003*da0073e9SAndroid Build Coastguard Worker for k, v in signature.parameters.items(): 4004*da0073e9SAndroid Build Coastguard Worker maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\ 4005*da0073e9SAndroid Build Coastguard Worker if v.annotation is not inspect.Signature.empty else '' 4006*da0073e9SAndroid Build Coastguard Worker 4007*da0073e9SAndroid Build Coastguard Worker def default_val_str(val): 4008*da0073e9SAndroid Build Coastguard Worker if isinstance(val, (tuple, list)): 4009*da0073e9SAndroid Build Coastguard Worker str_pieces = ['(' if isinstance(val, tuple) else '['] 4010*da0073e9SAndroid Build Coastguard Worker str_pieces.append(', '.join(default_val_str(v) for v in val)) 4011*da0073e9SAndroid Build Coastguard Worker if isinstance(val, tuple) and len(str_pieces) == 2: 4012*da0073e9SAndroid Build Coastguard Worker str_pieces.append(',') 4013*da0073e9SAndroid Build Coastguard Worker str_pieces.append(')' if isinstance(val, tuple) else ']') 4014*da0073e9SAndroid Build Coastguard Worker return ''.join(str_pieces) 4015*da0073e9SAndroid Build Coastguard Worker 4016*da0073e9SAndroid Build Coastguard Worker # Need to fix up some default value strings. 4017*da0073e9SAndroid Build Coastguard Worker # First case: modules. Default module `repr` contains the FS path of the module. 4018*da0073e9SAndroid Build Coastguard Worker # Don't leak that 4019*da0073e9SAndroid Build Coastguard Worker if isinstance(val, types.ModuleType): 4020*da0073e9SAndroid Build Coastguard Worker return f'<module {val.__name__}>' 4021*da0073e9SAndroid Build Coastguard Worker 4022*da0073e9SAndroid Build Coastguard Worker # Second case: callables. Callables (such as lambdas) encode their address in 4023*da0073e9SAndroid Build Coastguard Worker # their string repr. Don't do that 4024*da0073e9SAndroid Build Coastguard Worker if callable(val): 4025*da0073e9SAndroid Build Coastguard Worker return f'<function {val.__name__}>' 4026*da0073e9SAndroid Build Coastguard Worker 4027*da0073e9SAndroid Build Coastguard Worker return str(val) 4028*da0073e9SAndroid Build Coastguard Worker 4029*da0073e9SAndroid Build Coastguard Worker if v.default is not inspect.Signature.empty: 4030*da0073e9SAndroid Build Coastguard Worker default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'" 4031*da0073e9SAndroid Build Coastguard Worker maybe_default = f' = {default_val_str}' 4032*da0073e9SAndroid Build Coastguard Worker else: 4033*da0073e9SAndroid Build Coastguard Worker maybe_default = '' 4034*da0073e9SAndroid Build Coastguard Worker maybe_stars = '' 4035*da0073e9SAndroid Build Coastguard Worker if v.kind == inspect.Parameter.VAR_POSITIONAL: 4036*da0073e9SAndroid Build Coastguard Worker maybe_stars = '*' 4037*da0073e9SAndroid Build Coastguard Worker elif v.kind == inspect.Parameter.VAR_KEYWORD: 4038*da0073e9SAndroid Build Coastguard Worker maybe_stars = '**' 4039*da0073e9SAndroid Build Coastguard Worker arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}') 4040*da0073e9SAndroid Build Coastguard Worker 4041*da0073e9SAndroid Build Coastguard Worker return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\ 4042*da0073e9SAndroid Build Coastguard Worker if signature.return_annotation is not inspect.Signature.empty else '' 4043*da0073e9SAndroid Build Coastguard Worker 4044*da0073e9SAndroid Build Coastguard Worker return f'{fn_name}({", ".join(arg_strs)}){return_annot}' 4045*da0073e9SAndroid Build Coastguard Worker 4046*da0073e9SAndroid Build Coastguard Worker def _annotation_type_to_stable_str(self, t, sig_str): 4047*da0073e9SAndroid Build Coastguard Worker if t is inspect.Signature.empty: 4048*da0073e9SAndroid Build Coastguard Worker return '' 4049*da0073e9SAndroid Build Coastguard Worker 4050*da0073e9SAndroid Build Coastguard Worker # Forward ref 4051*da0073e9SAndroid Build Coastguard Worker if isinstance(t, str): 4052*da0073e9SAndroid Build Coastguard Worker return f"'{t}'" 4053*da0073e9SAndroid Build Coastguard Worker if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef): 4054*da0073e9SAndroid Build Coastguard Worker return t.__forward_arg__ 4055*da0073e9SAndroid Build Coastguard Worker if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef): 4056*da0073e9SAndroid Build Coastguard Worker return t.__forward_arg__ 4057*da0073e9SAndroid Build Coastguard Worker 4058*da0073e9SAndroid Build Coastguard Worker trivial_mappings = { 4059*da0073e9SAndroid Build Coastguard Worker str : 'str', 4060*da0073e9SAndroid Build Coastguard Worker int : 'int', 4061*da0073e9SAndroid Build Coastguard Worker float: 'float', 4062*da0073e9SAndroid Build Coastguard Worker bool: 'bool', 4063*da0073e9SAndroid Build Coastguard Worker torch.dtype: 'torch.dtype', 4064*da0073e9SAndroid Build Coastguard Worker torch.Tensor: 'torch.Tensor', 4065*da0073e9SAndroid Build Coastguard Worker torch.device: 'torch.device', 4066*da0073e9SAndroid Build Coastguard Worker torch.memory_format: 'torch.memory_format', 4067*da0073e9SAndroid Build Coastguard Worker slice: 'slice', 4068*da0073e9SAndroid Build Coastguard Worker torch.nn.Module: 'torch.nn.modules.module.Module', 4069*da0073e9SAndroid Build Coastguard Worker torch.fx.Graph : 'torch.fx.graph.Graph', 4070*da0073e9SAndroid Build Coastguard Worker torch.fx.Node : 'torch.fx.node.Node', 4071*da0073e9SAndroid Build Coastguard Worker torch.fx.Proxy : 'torch.fx.proxy.Proxy', 4072*da0073e9SAndroid Build Coastguard Worker torch.fx.node.Target : 'torch.fx.node.Target', 4073*da0073e9SAndroid Build Coastguard Worker torch.fx.node.Argument : 'torch.fx.node.Argument', 4074*da0073e9SAndroid Build Coastguard Worker torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode', 4075*da0073e9SAndroid Build Coastguard Worker torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule', 4076*da0073e9SAndroid Build Coastguard Worker torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match', 4077*da0073e9SAndroid Build Coastguard Worker Ellipsis : '...', 4078*da0073e9SAndroid Build Coastguard Worker typing.Any: 'Any', 4079*da0073e9SAndroid Build Coastguard Worker type(None): 'NoneType', 4080*da0073e9SAndroid Build Coastguard Worker None: 'None', 4081*da0073e9SAndroid Build Coastguard Worker typing.Iterator: 'Iterator', 4082*da0073e9SAndroid Build Coastguard Worker } 4083*da0073e9SAndroid Build Coastguard Worker 4084*da0073e9SAndroid Build Coastguard Worker mapping = trivial_mappings.get(t, None) 4085*da0073e9SAndroid Build Coastguard Worker if mapping: 4086*da0073e9SAndroid Build Coastguard Worker return mapping 4087*da0073e9SAndroid Build Coastguard Worker 4088*da0073e9SAndroid Build Coastguard Worker # Handle types with contained types 4089*da0073e9SAndroid Build Coastguard Worker contained = getattr(t, '__args__', None) or [] 4090*da0073e9SAndroid Build Coastguard Worker 4091*da0073e9SAndroid Build Coastguard Worker # Callables contain a bare List for arguments 4092*da0073e9SAndroid Build Coastguard Worker contained = t if isinstance(t, list) else contained 4093*da0073e9SAndroid Build Coastguard Worker 4094*da0073e9SAndroid Build Coastguard Worker # Python 3.8 puts type vars into __args__ for unbound types such as Dict 4095*da0073e9SAndroid Build Coastguard Worker if all(isinstance(ct, typing.TypeVar) for ct in contained): 4096*da0073e9SAndroid Build Coastguard Worker contained = [] 4097*da0073e9SAndroid Build Coastguard Worker 4098*da0073e9SAndroid Build Coastguard Worker contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained] 4099*da0073e9SAndroid Build Coastguard Worker contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else '' 4100*da0073e9SAndroid Build Coastguard Worker 4101*da0073e9SAndroid Build Coastguard Worker 4102*da0073e9SAndroid Build Coastguard Worker origin = getattr(t, '__origin__', None) 4103*da0073e9SAndroid Build Coastguard Worker if origin is None: 4104*da0073e9SAndroid Build Coastguard Worker # Unbound types don't have `__origin__` in some Python versions, so fix that up here. 4105*da0073e9SAndroid Build Coastguard Worker origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin 4106*da0073e9SAndroid Build Coastguard Worker 4107*da0073e9SAndroid Build Coastguard Worker if origin in {tuple, typing.Tuple}: 4108*da0073e9SAndroid Build Coastguard Worker return f'Tuple{contained_type_str}' 4109*da0073e9SAndroid Build Coastguard Worker if origin in {typing.Union}: 4110*da0073e9SAndroid Build Coastguard Worker # Annoying hack to detect Optional 4111*da0073e9SAndroid Build Coastguard Worker if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)): 4112*da0073e9SAndroid Build Coastguard Worker not_none_param = contained[0] if contained[0] is not type(None) else contained[1] 4113*da0073e9SAndroid Build Coastguard Worker return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]' 4114*da0073e9SAndroid Build Coastguard Worker return f'Union{contained_type_str}' 4115*da0073e9SAndroid Build Coastguard Worker if origin in {dict, typing.Dict}: 4116*da0073e9SAndroid Build Coastguard Worker return f'Dict{contained_type_str}' 4117*da0073e9SAndroid Build Coastguard Worker if origin in {list, typing.List}: 4118*da0073e9SAndroid Build Coastguard Worker return f'List{contained_type_str}' 4119*da0073e9SAndroid Build Coastguard Worker if origin in {type, typing.Type}: 4120*da0073e9SAndroid Build Coastguard Worker return f'Type{contained_type_str}' 4121*da0073e9SAndroid Build Coastguard Worker if isinstance(t, typing.Callable): 4122*da0073e9SAndroid Build Coastguard Worker if len(contained) > 0 and contained[0] is not Ellipsis: 4123*da0073e9SAndroid Build Coastguard Worker return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]' 4124*da0073e9SAndroid Build Coastguard Worker else: 4125*da0073e9SAndroid Build Coastguard Worker return f'Callable{contained_type_str}' 4126*da0073e9SAndroid Build Coastguard Worker 4127*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.' 4128*da0073e9SAndroid Build Coastguard Worker f'Please add support for this type and confirm with the ' 4129*da0073e9SAndroid Build Coastguard Worker f'FX team that your signature change is valid.') 4130*da0073e9SAndroid Build Coastguard Worker 4131*da0073e9SAndroid Build Coastguard Worker 4132*da0073e9SAndroid Build Coastguard Worker def test_function_back_compat(self): 4133*da0073e9SAndroid Build Coastguard Worker """ 4134*da0073e9SAndroid Build Coastguard Worker Test backward compatibility for function signatures with 4135*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True). Currently this checks for 4136*da0073e9SAndroid Build Coastguard Worker exact signature matches, which may lead to false positives. If this 4137*da0073e9SAndroid Build Coastguard Worker becomes too annoying, we can refine this check to actually parse out 4138*da0073e9SAndroid Build Coastguard Worker the saved schema strings and check if the change is truly backward- 4139*da0073e9SAndroid Build Coastguard Worker incompatible. 4140*da0073e9SAndroid Build Coastguard Worker """ 4141*da0073e9SAndroid Build Coastguard Worker signature_strs = [] 4142*da0073e9SAndroid Build Coastguard Worker 4143*da0073e9SAndroid Build Coastguard Worker for obj in _BACK_COMPAT_OBJECTS: 4144*da0073e9SAndroid Build Coastguard Worker if not isinstance(obj, type): 4145*da0073e9SAndroid Build Coastguard Worker signature_strs.append(self._fn_to_stable_annotation_str(obj)) 4146*da0073e9SAndroid Build Coastguard Worker 4147*da0073e9SAndroid Build Coastguard Worker signature_strs.sort() 4148*da0073e9SAndroid Build Coastguard Worker 4149*da0073e9SAndroid Build Coastguard Worker try: 4150*da0073e9SAndroid Build Coastguard Worker self.assertExpected('\n'.join(signature_strs) + '\n', 'fx_backcompat_function_signatures') 4151*da0073e9SAndroid Build Coastguard Worker except AssertionError as e: 4152*da0073e9SAndroid Build Coastguard Worker msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \ 4153*da0073e9SAndroid Build Coastguard Worker f"as backwards-compatible has experienced a signature change. See the " \ 4154*da0073e9SAndroid Build Coastguard Worker f"above exception context for more information. If this change was " \ 4155*da0073e9SAndroid Build Coastguard Worker f"unintended, please revert it. If it was intended, check with the FX " \ 4156*da0073e9SAndroid Build Coastguard Worker f"team to ensure that the proper deprecation protocols have been followed " \ 4157*da0073e9SAndroid Build Coastguard Worker f"and subsequently --accept the change." 4158*da0073e9SAndroid Build Coastguard Worker raise AssertionError(msg) # noqa: B904 4159*da0073e9SAndroid Build Coastguard Worker 4160*da0073e9SAndroid Build Coastguard Worker def test_class_member_back_compat(self): 4161*da0073e9SAndroid Build Coastguard Worker """ 4162*da0073e9SAndroid Build Coastguard Worker Test backward compatibility for members of classes with 4163*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True). Currently this checks for 4164*da0073e9SAndroid Build Coastguard Worker exact matches on the publicly visible members of the class. 4165*da0073e9SAndroid Build Coastguard Worker """ 4166*da0073e9SAndroid Build Coastguard Worker class_method_strs = [] 4167*da0073e9SAndroid Build Coastguard Worker 4168*da0073e9SAndroid Build Coastguard Worker for obj in _BACK_COMPAT_OBJECTS: 4169*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, type): 4170*da0073e9SAndroid Build Coastguard Worker public_members = [name for name in obj.__dict__ if not name.startswith('_')] 4171*da0073e9SAndroid Build Coastguard Worker class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}') 4172*da0073e9SAndroid Build Coastguard Worker 4173*da0073e9SAndroid Build Coastguard Worker class_method_strs.sort() 4174*da0073e9SAndroid Build Coastguard Worker 4175*da0073e9SAndroid Build Coastguard Worker try: 4176*da0073e9SAndroid Build Coastguard Worker self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members') 4177*da0073e9SAndroid Build Coastguard Worker except AssertionError as e: 4178*da0073e9SAndroid Build Coastguard Worker msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \ 4179*da0073e9SAndroid Build Coastguard Worker f"as backwards-compatible has experienced change in its public members. See the " \ 4180*da0073e9SAndroid Build Coastguard Worker f"above exception context for more information. If this change was " \ 4181*da0073e9SAndroid Build Coastguard Worker f"unintended, please revert it. If it was intended, check with the FX " \ 4182*da0073e9SAndroid Build Coastguard Worker f"team to ensure that the proper deprecation protocols have been followed " \ 4183*da0073e9SAndroid Build Coastguard Worker f"and subsequently --accept the change." 4184*da0073e9SAndroid Build Coastguard Worker raise AssertionError(msg) from e 4185*da0073e9SAndroid Build Coastguard Worker 4186*da0073e9SAndroid Build Coastguard Worker def test_public_api_surface(self): 4187*da0073e9SAndroid Build Coastguard Worker non_back_compat_objects = {} 4188*da0073e9SAndroid Build Coastguard Worker 4189*da0073e9SAndroid Build Coastguard Worker def check_symbols_have_bc_designation(m, seen): 4190*da0073e9SAndroid Build Coastguard Worker if not m.__name__.startswith('torch.fx'): 4191*da0073e9SAndroid Build Coastguard Worker return 4192*da0073e9SAndroid Build Coastguard Worker if m.__name__.startswith('torch.fx.experimental'): 4193*da0073e9SAndroid Build Coastguard Worker return 4194*da0073e9SAndroid Build Coastguard Worker # It's really common for inner functions to point to random modules 4195*da0073e9SAndroid Build Coastguard Worker # - make sure we don't recurse into modules we've already checked. 4196*da0073e9SAndroid Build Coastguard Worker seen.add(m.__name__) 4197*da0073e9SAndroid Build Coastguard Worker for k, v in m.__dict__.items(): 4198*da0073e9SAndroid Build Coastguard Worker if hasattr(v, '__name__') and v.__name__ in seen: 4199*da0073e9SAndroid Build Coastguard Worker continue 4200*da0073e9SAndroid Build Coastguard Worker if v is m: 4201*da0073e9SAndroid Build Coastguard Worker continue 4202*da0073e9SAndroid Build Coastguard Worker if k.startswith('_'): 4203*da0073e9SAndroid Build Coastguard Worker continue 4204*da0073e9SAndroid Build Coastguard Worker if isinstance(v, types.ModuleType): 4205*da0073e9SAndroid Build Coastguard Worker check_symbols_have_bc_designation(v, seen) 4206*da0073e9SAndroid Build Coastguard Worker elif isinstance(v, (type, types.FunctionType)): 4207*da0073e9SAndroid Build Coastguard Worker if v not in _MARKED_WITH_COMPATIBILITY: 4208*da0073e9SAndroid Build Coastguard Worker non_back_compat_objects.setdefault(v) 4209*da0073e9SAndroid Build Coastguard Worker 4210*da0073e9SAndroid Build Coastguard Worker check_symbols_have_bc_designation(torch.fx, set()) 4211*da0073e9SAndroid Build Coastguard Worker check_symbols_have_bc_designation(torch.fx.passes, set()) 4212*da0073e9SAndroid Build Coastguard Worker 4213*da0073e9SAndroid Build Coastguard Worker non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()] 4214*da0073e9SAndroid Build Coastguard Worker # Only want objects in torch.fx 4215*da0073e9SAndroid Build Coastguard Worker non_back_compat_strs = [ 4216*da0073e9SAndroid Build Coastguard Worker s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')] 4217*da0073e9SAndroid Build Coastguard Worker # Only want objects in public namespaces 4218*da0073e9SAndroid Build Coastguard Worker non_back_compat_strs = [ 4219*da0073e9SAndroid Build Coastguard Worker s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))] 4220*da0073e9SAndroid Build Coastguard Worker non_back_compat_strs.sort() 4221*da0073e9SAndroid Build Coastguard Worker 4222*da0073e9SAndroid Build Coastguard Worker if len(non_back_compat_strs) != 0: 4223*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a " 4224*da0073e9SAndroid Build Coastguard Worker f"backwards-compatibility classification! Please decorate these " 4225*da0073e9SAndroid Build Coastguard Worker f"API(s) with `@torch.fx._compatibility.compatibility` to specify " 4226*da0073e9SAndroid Build Coastguard Worker f"BC guarantees.") 4227*da0073e9SAndroid Build Coastguard Worker 4228*da0073e9SAndroid Build Coastguard Worker def test_adding_side_effect_function(self): 4229*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 4230*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4231*da0073e9SAndroid Build Coastguard Worker side_effect_func(x) 4232*da0073e9SAndroid Build Coastguard Worker return x 4233*da0073e9SAndroid Build Coastguard Worker 4234*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(TestModule()) 4235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(gm.graph.nodes), 3) 4236*da0073e9SAndroid Build Coastguard Worker gm.graph.eliminate_dead_code() 4237*da0073e9SAndroid Build Coastguard Worker gm.recompile() 4238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(gm.graph.nodes), 3) 4239*da0073e9SAndroid Build Coastguard Worker found = False 4240*da0073e9SAndroid Build Coastguard Worker for node in gm.graph.nodes: 4241*da0073e9SAndroid Build Coastguard Worker if node.op == 'call_function' and node.target == side_effect_func: 4242*da0073e9SAndroid Build Coastguard Worker found = True 4243*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found) 4244*da0073e9SAndroid Build Coastguard Worker 4245*da0073e9SAndroid Build Coastguard Worker def test_preserve_unused_attr_after_unpickle(self): 4246*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(Add()) 4247*da0073e9SAndroid Build Coastguard Worker gm.add_submodule("foo", Add()) 4248*da0073e9SAndroid Build Coastguard Worker gm.dummy_buffer = torch.nn.Buffer(torch.empty(1)) 4249*da0073e9SAndroid Build Coastguard Worker gm.register_parameter("dummy_parameter", torch.nn.Parameter(torch.empty(1))) 4250*da0073e9SAndroid Build Coastguard Worker b = io.BytesIO() 4251*da0073e9SAndroid Build Coastguard Worker torch.save(gm, b) 4252*da0073e9SAndroid Build Coastguard Worker b.seek(0) 4253*da0073e9SAndroid Build Coastguard Worker # weights_only=False as this loads a GraphModule 4254*da0073e9SAndroid Build Coastguard Worker # GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default 4255*da0073e9SAndroid Build Coastguard Worker reload_gm = torch.load(b, weights_only=False) 4256*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(reload_gm, "foo")) 4257*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(reload_gm, "dummy_buffer")) 4258*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(reload_gm, "dummy_parameter")) 4259*da0073e9SAndroid Build Coastguard Worker 4260*da0073e9SAndroid Build Coastguard Worker# This is failing on Python 3.12 : https://github.com/pytorch/pytorch/issues/119454 4261*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 4262*da0073e9SAndroid Build Coastguard Worker sys.version_info >= (3, 12), "Failing on python 3.12+" 4263*da0073e9SAndroid Build Coastguard Worker) 4264*da0073e9SAndroid Build Coastguard Workerclass TestFunctionalTracing(JitTestCase): 4265*da0073e9SAndroid Build Coastguard Worker def setUp(self): 4266*da0073e9SAndroid Build Coastguard Worker super().setUp() 4267*da0073e9SAndroid Build Coastguard Worker # Checking for mutable operations whil tracing is feature flagged 4268*da0073e9SAndroid Build Coastguard Worker # Enable it in testing but not by default 4269*da0073e9SAndroid Build Coastguard Worker self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations 4270*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.check_mutable_operations = True 4271*da0073e9SAndroid Build Coastguard Worker 4272*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 4273*da0073e9SAndroid Build Coastguard Worker super().tearDown() 4274*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag 4275*da0073e9SAndroid Build Coastguard Worker 4276*da0073e9SAndroid Build Coastguard Worker IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary", 4277*da0073e9SAndroid Build Coastguard Worker "has_torch_function_variadic", "handle_torch_function", 4278*da0073e9SAndroid Build Coastguard Worker "boolean_dispatch") 4279*da0073e9SAndroid Build Coastguard Worker TO_PATCH = {"has_torch_function": None, 4280*da0073e9SAndroid Build Coastguard Worker "has_torch_function_unary": None, 4281*da0073e9SAndroid Build Coastguard Worker "has_torch_function_variadic": None} 4282*da0073e9SAndroid Build Coastguard Worker 4283*da0073e9SAndroid Build Coastguard Worker BUILT_IN_FUNC = (AssertionError, "") 4284*da0073e9SAndroid Build Coastguard Worker PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable") 4285*da0073e9SAndroid Build Coastguard Worker PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") 4286*da0073e9SAndroid Build Coastguard Worker LEN_ERROR = (RuntimeError, r"'len' is not supported in symbolic tracing by default") 4287*da0073e9SAndroid Build Coastguard Worker ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$") 4288*da0073e9SAndroid Build Coastguard Worker CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow") 4289*da0073e9SAndroid Build Coastguard Worker INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined") 4290*da0073e9SAndroid Build Coastguard Worker MUTABLE = (RuntimeError, r"Tried to trace mutable operation") 4291*da0073e9SAndroid Build Coastguard Worker 4292*da0073e9SAndroid Build Coastguard Worker UNTRACEABLE_FUNCTIONALS = { 4293*da0073e9SAndroid Build Coastguard Worker "adaptive_avg_pool1d": BUILT_IN_FUNC, 4294*da0073e9SAndroid Build Coastguard Worker "avg_pool1d": BUILT_IN_FUNC, 4295*da0073e9SAndroid Build Coastguard Worker "avg_pool2d": BUILT_IN_FUNC, 4296*da0073e9SAndroid Build Coastguard Worker "avg_pool3d": BUILT_IN_FUNC, 4297*da0073e9SAndroid Build Coastguard Worker "bilinear": BUILT_IN_FUNC, 4298*da0073e9SAndroid Build Coastguard Worker "celu_": BUILT_IN_FUNC, 4299*da0073e9SAndroid Build Coastguard Worker "channel_shuffle": BUILT_IN_FUNC, 4300*da0073e9SAndroid Build Coastguard Worker "native_channel_shuffle": BUILT_IN_FUNC, 4301*da0073e9SAndroid Build Coastguard Worker "conv1d": BUILT_IN_FUNC, 4302*da0073e9SAndroid Build Coastguard Worker "conv2d": BUILT_IN_FUNC, 4303*da0073e9SAndroid Build Coastguard Worker "conv3d": BUILT_IN_FUNC, 4304*da0073e9SAndroid Build Coastguard Worker "conv_tbc": BUILT_IN_FUNC, 4305*da0073e9SAndroid Build Coastguard Worker "conv_transpose1d": BUILT_IN_FUNC, 4306*da0073e9SAndroid Build Coastguard Worker "conv_transpose2d": BUILT_IN_FUNC, 4307*da0073e9SAndroid Build Coastguard Worker "conv_transpose3d": BUILT_IN_FUNC, 4308*da0073e9SAndroid Build Coastguard Worker "cosine_similarity": BUILT_IN_FUNC, 4309*da0073e9SAndroid Build Coastguard Worker "elu_": BUILT_IN_FUNC, 4310*da0073e9SAndroid Build Coastguard Worker "gelu": BUILT_IN_FUNC, 4311*da0073e9SAndroid Build Coastguard Worker "hardshrink": BUILT_IN_FUNC, 4312*da0073e9SAndroid Build Coastguard Worker "hardtanh_": BUILT_IN_FUNC, 4313*da0073e9SAndroid Build Coastguard Worker "leaky_relu_": BUILT_IN_FUNC, 4314*da0073e9SAndroid Build Coastguard Worker "linear": BUILT_IN_FUNC, 4315*da0073e9SAndroid Build Coastguard Worker "logsigmoid": BUILT_IN_FUNC, 4316*da0073e9SAndroid Build Coastguard Worker "one_hot": BUILT_IN_FUNC, 4317*da0073e9SAndroid Build Coastguard Worker "pad": ARG_TYPE_MISMATCH, 4318*da0073e9SAndroid Build Coastguard Worker "pairwise_distance": BUILT_IN_FUNC, 4319*da0073e9SAndroid Build Coastguard Worker "pdist": BUILT_IN_FUNC, 4320*da0073e9SAndroid Build Coastguard Worker "pixel_shuffle": BUILT_IN_FUNC, 4321*da0073e9SAndroid Build Coastguard Worker "pixel_unshuffle": BUILT_IN_FUNC, 4322*da0073e9SAndroid Build Coastguard Worker "prelu": BUILT_IN_FUNC, 4323*da0073e9SAndroid Build Coastguard Worker "relu_": BUILT_IN_FUNC, 4324*da0073e9SAndroid Build Coastguard Worker "rrelu_": BUILT_IN_FUNC, 4325*da0073e9SAndroid Build Coastguard Worker "selu_": BUILT_IN_FUNC, 4326*da0073e9SAndroid Build Coastguard Worker "scaled_dot_product_attention": BUILT_IN_FUNC, 4327*da0073e9SAndroid Build Coastguard Worker "softplus": BUILT_IN_FUNC, 4328*da0073e9SAndroid Build Coastguard Worker "softshrink": BUILT_IN_FUNC, 4329*da0073e9SAndroid Build Coastguard Worker "threshold_": BUILT_IN_FUNC, 4330*da0073e9SAndroid Build Coastguard Worker 4331*da0073e9SAndroid Build Coastguard Worker "adaptive_avg_pool2d": LEN_ERROR, 4332*da0073e9SAndroid Build Coastguard Worker "adaptive_avg_pool3d": LEN_ERROR, 4333*da0073e9SAndroid Build Coastguard Worker "adaptive_max_pool2d_with_indices": LEN_ERROR, 4334*da0073e9SAndroid Build Coastguard Worker "adaptive_max_pool3d_with_indices": LEN_ERROR, 4335*da0073e9SAndroid Build Coastguard Worker "instance_norm": CONTROL_FLOW, 4336*da0073e9SAndroid Build Coastguard Worker 4337*da0073e9SAndroid Build Coastguard Worker "adaptive_max_pool1d": PROXY_ITERABLE, 4338*da0073e9SAndroid Build Coastguard Worker "adaptive_max_pool2d": PROXY_ITERABLE, 4339*da0073e9SAndroid Build Coastguard Worker "adaptive_max_pool3d": PROXY_ITERABLE, 4340*da0073e9SAndroid Build Coastguard Worker "fractional_max_pool2d": PROXY_ITERABLE, 4341*da0073e9SAndroid Build Coastguard Worker "fractional_max_pool3d": PROXY_ITERABLE, 4342*da0073e9SAndroid Build Coastguard Worker "max_pool1d": PROXY_ITERABLE, 4343*da0073e9SAndroid Build Coastguard Worker "max_pool2d": PROXY_ITERABLE, 4344*da0073e9SAndroid Build Coastguard Worker "max_pool3d": PROXY_ITERABLE, 4345*da0073e9SAndroid Build Coastguard Worker 4346*da0073e9SAndroid Build Coastguard Worker "lp_pool2d": PROXY_ITERATED, 4347*da0073e9SAndroid Build Coastguard Worker "lp_pool3d": PROXY_ITERATED, 4348*da0073e9SAndroid Build Coastguard Worker "max_unpool1d": PROXY_ITERATED, 4349*da0073e9SAndroid Build Coastguard Worker "max_unpool2d": PROXY_ITERATED, 4350*da0073e9SAndroid Build Coastguard Worker "max_unpool3d": PROXY_ITERATED, 4351*da0073e9SAndroid Build Coastguard Worker "fold": PROXY_ITERATED, 4352*da0073e9SAndroid Build Coastguard Worker "unfold": PROXY_ITERATED, 4353*da0073e9SAndroid Build Coastguard Worker 4354*da0073e9SAndroid Build Coastguard Worker "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH, 4355*da0073e9SAndroid Build Coastguard Worker "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH, 4356*da0073e9SAndroid Build Coastguard Worker "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH, 4357*da0073e9SAndroid Build Coastguard Worker "layer_norm": ARG_TYPE_MISMATCH, 4358*da0073e9SAndroid Build Coastguard Worker "rms_norm": ARG_TYPE_MISMATCH, 4359*da0073e9SAndroid Build Coastguard Worker "lp_pool1d": ARG_TYPE_MISMATCH, 4360*da0073e9SAndroid Build Coastguard Worker 4361*da0073e9SAndroid Build Coastguard Worker "affine_grid": CONTROL_FLOW, 4362*da0073e9SAndroid Build Coastguard Worker "alpha_dropout": CONTROL_FLOW, 4363*da0073e9SAndroid Build Coastguard Worker "batch_norm": CONTROL_FLOW, 4364*da0073e9SAndroid Build Coastguard Worker "binary_cross_entropy": CONTROL_FLOW, 4365*da0073e9SAndroid Build Coastguard Worker "binary_cross_entropy_with_logits": CONTROL_FLOW, 4366*da0073e9SAndroid Build Coastguard Worker "celu": CONTROL_FLOW, 4367*da0073e9SAndroid Build Coastguard Worker "cosine_embedding_loss": CONTROL_FLOW, 4368*da0073e9SAndroid Build Coastguard Worker "cross_entropy": CONTROL_FLOW, 4369*da0073e9SAndroid Build Coastguard Worker "ctc_loss": CONTROL_FLOW, 4370*da0073e9SAndroid Build Coastguard Worker "dropout": CONTROL_FLOW, 4371*da0073e9SAndroid Build Coastguard Worker "dropout1d": CONTROL_FLOW, 4372*da0073e9SAndroid Build Coastguard Worker "dropout2d": CONTROL_FLOW, 4373*da0073e9SAndroid Build Coastguard Worker "dropout3d": CONTROL_FLOW, 4374*da0073e9SAndroid Build Coastguard Worker "elu": CONTROL_FLOW, 4375*da0073e9SAndroid Build Coastguard Worker "embedding": CONTROL_FLOW, 4376*da0073e9SAndroid Build Coastguard Worker "embedding_bag": CONTROL_FLOW, 4377*da0073e9SAndroid Build Coastguard Worker "feature_alpha_dropout": CONTROL_FLOW, 4378*da0073e9SAndroid Build Coastguard Worker "gaussian_nll_loss": CONTROL_FLOW, 4379*da0073e9SAndroid Build Coastguard Worker "glu": CONTROL_FLOW, 4380*da0073e9SAndroid Build Coastguard Worker "grid_sample": CONTROL_FLOW, 4381*da0073e9SAndroid Build Coastguard Worker "group_norm": CONTROL_FLOW, 4382*da0073e9SAndroid Build Coastguard Worker "gumbel_softmax": CONTROL_FLOW, 4383*da0073e9SAndroid Build Coastguard Worker "hardsigmoid": CONTROL_FLOW, 4384*da0073e9SAndroid Build Coastguard Worker "hardswish": CONTROL_FLOW, 4385*da0073e9SAndroid Build Coastguard Worker "hardtanh": CONTROL_FLOW, 4386*da0073e9SAndroid Build Coastguard Worker "hinge_embedding_loss": CONTROL_FLOW, 4387*da0073e9SAndroid Build Coastguard Worker "huber_loss": CONTROL_FLOW, 4388*da0073e9SAndroid Build Coastguard Worker "interpolate": CONTROL_FLOW, 4389*da0073e9SAndroid Build Coastguard Worker "kl_div": CONTROL_FLOW, 4390*da0073e9SAndroid Build Coastguard Worker "l1_loss": CONTROL_FLOW, 4391*da0073e9SAndroid Build Coastguard Worker "leaky_relu": CONTROL_FLOW, 4392*da0073e9SAndroid Build Coastguard Worker "local_response_norm": CONTROL_FLOW, 4393*da0073e9SAndroid Build Coastguard Worker "margin_ranking_loss": CONTROL_FLOW, 4394*da0073e9SAndroid Build Coastguard Worker "max_pool1d_with_indices": ARG_TYPE_MISMATCH, 4395*da0073e9SAndroid Build Coastguard Worker "max_pool2d_with_indices": ARG_TYPE_MISMATCH, 4396*da0073e9SAndroid Build Coastguard Worker "max_pool3d_with_indices": ARG_TYPE_MISMATCH, 4397*da0073e9SAndroid Build Coastguard Worker "mse_loss": CONTROL_FLOW, 4398*da0073e9SAndroid Build Coastguard Worker "multi_head_attention_forward": CONTROL_FLOW, 4399*da0073e9SAndroid Build Coastguard Worker "multi_margin_loss": CONTROL_FLOW, 4400*da0073e9SAndroid Build Coastguard Worker "multilabel_margin_loss": CONTROL_FLOW, 4401*da0073e9SAndroid Build Coastguard Worker "multilabel_soft_margin_loss": CONTROL_FLOW, 4402*da0073e9SAndroid Build Coastguard Worker "nll_loss": CONTROL_FLOW, 4403*da0073e9SAndroid Build Coastguard Worker "poisson_nll_loss": CONTROL_FLOW, 4404*da0073e9SAndroid Build Coastguard Worker "relu": CONTROL_FLOW, 4405*da0073e9SAndroid Build Coastguard Worker "relu6": CONTROL_FLOW, 4406*da0073e9SAndroid Build Coastguard Worker "rrelu": CONTROL_FLOW, 4407*da0073e9SAndroid Build Coastguard Worker "selu": CONTROL_FLOW, 4408*da0073e9SAndroid Build Coastguard Worker "silu": CONTROL_FLOW, 4409*da0073e9SAndroid Build Coastguard Worker "mish": CONTROL_FLOW, 4410*da0073e9SAndroid Build Coastguard Worker "smooth_l1_loss": CONTROL_FLOW, 4411*da0073e9SAndroid Build Coastguard Worker "soft_margin_loss": CONTROL_FLOW, 4412*da0073e9SAndroid Build Coastguard Worker "threshold": CONTROL_FLOW, 4413*da0073e9SAndroid Build Coastguard Worker "triplet_margin_loss": CONTROL_FLOW, 4414*da0073e9SAndroid Build Coastguard Worker "triplet_margin_with_distance_loss": CONTROL_FLOW, 4415*da0073e9SAndroid Build Coastguard Worker "upsample": CONTROL_FLOW, 4416*da0073e9SAndroid Build Coastguard Worker 4417*da0073e9SAndroid Build Coastguard Worker "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT, 4418*da0073e9SAndroid Build Coastguard Worker "upsample_nearest": INTERPOLATE_ARGS_CONFLICT, 4419*da0073e9SAndroid Build Coastguard Worker } 4420*da0073e9SAndroid Build Coastguard Worker 4421*da0073e9SAndroid Build Coastguard Worker # List of nn.functionals with Tensor inputs but not with type annotation 4422*da0073e9SAndroid Build Coastguard Worker FUNCTIONALS_WITHOUT_ANNOTATION = ( 4423*da0073e9SAndroid Build Coastguard Worker "adaptive_max_pool1d", 4424*da0073e9SAndroid Build Coastguard Worker "adaptive_max_pool2d", 4425*da0073e9SAndroid Build Coastguard Worker "adaptive_max_pool3d", 4426*da0073e9SAndroid Build Coastguard Worker "fractional_max_pool2d", 4427*da0073e9SAndroid Build Coastguard Worker "fractional_max_pool3d", 4428*da0073e9SAndroid Build Coastguard Worker "max_pool1d", 4429*da0073e9SAndroid Build Coastguard Worker "max_pool2d", 4430*da0073e9SAndroid Build Coastguard Worker "max_pool3d", 4431*da0073e9SAndroid Build Coastguard Worker "gaussian_nll_loss", 4432*da0073e9SAndroid Build Coastguard Worker "upsample", 4433*da0073e9SAndroid Build Coastguard Worker "upsample_bilinear", 4434*da0073e9SAndroid Build Coastguard Worker "upsample_nearest", 4435*da0073e9SAndroid Build Coastguard Worker ) 4436*da0073e9SAndroid Build Coastguard Worker 4437*da0073e9SAndroid Build Coastguard Worker # Inconsistent behavior between Python 3.8 and other Python versions: 4438*da0073e9SAndroid Build Coastguard Worker # - Python 3.8+: Re-raise internal exception like `PROXY_ITERATED` 4439*da0073e9SAndroid Build Coastguard Worker # - Other Python: Raise `argument of type 'Proxy' is not iterable` due to the same 4440*da0073e9SAndroid Build Coastguard Worker # internal exception above 4441*da0073e9SAndroid Build Coastguard Worker # Use the following map to override the expected exception for Python 3.8 4442*da0073e9SAndroid Build Coastguard Worker UNTRACEABLE_FUNCTIONALS_PY38 = { 4443*da0073e9SAndroid Build Coastguard Worker "adaptive_max_pool1d": PROXY_ITERATED, 4444*da0073e9SAndroid Build Coastguard Worker "adaptive_max_pool2d": PROXY_ITERATED, 4445*da0073e9SAndroid Build Coastguard Worker "adaptive_max_pool3d": PROXY_ITERATED, 4446*da0073e9SAndroid Build Coastguard Worker "fractional_max_pool2d": PROXY_ITERATED, 4447*da0073e9SAndroid Build Coastguard Worker "fractional_max_pool3d": PROXY_ITERATED, 4448*da0073e9SAndroid Build Coastguard Worker "max_pool1d": PROXY_ITERATED, 4449*da0073e9SAndroid Build Coastguard Worker "max_pool2d": PROXY_ITERATED, 4450*da0073e9SAndroid Build Coastguard Worker "max_pool3d": PROXY_ITERATED, 4451*da0073e9SAndroid Build Coastguard Worker 4452*da0073e9SAndroid Build Coastguard Worker "group_norm": CONTROL_FLOW 4453*da0073e9SAndroid Build Coastguard Worker } 4454*da0073e9SAndroid Build Coastguard Worker 4455*da0073e9SAndroid Build Coastguard Worker @classmethod 4456*da0073e9SAndroid Build Coastguard Worker def _get_functional(cls): 4457*da0073e9SAndroid Build Coastguard Worker functional_list = [] 4458*da0073e9SAndroid Build Coastguard Worker for f in dir(torch.nn.functional): 4459*da0073e9SAndroid Build Coastguard Worker if not f.islower(): 4460*da0073e9SAndroid Build Coastguard Worker continue 4461*da0073e9SAndroid Build Coastguard Worker # Ignore internal functions 4462*da0073e9SAndroid Build Coastguard Worker if f.startswith('_'): 4463*da0073e9SAndroid Build Coastguard Worker continue 4464*da0073e9SAndroid Build Coastguard Worker # Ignore supporting functions 4465*da0073e9SAndroid Build Coastguard Worker if f in cls.IGNORE_FUNCS: 4466*da0073e9SAndroid Build Coastguard Worker continue 4467*da0073e9SAndroid Build Coastguard Worker fn = getattr(torch.nn.functional, f) 4468*da0073e9SAndroid Build Coastguard Worker # Ignore non-callable object like modules 4469*da0073e9SAndroid Build Coastguard Worker if not isinstance(fn, Callable): 4470*da0073e9SAndroid Build Coastguard Worker continue 4471*da0073e9SAndroid Build Coastguard Worker if f not in cls.FUNCTIONALS_WITHOUT_ANNOTATION: 4472*da0073e9SAndroid Build Coastguard Worker try: 4473*da0073e9SAndroid Build Coastguard Worker sig = inspect.signature(fn) 4474*da0073e9SAndroid Build Coastguard Worker has_tensor_arg = False 4475*da0073e9SAndroid Build Coastguard Worker for param in sig.parameters.values(): 4476*da0073e9SAndroid Build Coastguard Worker if isinstance(param.annotation, type) and issubclass(param.annotation, torch.Tensor): 4477*da0073e9SAndroid Build Coastguard Worker has_tensor_arg = True 4478*da0073e9SAndroid Build Coastguard Worker if not has_tensor_arg: 4479*da0073e9SAndroid Build Coastguard Worker continue 4480*da0073e9SAndroid Build Coastguard Worker # No signature or Object is not supported 4481*da0073e9SAndroid Build Coastguard Worker except ValueError: 4482*da0073e9SAndroid Build Coastguard Worker pass 4483*da0073e9SAndroid Build Coastguard Worker functional_list.append((f, fn)) 4484*da0073e9SAndroid Build Coastguard Worker return functional_list 4485*da0073e9SAndroid Build Coastguard Worker 4486*da0073e9SAndroid Build Coastguard Worker @classmethod 4487*da0073e9SAndroid Build Coastguard Worker def generate_test_func(cls, func_name, fn): 4488*da0073e9SAndroid Build Coastguard Worker 4489*da0073e9SAndroid Build Coastguard Worker def functional_test(self): 4490*da0073e9SAndroid Build Coastguard Worker if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \ 4491*da0073e9SAndroid Build Coastguard Worker sys.version_info >= (3, 8) and sys.version_info < (3, 12): 4492*da0073e9SAndroid Build Coastguard Worker exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name] 4493*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(exc, err): 4494*da0073e9SAndroid Build Coastguard Worker symbolic_trace(fn) 4495*da0073e9SAndroid Build Coastguard Worker elif func_name in self.UNTRACEABLE_FUNCTIONALS: 4496*da0073e9SAndroid Build Coastguard Worker exc, err = self.UNTRACEABLE_FUNCTIONALS[func_name] 4497*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(exc, err): 4498*da0073e9SAndroid Build Coastguard Worker symbolic_trace(fn) 4499*da0073e9SAndroid Build Coastguard Worker else: 4500*da0073e9SAndroid Build Coastguard Worker symbolic_trace(fn) 4501*da0073e9SAndroid Build Coastguard Worker return functional_test 4502*da0073e9SAndroid Build Coastguard Worker 4503*da0073e9SAndroid Build Coastguard Worker @classmethod 4504*da0073e9SAndroid Build Coastguard Worker def generate_tests(cls): 4505*da0073e9SAndroid Build Coastguard Worker functional_list = cls._get_functional() 4506*da0073e9SAndroid Build Coastguard Worker for func_name, fn in functional_list: 4507*da0073e9SAndroid Build Coastguard Worker test_name = "test_nn_functional_" + func_name 4508*da0073e9SAndroid Build Coastguard Worker functional_test = cls.generate_test_func(func_name, fn) 4509*da0073e9SAndroid Build Coastguard Worker setattr(cls, test_name, functional_test) 4510*da0073e9SAndroid Build Coastguard Worker 4511*da0073e9SAndroid Build Coastguard Worker @classmethod 4512*da0073e9SAndroid Build Coastguard Worker def setUpClass(cls): 4513*da0073e9SAndroid Build Coastguard Worker 4514*da0073e9SAndroid Build Coastguard Worker def no(*args, **kwargs): 4515*da0073e9SAndroid Build Coastguard Worker return False 4516*da0073e9SAndroid Build Coastguard Worker 4517*da0073e9SAndroid Build Coastguard Worker for name in cls.TO_PATCH.keys(): 4518*da0073e9SAndroid Build Coastguard Worker cls.TO_PATCH[name] = getattr(torch.nn.functional, name) 4519*da0073e9SAndroid Build Coastguard Worker setattr(torch.nn.functional, name, no) 4520*da0073e9SAndroid Build Coastguard Worker 4521*da0073e9SAndroid Build Coastguard Worker @classmethod 4522*da0073e9SAndroid Build Coastguard Worker def tearDownClass(cls): 4523*da0073e9SAndroid Build Coastguard Worker for name in cls.TO_PATCH.keys(): 4524*da0073e9SAndroid Build Coastguard Worker setattr(torch.nn.functional, name, cls.TO_PATCH[name]) 4525*da0073e9SAndroid Build Coastguard Worker 4526*da0073e9SAndroid Build Coastguard WorkerTestFunctionalTracing.generate_tests() 4527*da0073e9SAndroid Build Coastguard Worker 4528*da0073e9SAndroid Build Coastguard Worker 4529*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestOperatorSignatures, globals()) 4530*da0073e9SAndroid Build Coastguard Worker 4531*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("too slow") 4532*da0073e9SAndroid Build Coastguard Worker@skipIfNoTorchVision 4533*da0073e9SAndroid Build Coastguard Workerclass TestVisionTracing(JitTestCase): 4534*da0073e9SAndroid Build Coastguard Worker def setUp(self): 4535*da0073e9SAndroid Build Coastguard Worker # Checking for mutable operations while tracing is feature flagged 4536*da0073e9SAndroid Build Coastguard Worker # Enable it in testing but not by default 4537*da0073e9SAndroid Build Coastguard Worker self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations 4538*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.check_mutable_operations = True 4539*da0073e9SAndroid Build Coastguard Worker 4540*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 4541*da0073e9SAndroid Build Coastguard Worker torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag 4542*da0073e9SAndroid Build Coastguard Worker 4543*da0073e9SAndroid Build Coastguard Worker PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") 4544*da0073e9SAndroid Build Coastguard Worker INCONSISTENT_TYPE = ( 4545*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4546*da0073e9SAndroid Build Coastguard Worker r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor" 4547*da0073e9SAndroid Build Coastguard Worker ) 4548*da0073e9SAndroid Build Coastguard Worker 4549*da0073e9SAndroid Build Coastguard Worker UNTRACEABLE_MODELS = { 4550*da0073e9SAndroid Build Coastguard Worker "fasterrcnn_resnet50_fpn": PROXY_ITERATED, 4551*da0073e9SAndroid Build Coastguard Worker "fasterrcnn_resnet50_fpn_v2": PROXY_ITERATED, 4552*da0073e9SAndroid Build Coastguard Worker "fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED, 4553*da0073e9SAndroid Build Coastguard Worker "fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED, 4554*da0073e9SAndroid Build Coastguard Worker "maskrcnn_resnet50_fpn": PROXY_ITERATED, 4555*da0073e9SAndroid Build Coastguard Worker "maskrcnn_resnet50_fpn_v2": PROXY_ITERATED, 4556*da0073e9SAndroid Build Coastguard Worker "keypointrcnn_resnet50_fpn": PROXY_ITERATED, 4557*da0073e9SAndroid Build Coastguard Worker "retinanet_resnet50_fpn": PROXY_ITERATED, 4558*da0073e9SAndroid Build Coastguard Worker "retinanet_resnet50_fpn_v2": PROXY_ITERATED, 4559*da0073e9SAndroid Build Coastguard Worker "ssd300_vgg16": PROXY_ITERATED, 4560*da0073e9SAndroid Build Coastguard Worker "fcos_resnet50_fpn": PROXY_ITERATED, 4561*da0073e9SAndroid Build Coastguard Worker "ssdlite320_mobilenet_v3_large": PROXY_ITERATED, 4562*da0073e9SAndroid Build Coastguard Worker } 4563*da0073e9SAndroid Build Coastguard Worker UNSCRIPTABLE_MODELS = { 4564*da0073e9SAndroid Build Coastguard Worker "googlenet": INCONSISTENT_TYPE, 4565*da0073e9SAndroid Build Coastguard Worker "inception_v3": INCONSISTENT_TYPE, 4566*da0073e9SAndroid Build Coastguard Worker } 4567*da0073e9SAndroid Build Coastguard Worker 4568*da0073e9SAndroid Build Coastguard Worker output_transform = { 4569*da0073e9SAndroid Build Coastguard Worker "fcn_resnet50": lambda x: x["out"], 4570*da0073e9SAndroid Build Coastguard Worker "fcn_resnet101": lambda x: x["out"], 4571*da0073e9SAndroid Build Coastguard Worker "deeplabv3_resnet50": lambda x: x["out"], 4572*da0073e9SAndroid Build Coastguard Worker "deeplabv3_resnet101": lambda x: x["out"], 4573*da0073e9SAndroid Build Coastguard Worker "deeplabv3_mobilenet_v3_large": lambda x: x["out"], 4574*da0073e9SAndroid Build Coastguard Worker "lraspp_mobilenet_v3_large": lambda x: x["out"], 4575*da0073e9SAndroid Build Coastguard Worker "fasterrcnn_resnet50_fpn": lambda x: x[1], 4576*da0073e9SAndroid Build Coastguard Worker "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1], 4577*da0073e9SAndroid Build Coastguard Worker "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1], 4578*da0073e9SAndroid Build Coastguard Worker "maskrcnn_resnet50_fpn": lambda x: x[1], 4579*da0073e9SAndroid Build Coastguard Worker "keypointrcnn_resnet50_fpn": lambda x: x[1], 4580*da0073e9SAndroid Build Coastguard Worker "retinanet_resnet50_fpn": lambda x: x[1], 4581*da0073e9SAndroid Build Coastguard Worker } 4582*da0073e9SAndroid Build Coastguard Worker 4583*da0073e9SAndroid Build Coastguard Worker @classmethod 4584*da0073e9SAndroid Build Coastguard Worker def generate_test_fn(cls, name, x, kwargs): 4585*da0073e9SAndroid Build Coastguard Worker def run_test(self): 4586*da0073e9SAndroid Build Coastguard Worker model = torchvision_models.get_model(name, **kwargs) 4587*da0073e9SAndroid Build Coastguard Worker model = model.eval() 4588*da0073e9SAndroid Build Coastguard Worker if name in self.UNTRACEABLE_MODELS: 4589*da0073e9SAndroid Build Coastguard Worker err, exc = self.UNTRACEABLE_MODELS[name] 4590*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(err, exc): 4591*da0073e9SAndroid Build Coastguard Worker graph = symbolic_trace(model) 4592*da0073e9SAndroid Build Coastguard Worker else: 4593*da0073e9SAndroid Build Coastguard Worker out_transform = self.output_transform.get(name, lambda x: x) 4594*da0073e9SAndroid Build Coastguard Worker graph : torch.fx.GraphModule = symbolic_trace(model) 4595*da0073e9SAndroid Build Coastguard Worker a = out_transform(model(x)) 4596*da0073e9SAndroid Build Coastguard Worker b = out_transform(graph(x)) 4597*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 4598*da0073e9SAndroid Build Coastguard Worker 4599*da0073e9SAndroid Build Coastguard Worker if name in self.UNSCRIPTABLE_MODELS: 4600*da0073e9SAndroid Build Coastguard Worker err, exc = self.UNSCRIPTABLE_MODELS[name] 4601*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(err, exc): 4602*da0073e9SAndroid Build Coastguard Worker script = torch.jit.script(graph) 4603*da0073e9SAndroid Build Coastguard Worker else: 4604*da0073e9SAndroid Build Coastguard Worker script = torch.jit.script(graph) 4605*da0073e9SAndroid Build Coastguard Worker c = out_transform(script(x)) 4606*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, c) 4607*da0073e9SAndroid Build Coastguard Worker 4608*da0073e9SAndroid Build Coastguard Worker return run_test 4609*da0073e9SAndroid Build Coastguard Worker 4610*da0073e9SAndroid Build Coastguard Worker @classmethod 4611*da0073e9SAndroid Build Coastguard Worker def generate_classification_tests(cls): 4612*da0073e9SAndroid Build Coastguard Worker for k in torchvision_models.list_models(module=torchvision_models): 4613*da0073e9SAndroid Build Coastguard Worker test_name = 'test_torchvision_models_' + k 4614*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224) 4615*da0073e9SAndroid Build Coastguard Worker kwargs = dict(num_classes=50) 4616*da0073e9SAndroid Build Coastguard Worker model_test = cls.generate_test_fn(k, x, kwargs) 4617*da0073e9SAndroid Build Coastguard Worker setattr(cls, test_name, model_test) 4618*da0073e9SAndroid Build Coastguard Worker 4619*da0073e9SAndroid Build Coastguard Worker @classmethod 4620*da0073e9SAndroid Build Coastguard Worker def generate_segmentation_tests(cls): 4621*da0073e9SAndroid Build Coastguard Worker for k in torchvision_models.list_models(module=torchvision_models.segmentation): 4622*da0073e9SAndroid Build Coastguard Worker test_name = 'test_torchvision_models_segmentation_' + k 4623*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 3, 32, 32) 4624*da0073e9SAndroid Build Coastguard Worker kwargs = dict(num_classes=10, pretrained_backbone=False) 4625*da0073e9SAndroid Build Coastguard Worker model_test = cls.generate_test_fn(k, x, kwargs) 4626*da0073e9SAndroid Build Coastguard Worker setattr(cls, test_name, model_test) 4627*da0073e9SAndroid Build Coastguard Worker 4628*da0073e9SAndroid Build Coastguard Worker @classmethod 4629*da0073e9SAndroid Build Coastguard Worker def generate_detection_tests(cls): 4630*da0073e9SAndroid Build Coastguard Worker for k in torchvision_models.list_models(module=torchvision_models.detection): 4631*da0073e9SAndroid Build Coastguard Worker test_name = 'test_torchvision_models_detection_' + k 4632*da0073e9SAndroid Build Coastguard Worker x = [torch.rand(3, 300, 300)] 4633*da0073e9SAndroid Build Coastguard Worker kwargs = dict(num_classes=10, pretrained_backbone=False) 4634*da0073e9SAndroid Build Coastguard Worker model_test = cls.generate_test_fn(k, x, kwargs) 4635*da0073e9SAndroid Build Coastguard Worker setattr(cls, test_name, model_test) 4636*da0073e9SAndroid Build Coastguard Worker 4637*da0073e9SAndroid Build Coastguard Worker @classmethod 4638*da0073e9SAndroid Build Coastguard Worker def generate_video_tests(cls): 4639*da0073e9SAndroid Build Coastguard Worker for k in torchvision_models.list_models(module=torchvision_models.video): 4640*da0073e9SAndroid Build Coastguard Worker test_name = 'test_torchvision_models_video_' + k 4641*da0073e9SAndroid Build Coastguard Worker x = ( 4642*da0073e9SAndroid Build Coastguard Worker torch.rand(1, 3, 4, 112, 112) 4643*da0073e9SAndroid Build Coastguard Worker if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"} 4644*da0073e9SAndroid Build Coastguard Worker else torch.rand(1, 3, 16, 224, 224) 4645*da0073e9SAndroid Build Coastguard Worker ) 4646*da0073e9SAndroid Build Coastguard Worker kwargs = dict(num_classes=50) 4647*da0073e9SAndroid Build Coastguard Worker model_test = cls.generate_test_fn(k, x, kwargs) 4648*da0073e9SAndroid Build Coastguard Worker setattr(cls, test_name, model_test) 4649*da0073e9SAndroid Build Coastguard Worker 4650*da0073e9SAndroid Build Coastguard Worker @classmethod 4651*da0073e9SAndroid Build Coastguard Worker def generate_tests(cls): 4652*da0073e9SAndroid Build Coastguard Worker cls.generate_classification_tests() 4653*da0073e9SAndroid Build Coastguard Worker cls.generate_detection_tests() 4654*da0073e9SAndroid Build Coastguard Worker cls.generate_segmentation_tests() 4655*da0073e9SAndroid Build Coastguard Worker cls.generate_video_tests() 4656*da0073e9SAndroid Build Coastguard Worker 4657*da0073e9SAndroid Build Coastguard Workerif HAS_TORCHVISION: 4658*da0073e9SAndroid Build Coastguard Worker TestVisionTracing.generate_tests() 4659*da0073e9SAndroid Build Coastguard Worker 4660*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 4661*da0073e9SAndroid Build Coastguard Worker run_tests() 4662