xref: /aosp_15_r20/external/executorch/exir/control_flow.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker"""
10*523fa7a6SAndroid Build Coastguard WorkerThis module contains tools for rewriting a dynamic PyTorch program such
11*523fa7a6SAndroid Build Coastguard Workerthat the dynamic part (e.g. control flow) can be properly captured by
12*523fa7a6SAndroid Build Coastguard WorkerDispatchTracer.
13*523fa7a6SAndroid Build Coastguard WorkerThe core idea is annotating all branches in the graph with unique keys,
14*523fa7a6SAndroid Build Coastguard Workerand using a dictionary of supplemental inputs as arguments to these
15*523fa7a6SAndroid Build Coastguard Workerlocal branches so that every path gets a canonical input during tracing.
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard WorkerFor example, consider the following usage of Python if statement:
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Worker.. code-block:: python
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Worker    if pred:
22*523fa7a6SAndroid Build Coastguard Worker        ...
23*523fa7a6SAndroid Build Coastguard Worker        ret = a
24*523fa7a6SAndroid Build Coastguard Worker    else:
25*523fa7a6SAndroid Build Coastguard Worker        ...
26*523fa7a6SAndroid Build Coastguard Worker        ret = b
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard WorkerTo rewrite the code to be tracable, users may use tracing_key decorator
29*523fa7a6SAndroid Build Coastguard Workerand cond operator:
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Worker.. code-block:: python
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker    @control_flow.tracing_context(inputs)
34*523fa7a6SAndroid Build Coastguard Worker    def branch_true(args):
35*523fa7a6SAndroid Build Coastguard Worker        ...
36*523fa7a6SAndroid Build Coastguard Worker        return a
37*523fa7a6SAndroid Build Coastguard Worker
38*523fa7a6SAndroid Build Coastguard Worker    @control_flow.tracing_context(inputs)
39*523fa7a6SAndroid Build Coastguard Worker    def branch_false(args):
40*523fa7a6SAndroid Build Coastguard Worker        ...
41*523fa7a6SAndroid Build Coastguard Worker        return b
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker    ret = control_flow.cond(pred, branch_true, branch_false, args)
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Workerand we can use the usual exir.capture() function.
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Worker.. code-block:: python
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Worker    exir.capture(module, args)
50*523fa7a6SAndroid Build Coastguard Worker
51*523fa7a6SAndroid Build Coastguard Worker"""
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, List, Optional, Tuple, Union
54*523fa7a6SAndroid Build Coastguard Worker
55*523fa7a6SAndroid Build Coastguard Workerimport torch
56*523fa7a6SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
57*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType, internal_assert
58*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tracer import (
59*523fa7a6SAndroid Build Coastguard Worker    DispatchTracer,
60*523fa7a6SAndroid Build Coastguard Worker    flattened_dispatch_trace,
61*523fa7a6SAndroid Build Coastguard Worker    PythonTensor,
62*523fa7a6SAndroid Build Coastguard Worker    tree_return,
63*523fa7a6SAndroid Build Coastguard Worker    unwrap_functional,
64*523fa7a6SAndroid Build Coastguard Worker    unwrap_proxy,
65*523fa7a6SAndroid Build Coastguard Worker    using_tracer,
66*523fa7a6SAndroid Build Coastguard Worker    Value,
67*523fa7a6SAndroid Build Coastguard Worker)
68*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.wrap import update_with_proxy
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard Workerdef shape(x: torch.Tensor) -> Union[torch._C.Size, torch.Tensor]:
72*523fa7a6SAndroid Build Coastguard Worker    """
73*523fa7a6SAndroid Build Coastguard Worker    A helper function for capturing the shape as a tensor from a tensor
74*523fa7a6SAndroid Build Coastguard Worker    value.
75*523fa7a6SAndroid Build Coastguard Worker    """
76*523fa7a6SAndroid Build Coastguard Worker    tracer = DispatchTracer.get()
77*523fa7a6SAndroid Build Coastguard Worker    if tracer is None:
78*523fa7a6SAndroid Build Coastguard Worker        return x.shape
79*523fa7a6SAndroid Build Coastguard Worker    x = unwrap_functional(x)
80*523fa7a6SAndroid Build Coastguard Worker    if not isinstance(x, PythonTensor):
81*523fa7a6SAndroid Build Coastguard Worker        raise ExportError(
82*523fa7a6SAndroid Build Coastguard Worker            ExportErrorType.INVALID_INPUT_TYPE,
83*523fa7a6SAndroid Build Coastguard Worker            f"exir custom shape function only takes EXIR dispatch tensor, but got: {type(x)}",
84*523fa7a6SAndroid Build Coastguard Worker        )
85*523fa7a6SAndroid Build Coastguard Worker    # TODO _shape_as_tensor should work with functional tensor but currently not.
86*523fa7a6SAndroid Build Coastguard Worker    # TODO torch.tensor() should succeed under functionalization but currently not.
87*523fa7a6SAndroid Build Coastguard Worker    #      see: https://github.com/pytorch/pytorch/pull/76319
88*523fa7a6SAndroid Build Coastguard Worker    tmp = torch.empty(len(x.shape), dtype=torch.int64)
89*523fa7a6SAndroid Build Coastguard Worker    for i, s in enumerate(x.shape):
90*523fa7a6SAndroid Build Coastguard Worker        tmp[i] = s
91*523fa7a6SAndroid Build Coastguard Worker    proxy = torch.ops.aten._shape_as_tensor.default(x.proxy)
92*523fa7a6SAndroid Build Coastguard Worker    return PythonTensor(unwrap_functional(tmp), proxy)
93*523fa7a6SAndroid Build Coastguard Worker
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Workerdef _make_submodule(
96*523fa7a6SAndroid Build Coastguard Worker    fn: Callable[..., Union[torch.Tensor, Tuple[torch.Tensor]]],
97*523fa7a6SAndroid Build Coastguard Worker    example_returns: Optional[List[torch.Tensor]] = None,
98*523fa7a6SAndroid Build Coastguard Worker    single_return: bool = False,
99*523fa7a6SAndroid Build Coastguard Worker) -> torch.fx.GraphModule:
100*523fa7a6SAndroid Build Coastguard Worker    if not hasattr(fn, "__tracing_inputs__"):
101*523fa7a6SAndroid Build Coastguard Worker        raise ExportError(
102*523fa7a6SAndroid Build Coastguard Worker            ExportErrorType.MISSING_PROPERTY,
103*523fa7a6SAndroid Build Coastguard Worker            f"Expect function '{fn.__name__}' to be decorated with tracing_context.",
104*523fa7a6SAndroid Build Coastguard Worker        )
105*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
106*523fa7a6SAndroid Build Coastguard Worker    args = fn.__tracing_inputs__
107*523fa7a6SAndroid Build Coastguard Worker    # TODO(yidi): we don't want to enable here because we are not gonna use this code path in the future anyways
108*523fa7a6SAndroid Build Coastguard Worker    gm, _ = flattened_dispatch_trace(fn, args, set(), enable_functionalization=False)
109*523fa7a6SAndroid Build Coastguard Worker    output = next(iter(reversed(gm.graph.nodes)))
110*523fa7a6SAndroid Build Coastguard Worker    if example_returns:
111*523fa7a6SAndroid Build Coastguard Worker        internal_assert(
112*523fa7a6SAndroid Build Coastguard Worker            len(example_returns) == len(output.args[0]),
113*523fa7a6SAndroid Build Coastguard Worker            f"Eager mode of this {gm} returns {len(example_returns)} elements, but this graph returns {len(output.args[0])} elements",
114*523fa7a6SAndroid Build Coastguard Worker        )
115*523fa7a6SAndroid Build Coastguard Worker
116*523fa7a6SAndroid Build Coastguard Worker    if single_return:
117*523fa7a6SAndroid Build Coastguard Worker        # Force number of returned value to be 1.
118*523fa7a6SAndroid Build Coastguard Worker        internal_assert(
119*523fa7a6SAndroid Build Coastguard Worker            len(output.args[0]) == 1,
120*523fa7a6SAndroid Build Coastguard Worker            f"Graph {gm} should return just one element, but got {len(output.args[0])}",
121*523fa7a6SAndroid Build Coastguard Worker        )
122*523fa7a6SAndroid Build Coastguard Worker        output.args = tuple(output.args[0])
123*523fa7a6SAndroid Build Coastguard Worker        gm.recompile()
124*523fa7a6SAndroid Build Coastguard Worker    # pyre-fixme[16]: `GraphModule` has no attribute `__tracing_inputs__`.
125*523fa7a6SAndroid Build Coastguard Worker    gm.__tracing_inputs__ = args
126*523fa7a6SAndroid Build Coastguard Worker    return gm
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker
129*523fa7a6SAndroid Build Coastguard Workerdef while_loop(
130*523fa7a6SAndroid Build Coastguard Worker    cond_fn: Callable[..., torch.Tensor],
131*523fa7a6SAndroid Build Coastguard Worker    body_fn: Callable[..., Tuple[torch.Tensor]],
132*523fa7a6SAndroid Build Coastguard Worker    init_val: pytree.PyTree,
133*523fa7a6SAndroid Build Coastguard Worker) -> Union[Tuple[torch.Tensor], Value]:
134*523fa7a6SAndroid Build Coastguard Worker    """
135*523fa7a6SAndroid Build Coastguard Worker    A higher order function returning the result based on executing body_fn
136*523fa7a6SAndroid Build Coastguard Worker    until cond_fn returns False.
137*523fa7a6SAndroid Build Coastguard Worker    """
138*523fa7a6SAndroid Build Coastguard Worker    flattened_inputs, _ = pytree.tree_flatten(init_val)
139*523fa7a6SAndroid Build Coastguard Worker    if not all(isinstance(i, torch.Tensor) for i in flattened_inputs):
140*523fa7a6SAndroid Build Coastguard Worker        raise ExportError(
141*523fa7a6SAndroid Build Coastguard Worker            ExportErrorType.INVALID_INPUT_TYPE,
142*523fa7a6SAndroid Build Coastguard Worker            f"control_flow.while_loop() expects all inputs values to be tensors, actual inputs: {init_val}",
143*523fa7a6SAndroid Build Coastguard Worker        )
144*523fa7a6SAndroid Build Coastguard Worker
145*523fa7a6SAndroid Build Coastguard Worker    with using_tracer(None):
146*523fa7a6SAndroid Build Coastguard Worker        val = init_val
147*523fa7a6SAndroid Build Coastguard Worker        while cond_fn(*val):
148*523fa7a6SAndroid Build Coastguard Worker            val = body_fn(*val)
149*523fa7a6SAndroid Build Coastguard Worker
150*523fa7a6SAndroid Build Coastguard Worker    flattened_outputs, _ = pytree.tree_flatten(val)
151*523fa7a6SAndroid Build Coastguard Worker    if not all(isinstance(o, torch.Tensor) for o in flattened_outputs):
152*523fa7a6SAndroid Build Coastguard Worker        raise ExportError(
153*523fa7a6SAndroid Build Coastguard Worker            ExportErrorType.INVALID_OUTPUT_TYPE,
154*523fa7a6SAndroid Build Coastguard Worker            f"control_flow.while_loop() expects all returned values to be tensors, actual outputs: {val}",
155*523fa7a6SAndroid Build Coastguard Worker        )
156*523fa7a6SAndroid Build Coastguard Worker
157*523fa7a6SAndroid Build Coastguard Worker    tracer = DispatchTracer.get()
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Worker    if tracer is None:
160*523fa7a6SAndroid Build Coastguard Worker        return val
161*523fa7a6SAndroid Build Coastguard Worker
162*523fa7a6SAndroid Build Coastguard Worker    gm_cond = _make_submodule(cond_fn, single_return=True)
163*523fa7a6SAndroid Build Coastguard Worker    gm_body = _make_submodule(body_fn)
164*523fa7a6SAndroid Build Coastguard Worker
165*523fa7a6SAndroid Build Coastguard Worker    proxies = tuple([unwrap_proxy(v) for v in flattened_inputs])
166*523fa7a6SAndroid Build Coastguard Worker
167*523fa7a6SAndroid Build Coastguard Worker    proxy = tracer.create_proxy(
168*523fa7a6SAndroid Build Coastguard Worker        "call_function",
169*523fa7a6SAndroid Build Coastguard Worker        while_loop,
170*523fa7a6SAndroid Build Coastguard Worker        (gm_cond, gm_body, proxies),
171*523fa7a6SAndroid Build Coastguard Worker        {},
172*523fa7a6SAndroid Build Coastguard Worker    )
173*523fa7a6SAndroid Build Coastguard Worker
174*523fa7a6SAndroid Build Coastguard Worker    return tree_return(val, proxy, update_with_proxy)
175*523fa7a6SAndroid Build Coastguard Worker
176*523fa7a6SAndroid Build Coastguard Worker
177*523fa7a6SAndroid Build Coastguard Workerdef tracing_context(
178*523fa7a6SAndroid Build Coastguard Worker    inputs: Tuple[torch.Tensor, ...],
179*523fa7a6SAndroid Build Coastguard Worker) -> Callable[..., Callable[..., Union[torch.Tensor, Tuple[torch.Tensor]]]]:
180*523fa7a6SAndroid Build Coastguard Worker    """
181*523fa7a6SAndroid Build Coastguard Worker    A decorator function to annotate code path that we conditionally
182*523fa7a6SAndroid Build Coastguard Worker    run during tracing. We need to annotate these paths for now because
183*523fa7a6SAndroid Build Coastguard Worker    during exir.capture(), the tracer does not know what's the proper
184*523fa7a6SAndroid Build Coastguard Worker    local inputs to be passed to the untaken path.
185*523fa7a6SAndroid Build Coastguard Worker    """
186*523fa7a6SAndroid Build Coastguard Worker
187*523fa7a6SAndroid Build Coastguard Worker    def decorator(
188*523fa7a6SAndroid Build Coastguard Worker        f: Callable[..., Tuple[torch.Tensor]]
189*523fa7a6SAndroid Build Coastguard Worker    ) -> Callable[..., Union[torch.Tensor, Tuple[torch.Tensor]]]:
190*523fa7a6SAndroid Build Coastguard Worker        def wrapper(
191*523fa7a6SAndroid Build Coastguard Worker            *args: torch.Tensor, **kwargs: Tuple[torch.Tensor]
192*523fa7a6SAndroid Build Coastguard Worker        ) -> Tuple[torch.Tensor]:
193*523fa7a6SAndroid Build Coastguard Worker            if kwargs:
194*523fa7a6SAndroid Build Coastguard Worker                raise ExportError(
195*523fa7a6SAndroid Build Coastguard Worker                    ExportErrorType.NOT_SUPPORTED,
196*523fa7a6SAndroid Build Coastguard Worker                    "kwargs are not supported for @tracing_context decorated functions.",
197*523fa7a6SAndroid Build Coastguard Worker                )
198*523fa7a6SAndroid Build Coastguard Worker
199*523fa7a6SAndroid Build Coastguard Worker            return f(*args)
200*523fa7a6SAndroid Build Coastguard Worker
201*523fa7a6SAndroid Build Coastguard Worker        wrapper.__tracing_inputs__ = inputs  # pyre-ignore
202*523fa7a6SAndroid Build Coastguard Worker        return wrapper
203*523fa7a6SAndroid Build Coastguard Worker
204*523fa7a6SAndroid Build Coastguard Worker    return decorator
205