xref: /aosp_15_r20/external/pytorch/torch/_higher_order_ops/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3from contextlib import contextmanager
4from dataclasses import dataclass
5from typing import Any, Callable
6
7import torch
8import torch.fx.traceback as fx_traceback
9import torch.utils._pytree as pytree
10from torch._ops import OperatorBase
11from torch.fx.experimental.proxy_tensor import make_fx
12from torch.multiprocessing.reductions import StorageWeakRef
13
14
15@dataclass
16class UnsupportedAliasMutationException(RuntimeError):
17    reason: str
18
19
20def autograd_not_implemented_inner(
21    operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any
22) -> Any:
23    """If autograd is enabled and any of the arguments require grad this will either
24    raise an error or return a DelayedError depending on the value of delayed.
25
26    Args:
27        operator: The Operator to call with the *args and **kwargs with
28        op_name: The name of the Operator
29        delayed_error: If True, return a DelayedError instead of raising an error
30        args: The flattened operands to the Operator
31        kwargs: The keyword arguments to the Operator
32
33    Raises:
34        RuntimeError: If autograd is enabled and any of the arguments to the Operator
35    """
36    with torch._C._AutoDispatchBelowAutograd():
37        result = operator(*args, **kwargs)
38        flat_operands = pytree.arg_tree_leaves(*args)
39        if torch.is_grad_enabled() and any(
40            f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
41        ):
42            if delayed_error:
43                err_fn = torch._C._functions.DelayedError(
44                    f"Autograd not implemented for {str(operator)}",
45                    1,
46                )
47
48                def fake_requires_grad(tensor):
49                    if torch.is_floating_point(tensor) or torch.is_complex(tensor):
50                        tensor = tensor.detach()
51                        tensor.requires_grad = True
52                    return tensor
53
54                return pytree.tree_map_only(
55                    torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
56                )
57            else:
58                raise RuntimeError(f"Autograd not implemented for {str(operator)}")
59        return result
60
61
62def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable:
63    def inner(*args, **kwargs):
64        return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
65
66    return inner
67
68
69def _maybe_run_with_interpreter(fn):
70    maybe_interpreted_fn = fn
71    if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
72        # Running graph with interpreter is needed for propagating the stack_trace
73        def graph_with_interpreter(*args):
74            with fx_traceback.preserve_node_meta():
75                return torch.fx.Interpreter(fn).run(*args)
76
77        maybe_interpreted_fn = graph_with_interpreter
78    return maybe_interpreted_fn
79
80
81def reenter_make_fx(fn):
82    from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
83
84    @functools.wraps(fn)
85    def wrapped(*args):
86        assert (
87            _CURRENT_MAKE_FX_TRACER is not None
88        ), "Cannot reenter make_fx when we're not under a make_fx tracing session"
89        return _CURRENT_MAKE_FX_TRACER.trace_subgraph(
90            _maybe_run_with_interpreter(fn), *args
91        )
92
93    return wrapped
94
95
96def _maybe_reenter_make_fx(fn):
97    from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
98
99    if _CURRENT_MAKE_FX_TRACER is not None:
100        return reenter_make_fx(fn)
101    else:
102        return make_fx(fn)
103
104
105@contextmanager
106def _set_compilation_env():
107    _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
108    try:
109        # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
110        # once we are confident fx tracing works with dynamo.
111        torch.fx._symbolic_trace._is_fx_tracing_flag = False
112        yield
113    finally:
114        torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
115
116
117def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False):
118    """
119    Dispatch-trace the branch with inputs and check if
120    producing graph has mutable op on the input. This is
121    bit restrictive as the branch must be traceable.
122    """
123    try:
124        gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
125    except UnsupportedAliasMutationException:
126        # this can happen when nested cond_op is
127        # functionalized
128        return True
129    except Exception as e:
130        raise e
131
132    def _detect_input_mutation(gm):
133        input_nodes = set()
134        for node in gm.graph.nodes:
135            if node.op == "placeholder":
136                input_nodes.add(node)
137            if node.op == "call_function":
138                target = node.target
139                if (
140                    isinstance(target, torch._ops.OpOverload)
141                    and target._schema.is_mutable
142                ):
143                    for arg in node.args:
144                        if arg in input_nodes:
145                            return True
146
147        for _, module in gm.named_children():
148            if isinstance(module, torch.fx.GraphModule):
149                if _detect_input_mutation(module):
150                    return True
151
152        return False
153
154    return _detect_input_mutation(gm)
155
156
157def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False):
158    """
159    Dispatch-trace the branch with inputs and check if
160    producing graph has output aliasing the branch input. This is
161    bit restrictive as the branch must be traceable.
162    """
163    try:
164        gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
165    except UnsupportedAliasMutationException:
166        # this can happen when nested cond_op is
167        # functionalized
168        return True
169    except Exception as e:
170        raise e
171
172    def _detect_input_alias(gm):
173        input_storages = set()
174        for node in gm.graph.nodes:
175            # We need to check existence of "val" because we reuse the logic here
176            # for map operator, where num_mapped_args is a scalar
177            # and doesn't have a "val" meta.
178            if node.op == "placeholder" and "val" in node.meta:
179                input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
180            if node.op == "output":
181
182                def check_alias(out):
183                    if out is not None and "val" in out.meta:
184                        out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
185                        return out_storage in input_storages
186                    return False
187
188                if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))):
189                    return True
190
191        for _, module in gm.named_children():
192            if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
193                return True
194
195        return False
196
197    return _detect_input_alias(gm)
198
199
200def unique_graph_id(proxy_mode, prefix):
201    """Returns a unique name and id for a graph to be added to a proxy_mode tracer"""
202    # There are probably better ways - I know that create_arg has some self incrementing name
203    # magic to it, but since we explicitly have to get the name for register_module,
204    # I was not sure how to do that. This kinda simulates it.
205    next_name = None
206    i = 0
207    while not next_name:
208        candidate = f"{prefix}_{i}"
209        if hasattr(proxy_mode.tracer.root, candidate):
210            i += 1
211        else:
212            next_name = candidate
213    return i, next_name
214
215
216def _from_fun(t):
217    from torch._functorch.aot_autograd import from_fun
218    from torch._subclasses.functional_tensor import FunctionalTensor
219
220    if isinstance(t, torch.Tensor):
221        if t.dtype != torch.bool:
222            return torch.empty_strided(
223                t.size(),
224                t.stride(),
225                dtype=t.dtype,
226                requires_grad=t.requires_grad,
227            )
228        else:
229            # clone of a functional tensor produces a functional tensor
230            # but we want to avoid it so we clone a non-functional version
231            maybe_unfunc_t = t
232            if isinstance(t, FunctionalTensor):
233                torch._sync(t)
234                maybe_unfunc_t = from_fun(t)
235            elif torch._is_functional_tensor(t):
236                # need to handle both types of functionalization here:
237                # these are the tensors that came from the user,
238                # which could be either FunctionalTensorWrapper or FunctionalTensor
239                torch._sync(t)
240                maybe_unfunc_t = torch._from_functional_tensor(t)
241            return maybe_unfunc_t.clone()
242    return t
243
244
245def clone_outputs_aliasing_inputs(args):
246    input_storage = {
247        StorageWeakRef(arg._typed_storage())
248        for arg in args
249        if isinstance(arg, torch.Tensor)
250    }
251
252    def maybe_clone(t):
253        if (
254            isinstance(t, torch.Tensor)
255            and StorageWeakRef(t._typed_storage()) in input_storage
256        ):
257            return t.clone()
258        return t
259
260    return maybe_clone
261
262
263def prepare_fw_with_masks(fn):
264    def fw_with_masks(*args):
265        fw_out = fn(*args)
266        return fw_out, [
267            True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
268            for ret in fw_out
269        ]
270
271    return fw_with_masks
272
273
274# TODO: The parameter use_output_and_grad_bw is required because some operations
275# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
276def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
277    from torch._functorch.aot_autograd import AOTConfig, create_joint
278
279    # Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
280    # between Autograd and Python key. Currently, we only suspend functionalization but more can be
281    # added when required. Will encounter two problems if we don't suspend functionalization:
282    #
283    # 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
284    # but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
285    # However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
286    # fetch the proxy for the inputs and fail to capture any operations on them.
287    #
288    # 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
289    # wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
290    # only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
291    # when creating the output node, it fails to associate the wrapped tensor with its proxy.
292    # Instead, it will create _tensor_constant as output.
293
294    dummy_aot_config = AOTConfig(
295        fw_compiler=None,  # type: ignore[arg-type]
296        bw_compiler=None,  # type: ignore[arg-type]
297        partition_fn=None,  # type: ignore[arg-type]
298        decompositions={},
299        num_params_buffers=0,
300        aot_id=0,
301        keep_inference_input_mutations=False,
302    )
303
304    example_grad = [_from_fun(out) for out in fw_outputs]
305    num_grads = len(example_grad)
306    fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)
307
308    def joint_fn(*joint_operands_grads):
309        if use_output_and_grad_bw:
310            grads = joint_operands_grads[0]
311            inputs = joint_operands_grads[1][-1:]
312        else:
313            grads = joint_operands_grads[:num_grads]
314            inputs = joint_operands_grads[num_grads:]
315
316        joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)
317        _, grads = joint(
318            list(inputs),
319            [grad for grad in grads if grad is not None and grad.requires_grad],
320        )
321
322        # In order to keep map functional for backward graph,
323        # we clone outputs that are aliasing inputs
324        maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)
325
326        return pytree.tree_map(maybe_clone, grads)
327
328    if use_output_and_grad_bw:
329        example_xs_out = list(fw_inputs) + list(fw_outputs)
330        joint_graph = _maybe_reenter_make_fx(joint_fn)(
331            (list(example_grad), list(example_xs_out))
332        )
333    else:
334        example_xs_out = list(fw_inputs)
335        joint_graph = _maybe_reenter_make_fx(joint_fn)(
336            *(list(example_grad) + list(example_xs_out))
337        )
338
339    return fw_graph, joint_graph
340
341
342def _unstack_pytree(xs):
343    flat_xs, inspec = pytree.tree_flatten(xs)
344    if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
345        raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
346
347    if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
348        raise RuntimeError(
349            f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
350        )
351
352    a = zip(*flat_xs)
353
354    pytrees = []
355    for tuple in a:
356        pytrees.append(pytree.tree_unflatten(tuple, inspec))
357    return pytrees
358
359
360def _stack_pytree(pytrees):
361    flat_out = []
362    out_spec = None
363    for pt in pytrees:
364        flat_pt, out_spec = pytree.tree_flatten(pt)
365        flat_out.append(flat_pt)
366    assert out_spec is not None
367    b = zip(*flat_out)
368    stacked_out = []
369    for leaves in b:
370        if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
371            stacked_out.append(torch.stack(leaves))
372        elif all(leaf is None for leaf in leaves):
373            # Backward graph can return None output when forward inputs doesn't require grad.
374            # When we eagerly execute backward graph, we need to call _stack_pytree on its output,
375            # therefore we need to deal with None output.
376            stacked_out.append(None)  # type: ignore[arg-type]
377        else:
378            raise RuntimeError(f"Cannot stack {leaves}.")
379    return pytree.tree_unflatten(stacked_out, out_spec)
380