xref: /aosp_15_r20/external/pytorch/torch/_higher_order_ops/map.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.utils._pytree as pytree
4from torch._C import DispatchKey
5from torch._dispatch.python import suspend_functionalization
6from torch._functorch.aot_autograd import AOTConfig, create_joint
7from torch._higher_order_ops.utils import (
8    _has_potential_branch_input_alias,
9    _has_potential_branch_input_mutation,
10    _maybe_run_with_interpreter,
11    reenter_make_fx,
12    UnsupportedAliasMutationException,
13)
14from torch._ops import HigherOrderOperator
15from torch._subclasses.fake_tensor import FakeTensorMode
16from torch._subclasses.functional_tensor import disable_functional_mode
17from torch.fx.experimental.proxy_tensor import (
18    disable_proxy_modes_tracing,
19    make_fx,
20    ProxyTorchDispatchMode,
21    track_tensor_tree,
22)
23
24from .utils import (
25    _from_fun,
26    _stack_pytree,
27    _unstack_pytree,
28    clone_outputs_aliasing_inputs,
29    prepare_fw_with_masks,
30)
31
32
33# TODO: We add this to prevent dymamo from tracing into map_wrapper,
34# remove the wrapper call when it's ready.
35class MapWrapper(HigherOrderOperator):
36    def __init__(self):
37        super().__init__("map")
38
39    def __call__(self, xs, *args):
40        return map_wrapper(xs, *args)
41
42
43class MapImpl(HigherOrderOperator):
44    def __init__(self):
45        super().__init__("map_impl")
46
47    def __call__(self, *args, **kwargs):
48        return super().__call__(*args, **kwargs)
49
50
51map = MapWrapper()
52
53map_impl = MapImpl()
54
55dummy_aot_config = AOTConfig(
56    fw_compiler=None,  # type: ignore[arg-type]
57    bw_compiler=None,  # type: ignore[arg-type]
58    partition_fn=None,  # type: ignore[arg-type]
59    decompositions={},
60    num_params_buffers=0,
61    aot_id=0,
62    keep_inference_input_mutations=False,
63)
64
65
66def create_fw_bw_graph(f, num_mapped_args, *args):
67    mapped_xs = args[:num_mapped_args]
68    pos_args = args[num_mapped_args:]
69
70    # See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
71
72    with suspend_functionalization(), disable_functional_mode():
73        with disable_proxy_modes_tracing():
74            unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs)
75            example_xs = _unstack_pytree(unwrapped_mapped_xs)[0]
76
77            example_pos_args = [
78                _from_fun(arg) if isinstance(arg, torch.Tensor) else arg
79                for arg in pos_args
80            ]
81            example_flat_out = pytree.tree_map(
82                _from_fun, f(*example_xs, *example_pos_args)
83            )
84            if any(
85                not isinstance(out, torch.Tensor)
86                for out in example_flat_out
87                if out is not None
88            ):
89                raise RuntimeError(
90                    "Expect outputs of map only contains tensors or None. "
91                    f"Got types {[type(out) for out in example_flat_out]}."
92                )
93            example_grad = [_from_fun(out) for out in example_flat_out]
94
95            fw_graph = make_fx(f)(*example_xs, *example_pos_args)
96
97        def joint_f(*example_args):
98            joint_mapped_args = example_args[:joint_num_mapped]
99            args = example_args[joint_num_mapped:]
100
101            mapped_input = joint_mapped_args[:num_mapped_args]
102            mapped_grads = joint_mapped_args[num_mapped_args:]
103
104            joint = create_joint(prepare_fw_with_masks(f), aot_config=dummy_aot_config)
105            _, grads = joint(
106                list(mapped_input) + list(args),
107                [
108                    grad
109                    for grad in mapped_grads
110                    if grad is not None and grad.requires_grad
111                ],
112            )
113
114            # In order to keep map functional for backward graph,
115            # we clone outputs that are aliasing inputs
116            maybe_clone = clone_outputs_aliasing_inputs(example_args)
117
118            return pytree.tree_map(maybe_clone, grads)
119
120        joint_num_mapped = len(example_grad) + len(example_xs)
121        joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args)
122        return fw_graph, joint_graph
123
124
125def map_wrapper(f, xs, *args):
126    flat_xs, xs_spec = pytree.tree_flatten(xs)
127    if not all(isinstance(t, torch.Tensor) for t in flat_xs):
128        raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.")
129
130    num_mapped_args = len(flat_xs)
131    shapes = [xs.shape for xs in flat_xs]
132    leading_dim_size = shapes[0][0]
133    if leading_dim_size == 0:
134        raise RuntimeError("Leading dimensions of mapped xs cannot be 0.")
135
136    if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
137        raise RuntimeError(
138            f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}."
139        )
140
141    out_spec = None
142
143    def flat_fn(*flat_args):
144        xs = pytree.tree_unflatten(list(flat_args[:num_mapped_args]), xs_spec)
145        unflattened_out = f(xs, *flat_args[num_mapped_args:])
146        flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out)
147
148        nonlocal out_spec
149        out_spec = tmp_out_spec
150        return flat_out
151
152    return pytree.tree_unflatten(
153        map_impl(flat_fn, flat_xs, args), out_spec  # type: ignore[arg-type]
154    )
155
156
157class MapAutogradOp(torch.autograd.Function):
158    @staticmethod
159    def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args):
160        ctx.save_for_backward(*flat_args)
161        ctx._joint_graph = joint_graph
162        ctx._num_mapped_args = num_mapped_args
163        with torch._C._AutoDispatchBelowAutograd():
164            return (
165                *map_impl(
166                    fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:]
167                ),
168            )
169
170    @staticmethod
171    def backward(ctx, *flat_grads):
172        fw_args = ctx.saved_tensors
173        fw_mapped_args = fw_args[: ctx._num_mapped_args]
174        pos_args = fw_args[ctx._num_mapped_args :]
175
176        grads = map_impl(
177            ctx._joint_graph,
178            fw_mapped_args + flat_grads,
179            pos_args,
180        )
181        return None, None, None, *grads
182
183
184def trace_map(proxy_mode, func_overload, f, xs, pos_args):
185    leading_dim_size = xs[0].shape[0]
186
187    example_input = _unstack_pytree(xs)[0]
188    body_graph = f
189
190    body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args)
191
192    next_name = proxy_mode.tracer.get_fresh_qualname("body_graph_")
193
194    proxy_mode.tracer.root.register_module(next_name, body_graph)
195
196    with disable_proxy_modes_tracing():
197        example_outs = body_graph(*example_input, *pos_args)
198
199        def expand_tensor(t):
200            if isinstance(t, torch.Tensor):
201                return t.expand(leading_dim_size, *t.shape)
202            return t
203
204        expanded_outs = pytree.tree_map(expand_tensor, example_outs)
205
206    node_args = (body_graph, list(xs), list(pos_args))
207    proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
208    out_proxy = proxy_mode.tracer.create_proxy(
209        "call_function", func_overload, proxy_args, {}, name="map_impl"
210    )
211    return track_tensor_tree(
212        expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
213    )
214
215
216@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
217def map_dense(f, xs, pos_args):
218    pytrees = []
219    for inp in _unstack_pytree(xs):
220        pytrees.append(f(*inp, *pos_args))
221    return _stack_pytree(pytrees)
222
223
224@map_impl.py_impl(DispatchKey.Autograd)
225def map_autograd(f, xs, pos_args):
226    num_mapped_args = len(xs)
227    fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args)
228    flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args)
229    return flat_out
230
231
232@map_impl.py_impl(ProxyTorchDispatchMode)
233def map_proxy_torch_dispatch_mode(mode, f, xs, args):
234    return trace_map(mode, map_impl, f, xs, args)
235
236
237@map_impl.py_impl(FakeTensorMode)
238def map_fake_tensor_mode(mode, f, xs, args):
239    with mode:
240        return map_dense(f, xs, args)
241
242
243@map_impl.py_functionalize_impl
244def map_functionalize(ctx, f, xs, pos_args):
245    unwrapped_xs = ctx.unwrap_tensors(xs)
246    unwrapped_args = ctx.unwrap_tensors(pos_args)
247    wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f))
248
249    with ctx.redispatch_to_next():
250        with disable_proxy_modes_tracing():
251            example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
252        pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
253        if _has_potential_branch_input_mutation(
254            f, example_inputs, pre_dispatch=pre_dispatch
255        ):
256            raise UnsupportedAliasMutationException("torch.map is mutating the input!")
257
258        if _has_potential_branch_input_alias(
259            f, example_inputs, pre_dispatch=pre_dispatch
260        ):
261            raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
262
263        map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
264        return ctx.wrap_tensors(map_return)
265