xref: /aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3Contains various utils for AOTAutograd, including those for handling collections.
4"""
5
6import dataclasses
7import operator
8import warnings
9from contextlib import nullcontext
10from functools import wraps
11from typing import Any, Callable, List, Optional, Tuple, Union
12
13import torch
14import torch.utils._pytree as pytree
15from torch._library.fake_class_registry import FakeScriptObject
16from torch._logging import getArtifactLogger
17from torch.fx.experimental._backward_state import BackwardState
18from torch.fx.experimental.proxy_tensor import py_sym_types
19
20
21KNOWN_TYPES = [
22    torch.Tensor,
23    BackwardState,
24    int,
25    str,
26    float,
27    bool,
28    type(None),
29    *py_sym_types,
30    FakeScriptObject,
31    torch.ScriptObject,
32]
33
34original_zip = zip
35
36aot_graphs_effects_log = getArtifactLogger(__name__, "aot_graphs_effects")
37
38
39def strict_zip(*iterables, strict=True, **kwargs):
40    if not strict:
41        return original_zip(*iterables, **kwargs)
42
43    length = len(iterables[0])
44    for iterable in iterables[1:]:
45        if len(iterable) != length:
46            raise ValueError(
47                "The iterables have different lengths and strict mode is enabled."
48            )
49
50    return original_zip(*iterables, **kwargs)
51
52
53def _get_symint_hints(exprs):
54    """
55    Get the hints of a list/tuple of int/SymInt.
56    """
57    if isinstance(exprs, (list, tuple)):
58        return type(exprs)(_get_symint_hints(e) for e in exprs)
59    elif isinstance(exprs, torch.SymInt):
60        return exprs.node.shape_env.size_hint(exprs.node.expr)
61    else:
62        return exprs
63
64
65def partial_flatten_asdict(obj: Any) -> Any:
66    if dataclasses.is_dataclass(obj):
67        return {
68            field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)
69        }
70    elif isinstance(obj, (list, tuple)):
71        return obj.__class__([partial_flatten_asdict(item) for item in obj])
72    elif isinstance(obj, dict):
73        return {k: partial_flatten_asdict(v) for k, v in obj.items()}
74    else:
75        return obj
76
77
78def normalize_as_list(x):
79    if isinstance(x, tuple):
80        return list(x)
81    elif isinstance(x, list):
82        return x
83    return [x]
84
85
86def _get_autocast_states():
87    return [
88        torch.is_autocast_enabled("cuda"),
89        torch.is_autocast_enabled("cpu"),
90        torch.get_autocast_dtype("cuda"),
91        torch.get_autocast_dtype("cpu"),
92        torch.is_autocast_cache_enabled(),
93    ]
94
95
96def make_boxed_func(f):
97    def g(args):
98        return f(*args)
99
100    g._boxed_call = True  # type: ignore[attr-defined]
101    return g
102
103
104def make_boxed_compiler(compiler):
105    @wraps(compiler)
106    def f(fx_g, inps):
107        out_f = compiler(fx_g, inps)
108        fx_g = make_boxed_func(out_f)
109        return fx_g
110
111    return f
112
113
114def call_func_at_runtime_with_args(
115    f, args: Union[Tuple[Any], List[Any]], steal_args=False, disable_amp=False
116):
117    if not steal_args:
118        args = list(args)
119    assert isinstance(args, list)
120
121    context = torch._C._DisableAutocast if disable_amp else nullcontext
122    with context():
123        if hasattr(f, "_boxed_call"):
124            out = normalize_as_list(f(args))
125        else:
126            # TODO: Please remove soon
127            # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
128            warnings.warn(
129                "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. "
130                "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
131                "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
132            )
133            out = normalize_as_list(f(*args))
134    return out
135
136
137# Inspired by autodidax (thanks!)
138class PytreeThunk:
139    spec: Optional[pytree.TreeSpec] = None
140    # These are some kinda dumb microoptimizations that save about 3-4 us of overhead.
141    is_simple: Optional[
142        bool
143    ] = None  # if the output spec is a tuple/list, we won't bother unflattening it.
144    is_really_simple: Optional[bool] = None  # if the output spec is a LeafSpec
145
146    def set(self, spec: pytree.TreeSpec) -> None:
147        assert self.spec is None or self.spec == spec
148        assert spec is not None
149        self.spec: pytree.TreeSpec = spec
150        if self.spec.type in {tuple, list} and all(
151            child.is_leaf() for child in spec.children_specs
152        ):
153            self.is_simple = True
154        if self.spec.is_leaf():
155            self.is_really_simple = True
156
157    def unflatten(self, x: List[Any]) -> Any:
158        if self.is_really_simple:
159            return x[0]
160        if self.is_simple:
161            return x
162        assert self.spec is not None
163        return pytree.tree_unflatten(x, self.spec)
164
165
166# Creates a function that returns flattened inputs and outputs
167# Also returns the output tree spec, which is needed to recover the "unflattened"
168# output tree structure later.
169def create_tree_flattened_fn(fn, args, kwargs=None) -> Tuple[Callable, PytreeThunk]:
170    if kwargs is None:
171        kwargs = {}
172    # Save the args_spec for flat_tensor_args to unflatten while tracing
173    _, tensor_args_spec = pytree.tree_flatten((args, kwargs))
174    out_spec = PytreeThunk()
175
176    def flat_fn(*flat_args):
177        # The input are flattened tensor args. Prepare the args in the
178        # order that original function expects. Add static args as well.
179        # They will appear as tensor constants in the traced graph.
180        nonlocal out_spec
181        args, kwargs = pytree.tree_unflatten(flat_args, tensor_args_spec)
182        tree_out = fn(*args, **kwargs)
183        flat_out, spec = pytree.tree_flatten(tree_out)
184        for i in flat_out:
185            is_known_type = False
186            for j in KNOWN_TYPES:
187                if isinstance(i, j):
188                    is_known_type = True
189                    break
190            if not is_known_type:
191                raise RuntimeError(
192                    f"Found {type(i)} in output, which is not a known type. "
193                    "If this type holds tensors, you need to register a pytree for it. "
194                    "See https://github.com/pytorch/functorch/issues/475 for a brief "
195                    "explanation why. If you don't need to register a pytree, please "
196                    "leave a comment explaining your use case and we'll make this more "
197                    "ergonomic to deal with"
198                )
199        out_spec.set(spec)
200        return flat_out
201
202    # Can't use functools.wraps here because the wrapper has different
203    # calling convention
204    if hasattr(fn, "_orig_mod"):
205        flat_fn._orig_mod = fn._orig_mod  # type: ignore[attr-defined]
206
207    return flat_fn, out_spec
208
209
210# This function takes in a tensor t, and returns one of t, t.view(), or t.clone().
211# When tracing the joint forward + backward, for any inputs in the graph that are mutated,
212# we need to clone them first (and similarly for metadata-only mutations, we need to view them first).
213# The idea is that when we trace the backward, we need to pass in the *original* primals
214# to autograd.grad(), before they were mutated.
215# Note: when we have synthetic base inputs, we need to clone them *before* creating views off of them.
216# This means that "idx" here represents the index of the (potentially) synthetic base.
217# What we need to do is:
218# (1) map the current (post-synthetic-base calling convention) input argument index
219#     to int index pre-synthetic-base-calling-convention.
220# (2) There could be multiple, if this index corresponds to a synthetic base
221#     that has multiple input aliases.
222# (3) If any of those corresponding inputs get metadata mutations, then we clone the base.
223def maybe_to_fresh_input(idx, t, meta):
224    if not isinstance(t, torch.Tensor):
225        return t
226    if idx in meta.mutated_inp_runtime_indices:
227        # We only need to bother cloning mutated inputs that participate in autograd.
228        mutated_inp_idx = meta.mutated_inp_runtime_indices.index(idx)
229        if meta.input_info[idx].requires_grad and meta.input_info[idx].mutates_data:
230            # Make sure the primal we pass to autograd.grad()
231            # sees the tensor before the mutation
232            return t.clone()
233        if meta.input_info[idx] and meta.input_info[idx].mutates_metadata:
234            # Make sure the primal we pass to autograd.grad()
235            # sees the tensor before the metadata mutation
236            return t.view(t.shape)
237    return t
238
239
240def is_with_effects(node):
241    return (
242        node.op == "call_function"
243        and node.target == torch.ops.higher_order.with_effects
244    )
245
246
247def is_with_effects_op(node, op):
248    return is_with_effects(node) and node.args[1] == op
249
250
251def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None):
252    # Remove the tokens from the inputs/outputs of the graph since inductor does
253    # not want these extra inputs/outputs, and replace them with
254    # _make_token() to create a token, and _sink_tokens() to collect the
255    # tokens.  See Note [Side-Effectful Tokens in AOTAutograd]
256    # Logic:
257    # 1. Inputs identified as input tokens:
258    #    - If used as a first argument in with_effects
259    #
260    # 2. Outputs identified as output tokens:
261    #    - If Produced by getitem(with_effects, 0)
262    #
263    # 3. Checks invariants of number input output tokens:
264    # forward:
265    # expected_num_erased_inputs == len(fw_metadata.tokens)
266    # expected_num_erased_outputs == len(fw_metadata.tokens)
267    # backward:
268    # expected_num_erased_inputs == fw_metadata.num_backward_tokens
269    # expected_num_erased_outputs == fw_metadata.num_backward_tokens
270    num_forward_tokens = len(fw_metadata.tokens)
271    num_backward_tokens = fw_metadata.num_backward_tokens
272
273    def rewrite_with_effects_input_token(module, node):
274        with module.graph.inserting_before(node):
275            new_token_node = module.graph.call_function(
276                torch.ops.prims._make_token.default, ()
277            )
278            new_token_node.meta["val"] = torch.tensor([])
279            new_token_node.meta["tensor_meta"] = torch.tensor([])
280
281            args = list(node.args)
282            args[0] = new_token_node
283            node.args = tuple(args)
284
285    def rewrite_output(module, node, output_token_nodes, other_output_args):
286        for output_token_node in output_token_nodes:
287            assert (
288                output_token_node.op == "call_function"
289                and output_token_node.target == operator.getitem
290                and output_token_node.args[1] == 0
291            )
292        with module.graph.inserting_before(node):
293            module.graph.call_function(
294                torch.ops.prims._sink_tokens.default,
295                (output_token_nodes,),
296            )
297            node.args = (other_output_args,)
298
299    def do(module, subgraph, expected_num_erased):
300        num_erased_inputs = 0
301        num_erased_outs = 0
302        input_nodes = []
303        input_token_nodes = set()
304        with_effect_nodes = []
305        output_token_nodes = []
306        other_output_nodes = []
307        for i, node in enumerate(module.graph.nodes):
308            if node.op == "placeholder":
309                input_nodes.append(node)
310            elif is_with_effects(node):
311                with_effect_nodes.append(node)
312                if node.args[0] in input_nodes:
313                    input_token_nodes.add(node.args[0])
314                    rewrite_with_effects_input_token(module, node)
315            elif node.op == "output":
316                outs = node.args[0]
317                for out in outs:
318                    if (
319                        isinstance(out, torch.fx.node.Node)
320                        and out.op == "call_function"
321                        and out.target == operator.getitem
322                        and out.args[1] == 0
323                        and out.args[0] in with_effect_nodes
324                    ):
325                        output_token_nodes.append(out)
326                    else:
327                        other_output_nodes.append(out)
328
329                rewrite_output(module, node, output_token_nodes, other_output_nodes)
330                num_erased_outs = len(output_token_nodes)
331
332        for input_token_node in input_token_nodes:
333            module.graph.erase_node(input_token_node)
334
335        num_erased_inputs = len(input_token_nodes)
336
337        assert (
338            num_erased_inputs == expected_num_erased
339        ), f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}"
340        assert (
341            num_erased_outs == expected_num_erased
342        ), f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}"
343
344        module.recompile()
345
346    if num_forward_tokens > 0:
347        if aot_config.enable_log:
348            from torch._dynamo.utils import lazy_format_graph_code
349
350            aot_graphs_effects_log.debug(
351                "%s",
352                lazy_format_graph_code(
353                    "Forward graph before unlifting tokens",
354                    fw_module,
355                    aot_config.aot_id,
356                    include_stride=True,
357                    include_device=True,
358                    colored=True,
359                ),
360            )
361        do(
362            fw_module,
363            "forward",
364            num_forward_tokens,
365        )
366
367    if bw_module is not None and num_backward_tokens > 0:
368        if aot_config.enable_log:
369            from torch._dynamo.utils import lazy_format_graph_code
370
371            aot_graphs_effects_log.debug(
372                "%s",
373                lazy_format_graph_code(
374                    "Backward graph before unlifting tokens",
375                    bw_module,
376                    aot_config.aot_id,
377                    include_stride=True,
378                    include_device=True,
379                    colored=True,
380                ),
381            )
382        do(bw_module, "backward", num_backward_tokens)
383
384    # This is sad, but we need to update the metadata to get rid of
385    # the tokens.
386    fw_metadata.tokens = {}
387    fw_metadata.num_backward_tokens = 0
388
389
390def root_module_when_exporting_non_strict(flat_fn):
391    # When exporting in non-strict mode, we wrap the root module in a specific pattern.
392    # See `_aot_export_non_strict` in torch.export._trace.py.
393    # We look for that wrapping pattern here.
394    if hasattr(flat_fn, "_orig_mod") and hasattr(flat_fn._orig_mod, "_export_root"):
395        return flat_fn._orig_mod._export_root
396    else:
397        return None
398
399
400def copy_fwd_metadata_to_bw_nodes(fx_g):
401    """
402    Input: `fx_g` which contains the joint fwd+bwd FX graph created by
403    aot_autograd.
404
405    This function walks the graph and copies over metadata from forward nodes
406    to backward nodes, using the `seq_nr` field as a one-to-many mapping
407    from forward node to backward node. This metadata is useful for performance
408    profiling and debugging.
409    """
410
411    def _is_forward_node_with_seq_nr(node):
412        # For now, assume that if nn_module_stack_metadata is populated, this
413        # node is from the forward. Ignore nodes without `seq_nr`.
414        # TODO(future): there is likely a less brittle way to do this by walking
415        # the descendants of graph inputs corresponding to fwd inputs, didn't
416        # seem obvious at first glance on how to partition graph inputs into
417        # fwd vs bwd without relying on string names.
418        return "nn_module_stack" in node.meta and "seq_nr" in node.meta
419
420    def _is_backward_node_with_seq_nr(node):
421        # For now, assume that if nn_module_stack_metadata is not populated,
422        # this node is from the backward. Ignore nodes without `seq_nr`.
423        # TODO(future): there is likely a less brittle way to do this, same
424        # as with the forward.
425        return ("nn_module_stack" not in node.meta) and "seq_nr" in node.meta
426
427    fwd_seq_nr_to_node = {}
428    for node in fx_g.graph.nodes:
429        if not _is_forward_node_with_seq_nr(node):
430            continue
431        seq_nr = node.meta["seq_nr"]
432        if seq_nr in fwd_seq_nr_to_node:
433            # If we already saw an op with the current `seq_nr`, that means
434            # that the current op did not create an autograd node, and there
435            # is no corresponding backward node, so we skip.
436            continue
437        fwd_seq_nr_to_node[node.meta["seq_nr"]] = node
438
439    for node in fx_g.graph.nodes:
440        if not _is_backward_node_with_seq_nr(node):
441            continue
442        # fwd_node should always exist, but handle non-existence just in case
443        fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
444        if fwd_node is not None:
445            node.meta["fwd_nn_module_stack"] = fwd_node.meta["nn_module_stack"]
446            node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack")
447