1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker""" 10*523fa7a6SAndroid Build Coastguard WorkerThis module contains tools for rewriting a dynamic PyTorch program such 11*523fa7a6SAndroid Build Coastguard Workerthat the dynamic part (e.g. control flow) can be properly captured by 12*523fa7a6SAndroid Build Coastguard WorkerDispatchTracer. 13*523fa7a6SAndroid Build Coastguard WorkerThe core idea is annotating all branches in the graph with unique keys, 14*523fa7a6SAndroid Build Coastguard Workerand using a dictionary of supplemental inputs as arguments to these 15*523fa7a6SAndroid Build Coastguard Workerlocal branches so that every path gets a canonical input during tracing. 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard WorkerFor example, consider the following usage of Python if statement: 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Worker.. code-block:: python 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Worker if pred: 22*523fa7a6SAndroid Build Coastguard Worker ... 23*523fa7a6SAndroid Build Coastguard Worker ret = a 24*523fa7a6SAndroid Build Coastguard Worker else: 25*523fa7a6SAndroid Build Coastguard Worker ... 26*523fa7a6SAndroid Build Coastguard Worker ret = b 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard WorkerTo rewrite the code to be tracable, users may use tracing_key decorator 29*523fa7a6SAndroid Build Coastguard Workerand cond operator: 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Worker.. code-block:: python 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker @control_flow.tracing_context(inputs) 34*523fa7a6SAndroid Build Coastguard Worker def branch_true(args): 35*523fa7a6SAndroid Build Coastguard Worker ... 36*523fa7a6SAndroid Build Coastguard Worker return a 37*523fa7a6SAndroid Build Coastguard Worker 38*523fa7a6SAndroid Build Coastguard Worker @control_flow.tracing_context(inputs) 39*523fa7a6SAndroid Build Coastguard Worker def branch_false(args): 40*523fa7a6SAndroid Build Coastguard Worker ... 41*523fa7a6SAndroid Build Coastguard Worker return b 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker ret = control_flow.cond(pred, branch_true, branch_false, args) 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Workerand we can use the usual exir.capture() function. 46*523fa7a6SAndroid Build Coastguard Worker 47*523fa7a6SAndroid Build Coastguard Worker.. code-block:: python 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Worker exir.capture(module, args) 50*523fa7a6SAndroid Build Coastguard Worker 51*523fa7a6SAndroid Build Coastguard Worker""" 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, List, Optional, Tuple, Union 54*523fa7a6SAndroid Build Coastguard Worker 55*523fa7a6SAndroid Build Coastguard Workerimport torch 56*523fa7a6SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 57*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType, internal_assert 58*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tracer import ( 59*523fa7a6SAndroid Build Coastguard Worker DispatchTracer, 60*523fa7a6SAndroid Build Coastguard Worker flattened_dispatch_trace, 61*523fa7a6SAndroid Build Coastguard Worker PythonTensor, 62*523fa7a6SAndroid Build Coastguard Worker tree_return, 63*523fa7a6SAndroid Build Coastguard Worker unwrap_functional, 64*523fa7a6SAndroid Build Coastguard Worker unwrap_proxy, 65*523fa7a6SAndroid Build Coastguard Worker using_tracer, 66*523fa7a6SAndroid Build Coastguard Worker Value, 67*523fa7a6SAndroid Build Coastguard Worker) 68*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.wrap import update_with_proxy 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard Workerdef shape(x: torch.Tensor) -> Union[torch._C.Size, torch.Tensor]: 72*523fa7a6SAndroid Build Coastguard Worker """ 73*523fa7a6SAndroid Build Coastguard Worker A helper function for capturing the shape as a tensor from a tensor 74*523fa7a6SAndroid Build Coastguard Worker value. 75*523fa7a6SAndroid Build Coastguard Worker """ 76*523fa7a6SAndroid Build Coastguard Worker tracer = DispatchTracer.get() 77*523fa7a6SAndroid Build Coastguard Worker if tracer is None: 78*523fa7a6SAndroid Build Coastguard Worker return x.shape 79*523fa7a6SAndroid Build Coastguard Worker x = unwrap_functional(x) 80*523fa7a6SAndroid Build Coastguard Worker if not isinstance(x, PythonTensor): 81*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 82*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.INVALID_INPUT_TYPE, 83*523fa7a6SAndroid Build Coastguard Worker f"exir custom shape function only takes EXIR dispatch tensor, but got: {type(x)}", 84*523fa7a6SAndroid Build Coastguard Worker ) 85*523fa7a6SAndroid Build Coastguard Worker # TODO _shape_as_tensor should work with functional tensor but currently not. 86*523fa7a6SAndroid Build Coastguard Worker # TODO torch.tensor() should succeed under functionalization but currently not. 87*523fa7a6SAndroid Build Coastguard Worker # see: https://github.com/pytorch/pytorch/pull/76319 88*523fa7a6SAndroid Build Coastguard Worker tmp = torch.empty(len(x.shape), dtype=torch.int64) 89*523fa7a6SAndroid Build Coastguard Worker for i, s in enumerate(x.shape): 90*523fa7a6SAndroid Build Coastguard Worker tmp[i] = s 91*523fa7a6SAndroid Build Coastguard Worker proxy = torch.ops.aten._shape_as_tensor.default(x.proxy) 92*523fa7a6SAndroid Build Coastguard Worker return PythonTensor(unwrap_functional(tmp), proxy) 93*523fa7a6SAndroid Build Coastguard Worker 94*523fa7a6SAndroid Build Coastguard Worker 95*523fa7a6SAndroid Build Coastguard Workerdef _make_submodule( 96*523fa7a6SAndroid Build Coastguard Worker fn: Callable[..., Union[torch.Tensor, Tuple[torch.Tensor]]], 97*523fa7a6SAndroid Build Coastguard Worker example_returns: Optional[List[torch.Tensor]] = None, 98*523fa7a6SAndroid Build Coastguard Worker single_return: bool = False, 99*523fa7a6SAndroid Build Coastguard Worker) -> torch.fx.GraphModule: 100*523fa7a6SAndroid Build Coastguard Worker if not hasattr(fn, "__tracing_inputs__"): 101*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 102*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.MISSING_PROPERTY, 103*523fa7a6SAndroid Build Coastguard Worker f"Expect function '{fn.__name__}' to be decorated with tracing_context.", 104*523fa7a6SAndroid Build Coastguard Worker ) 105*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 106*523fa7a6SAndroid Build Coastguard Worker args = fn.__tracing_inputs__ 107*523fa7a6SAndroid Build Coastguard Worker # TODO(yidi): we don't want to enable here because we are not gonna use this code path in the future anyways 108*523fa7a6SAndroid Build Coastguard Worker gm, _ = flattened_dispatch_trace(fn, args, set(), enable_functionalization=False) 109*523fa7a6SAndroid Build Coastguard Worker output = next(iter(reversed(gm.graph.nodes))) 110*523fa7a6SAndroid Build Coastguard Worker if example_returns: 111*523fa7a6SAndroid Build Coastguard Worker internal_assert( 112*523fa7a6SAndroid Build Coastguard Worker len(example_returns) == len(output.args[0]), 113*523fa7a6SAndroid Build Coastguard Worker f"Eager mode of this {gm} returns {len(example_returns)} elements, but this graph returns {len(output.args[0])} elements", 114*523fa7a6SAndroid Build Coastguard Worker ) 115*523fa7a6SAndroid Build Coastguard Worker 116*523fa7a6SAndroid Build Coastguard Worker if single_return: 117*523fa7a6SAndroid Build Coastguard Worker # Force number of returned value to be 1. 118*523fa7a6SAndroid Build Coastguard Worker internal_assert( 119*523fa7a6SAndroid Build Coastguard Worker len(output.args[0]) == 1, 120*523fa7a6SAndroid Build Coastguard Worker f"Graph {gm} should return just one element, but got {len(output.args[0])}", 121*523fa7a6SAndroid Build Coastguard Worker ) 122*523fa7a6SAndroid Build Coastguard Worker output.args = tuple(output.args[0]) 123*523fa7a6SAndroid Build Coastguard Worker gm.recompile() 124*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `GraphModule` has no attribute `__tracing_inputs__`. 125*523fa7a6SAndroid Build Coastguard Worker gm.__tracing_inputs__ = args 126*523fa7a6SAndroid Build Coastguard Worker return gm 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Worker 129*523fa7a6SAndroid Build Coastguard Workerdef while_loop( 130*523fa7a6SAndroid Build Coastguard Worker cond_fn: Callable[..., torch.Tensor], 131*523fa7a6SAndroid Build Coastguard Worker body_fn: Callable[..., Tuple[torch.Tensor]], 132*523fa7a6SAndroid Build Coastguard Worker init_val: pytree.PyTree, 133*523fa7a6SAndroid Build Coastguard Worker) -> Union[Tuple[torch.Tensor], Value]: 134*523fa7a6SAndroid Build Coastguard Worker """ 135*523fa7a6SAndroid Build Coastguard Worker A higher order function returning the result based on executing body_fn 136*523fa7a6SAndroid Build Coastguard Worker until cond_fn returns False. 137*523fa7a6SAndroid Build Coastguard Worker """ 138*523fa7a6SAndroid Build Coastguard Worker flattened_inputs, _ = pytree.tree_flatten(init_val) 139*523fa7a6SAndroid Build Coastguard Worker if not all(isinstance(i, torch.Tensor) for i in flattened_inputs): 140*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 141*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.INVALID_INPUT_TYPE, 142*523fa7a6SAndroid Build Coastguard Worker f"control_flow.while_loop() expects all inputs values to be tensors, actual inputs: {init_val}", 143*523fa7a6SAndroid Build Coastguard Worker ) 144*523fa7a6SAndroid Build Coastguard Worker 145*523fa7a6SAndroid Build Coastguard Worker with using_tracer(None): 146*523fa7a6SAndroid Build Coastguard Worker val = init_val 147*523fa7a6SAndroid Build Coastguard Worker while cond_fn(*val): 148*523fa7a6SAndroid Build Coastguard Worker val = body_fn(*val) 149*523fa7a6SAndroid Build Coastguard Worker 150*523fa7a6SAndroid Build Coastguard Worker flattened_outputs, _ = pytree.tree_flatten(val) 151*523fa7a6SAndroid Build Coastguard Worker if not all(isinstance(o, torch.Tensor) for o in flattened_outputs): 152*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 153*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.INVALID_OUTPUT_TYPE, 154*523fa7a6SAndroid Build Coastguard Worker f"control_flow.while_loop() expects all returned values to be tensors, actual outputs: {val}", 155*523fa7a6SAndroid Build Coastguard Worker ) 156*523fa7a6SAndroid Build Coastguard Worker 157*523fa7a6SAndroid Build Coastguard Worker tracer = DispatchTracer.get() 158*523fa7a6SAndroid Build Coastguard Worker 159*523fa7a6SAndroid Build Coastguard Worker if tracer is None: 160*523fa7a6SAndroid Build Coastguard Worker return val 161*523fa7a6SAndroid Build Coastguard Worker 162*523fa7a6SAndroid Build Coastguard Worker gm_cond = _make_submodule(cond_fn, single_return=True) 163*523fa7a6SAndroid Build Coastguard Worker gm_body = _make_submodule(body_fn) 164*523fa7a6SAndroid Build Coastguard Worker 165*523fa7a6SAndroid Build Coastguard Worker proxies = tuple([unwrap_proxy(v) for v in flattened_inputs]) 166*523fa7a6SAndroid Build Coastguard Worker 167*523fa7a6SAndroid Build Coastguard Worker proxy = tracer.create_proxy( 168*523fa7a6SAndroid Build Coastguard Worker "call_function", 169*523fa7a6SAndroid Build Coastguard Worker while_loop, 170*523fa7a6SAndroid Build Coastguard Worker (gm_cond, gm_body, proxies), 171*523fa7a6SAndroid Build Coastguard Worker {}, 172*523fa7a6SAndroid Build Coastguard Worker ) 173*523fa7a6SAndroid Build Coastguard Worker 174*523fa7a6SAndroid Build Coastguard Worker return tree_return(val, proxy, update_with_proxy) 175*523fa7a6SAndroid Build Coastguard Worker 176*523fa7a6SAndroid Build Coastguard Worker 177*523fa7a6SAndroid Build Coastguard Workerdef tracing_context( 178*523fa7a6SAndroid Build Coastguard Worker inputs: Tuple[torch.Tensor, ...], 179*523fa7a6SAndroid Build Coastguard Worker) -> Callable[..., Callable[..., Union[torch.Tensor, Tuple[torch.Tensor]]]]: 180*523fa7a6SAndroid Build Coastguard Worker """ 181*523fa7a6SAndroid Build Coastguard Worker A decorator function to annotate code path that we conditionally 182*523fa7a6SAndroid Build Coastguard Worker run during tracing. We need to annotate these paths for now because 183*523fa7a6SAndroid Build Coastguard Worker during exir.capture(), the tracer does not know what's the proper 184*523fa7a6SAndroid Build Coastguard Worker local inputs to be passed to the untaken path. 185*523fa7a6SAndroid Build Coastguard Worker """ 186*523fa7a6SAndroid Build Coastguard Worker 187*523fa7a6SAndroid Build Coastguard Worker def decorator( 188*523fa7a6SAndroid Build Coastguard Worker f: Callable[..., Tuple[torch.Tensor]] 189*523fa7a6SAndroid Build Coastguard Worker ) -> Callable[..., Union[torch.Tensor, Tuple[torch.Tensor]]]: 190*523fa7a6SAndroid Build Coastguard Worker def wrapper( 191*523fa7a6SAndroid Build Coastguard Worker *args: torch.Tensor, **kwargs: Tuple[torch.Tensor] 192*523fa7a6SAndroid Build Coastguard Worker ) -> Tuple[torch.Tensor]: 193*523fa7a6SAndroid Build Coastguard Worker if kwargs: 194*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 195*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, 196*523fa7a6SAndroid Build Coastguard Worker "kwargs are not supported for @tracing_context decorated functions.", 197*523fa7a6SAndroid Build Coastguard Worker ) 198*523fa7a6SAndroid Build Coastguard Worker 199*523fa7a6SAndroid Build Coastguard Worker return f(*args) 200*523fa7a6SAndroid Build Coastguard Worker 201*523fa7a6SAndroid Build Coastguard Worker wrapper.__tracing_inputs__ = inputs # pyre-ignore 202*523fa7a6SAndroid Build Coastguard Worker return wrapper 203*523fa7a6SAndroid Build Coastguard Worker 204*523fa7a6SAndroid Build Coastguard Worker return decorator 205