xref: /aosp_15_r20/external/pytorch/torch/_higher_order_ops/associative_scan.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import itertools
4from typing import Callable, List
5
6import torch
7import torch._prims_common as utils
8import torch._subclasses.functional_tensor
9import torch.utils._pytree as pytree
10from torch._C import DispatchKey
11from torch._higher_order_ops.utils import (
12    _maybe_run_with_interpreter,
13    _set_compilation_env,
14    autograd_not_implemented,
15    reenter_make_fx,
16    unique_graph_id,
17)
18from torch._inductor.utils import is_pointwise_use
19from torch._ops import HigherOrderOperator
20from torch._subclasses.fake_tensor import FakeTensorMode
21from torch.fx.experimental.proxy_tensor import (
22    disable_proxy_modes_tracing,
23    ProxyTorchDispatchMode,
24    track_tensor_tree,
25)
26
27
28aten = torch._ops.ops.aten
29
30
31def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves):
32    assert len(args) == 2 * num_leaves
33    lhs = pytree.tree_unflatten(args[:num_leaves], spec)
34    rhs = pytree.tree_unflatten(args[num_leaves:], spec)
35    combined = combine_fn(lhs, rhs)
36    combined_leaves = pytree.tree_leaves(combined)
37    assert num_leaves == len(combined_leaves)
38    return combined_leaves
39
40
41def _interleave(a, b, dim):
42    # https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors
43    if b_trunc := (a.shape[dim] == b.shape[dim] + 1):
44        pad = (
45            [0] * ((b.ndim - dim - 1) * 2 + 1)
46            + [1]
47            + [0] * (b.ndim * 2 - ((b.ndim - dim - 1) * 2 + 2))
48        )
49        b = torch.nn.functional.pad(b, pad)
50
51    stacked = torch.stack([a, b], dim=dim + 1)
52    interleaved = torch.flatten(stacked, start_dim=dim, end_dim=dim + 1)
53    if b_trunc:
54        # TODO: find torch alternative for slice_along dim for torch.jit.script to work
55        interleaved = aten.slice(interleaved, dim, 0, b.shape[dim] + a.shape[dim] - 1)
56    return interleaved
57
58
59def safe_map(f, *args):
60    args = list(map(list, args))
61    n = len(args[0])
62    for arg in args[1:]:
63        if len(arg) != n:
64            raise ValueError("length mismatch: {list(map(len, args))}")
65
66    def nf(a):
67        return f(*a)
68
69    return list(map(nf, zip(*args)))
70
71
72class AssociativeScanOp(HigherOrderOperator):
73    def __init__(self):
74        super().__init__("associative_scan")
75
76    def __call__(self, combine_fn, input, dim):
77        return super().__call__(combine_fn, input, dim)
78
79
80associative_scan_op = AssociativeScanOp()
81
82
83def associative_scan(
84    combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree],
85    input: pytree.PyTree,
86    dim: int,
87    reverse: bool = False,
88    combine_mode: str = "pointwise",
89) -> torch.Tensor:
90    r"""
91    Performs an inclusive scan with an associative pointwise combine function.
92
93    .. warning::
94        `torch.associative_scan` is a prototype feature in PyTorch. It currently
95        does not support autograd and you may run into miscompiles.
96        Read more about feature classification at:
97        https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
98
99    This operator requires runtime code generation and so requires support for
100    ``torch.compile``. Further, only CUDA device codegen is supported at the moment.
101
102    Args:
103        combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``,
104            or if input is a pytree ``(pytree, pytree) -> pytree``.
105            This function must be pure, pointwise, and satisfy the associative property.
106        input (torch.Tensor): The input tensor, or nested pytree of tensors.
107            All inputs are expected to have the same shape.
108        dim (int): the dimension to scan over
109        reverse (bool): A boolean stating if the scan should be reversed with respect to the dimension.
110        combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``.
111            If ``combine_mode=pointwise``, ``combine_fn`` must be pure, may only contain pointwise operations
112            and ``input`` must be CUDA tensors.
113            In all other cases ``combine_mode=generic`` should be used.
114            Note: ``combine_mode=pointwise`` is more efficient than ``combine_mode=generic``.
115
116
117    Example::
118
119        def add(x: torch.Tensor, y: torch.Tensor):
120            return x + y
121
122        cumsum = associative_scan(add, x, dim)
123
124    """
125    assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}"
126    assert isinstance(dim, int), "dim must be an int, but got {type(dim)}"
127    assert combine_mode in ["pointwise", "generic"]
128
129    if not torch._dynamo.is_compiling():
130        with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
131            return torch.compile(associative_scan, fullgraph=True)(
132                combine_fn, input, dim, reverse=reverse, combine_mode=combine_mode
133            )
134
135    leaves, spec = pytree.tree_flatten(input)
136
137    if combine_mode == "pointwise" and not all(l.device.type == "cuda" for l in leaves):
138        raise ValueError(
139            "For combine_mode='pointwise', all input tensors need to be on CUDA"
140        )
141
142    assert len(leaves) >= 1, "expected at least 1 input leaf"
143    assert all(
144        isinstance(x, torch.Tensor) for x in leaves
145    ), "input leaves must be a Tensor"
146
147    if reverse:
148        leaves = [torch.flip(elem, [dim]) for elem in leaves]
149
150    shape = leaves[0].shape
151    ndim = len(shape)
152    dim = utils.canonicalize_dim(ndim, dim)
153
154    for x in leaves[1:]:
155        assert x.shape == shape, "All input tensors must have the same shape"
156
157    out = combine_fn(
158        pytree.tree_unflatten(leaves, spec),
159        pytree.tree_unflatten(leaves, spec),
160    )
161    out_leaves, tree_out = pytree.tree_flatten(out)
162    assert len(leaves) == len(
163        out_leaves
164    ), "The pytree of the output of the operator needs to match the input pytree"
165    for x in out_leaves:
166        assert (
167            x.shape == shape
168        ), "The pytree of the output of the operator needs to match the input pytree"
169
170    combine_fn = functools.partial(
171        wrap_combine_fn_flat, combine_fn=combine_fn, spec=spec, num_leaves=len(leaves)
172    )
173
174    if combine_mode == "generic":
175        result_flat = generic_associative_scan(combine_fn, leaves, dim)
176    else:
177        result_flat = associative_scan_op(combine_fn, leaves, dim)
178
179    if reverse:
180        result_flat = [torch.flip(elem, [dim]) for elem in result_flat]
181
182    return pytree.tree_unflatten(result_flat, spec)
183
184
185def generic_associative_scan(operator, elems_flat, dim=0):
186    r"""
187    This function performs the associative_scan operation.
188    The algorithm works by recursively collecting neighbours of ``elems_flat`` and subsequently
189    applying the ``operator`` on all pairs in parallel along ``dim``.
190    The results of the recursive calls are later combined.
191
192    Args:
193        operator (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``,
194            or if input is a pytree ``(pytree, pytree) -> pytree``.
195            This function must be pure, pointwise, and satisfy the associative property.
196        elems_flat (torch.Tensor): A list of torch.Tensors converted from the pytree of
197            ``input`` provided to ``associative_scan``.
198            All inputs are expected to have the same shape.
199        dim (int): the dimension to scan over
200
201
202    Example::
203
204        def add(x: torch.Tensor, y: torch.Tensor):
205            return x + y
206
207        elems_flat = torch.tensor([0.0, 1.0, 2.0, 3.0])
208
209        First iteration of _scan ->
210            # odd_elems -> apply operator on all neighbours
211            # odd_elems = operator([torch.tensor([0.0, 2.0])],
212            #                      [torch.tensor([1.0, 3.0])])
213            odd_elems = torch.tensor([1.0, 5.0])
214            Second iteration of _scan ->
215                # odd_elems = operator([torch.tensor([1.0])],
216                #                      [torch.tensor([5.0])])
217                odd_elems = torch.tensor([6.0])
218                # even_elems -> apply operator on all odd_elems and
219                # every second element of ``elems``, starting from the second element.
220                # even_elems is expanded with the first element of ``elems``
221                even_elems = [1.0]
222                # Merges odd_elems and even_elems
223                res = torch.tensor([1.0, 6.0])
224            # even_elems -> apply operator on all odd_elems and
225            # every second element of ``elems``, starting from the second element.
226            # even_elems is expanded with the first element of ``elems``
227            even_elems = [0.0, 3.0]
228            # Merges odd_elems and even_elems
229            res = torch.tensor([0.0, 1.0, 3.0, 6.0])
230
231    """
232
233    def _scan(elems):
234        """Perform the actual recursive scan on ``elems``."""
235        num_elems = elems[0].shape[dim]
236
237        if num_elems < 2:
238            return elems
239
240        reduced_elems = operator(
241            *[aten.slice(elem, dim, 0, -1, 2) for elem in elems],
242            *[aten.slice(elem, dim, 1, None, 2) for elem in elems],
243        )
244
245        # Recursively compute scan for partially reduced tensors.
246        odd_elems = _scan(reduced_elems)
247
248        if num_elems % 2 == 0:
249            even_elems = operator(
250                *[aten.slice(e, dim, 0, -1) for e in odd_elems],
251                *[aten.slice(e, dim, 2, None, 2) for e in elems],
252            )
253        else:
254            even_elems = operator(
255                *odd_elems,
256                *[aten.slice(e, dim, 2, None, 2) for e in elems],
257            )
258
259        # The first element of a scan is the same as the first element
260        # of the original `elems`.
261        even_elems = [
262            torch.cat([aten.slice(elem, dim, 0, 1), result], dim=dim)
263            if result.shape.numel() > 0 and elem.shape[dim] > 0
264            else result
265            if result.shape.numel() > 0
266            else aten.slice(
267                elem, dim, 0, 1
268            )  # Jax allows/ignores concat with 0-dim, Pytorch does not
269            for (elem, result) in zip(elems, even_elems)
270        ]
271
272        return list(
273            safe_map(functools.partial(_interleave, dim=dim), even_elems, odd_elems)
274        )
275
276    scans = _scan(elems_flat)
277
278    return scans
279
280
281def trace_associative_scan(
282    proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int
283):
284    with disable_proxy_modes_tracing():
285        sample_inputs = [
286            torch.empty_like(
287                x,
288                dtype=x.dtype,
289                device=x.device,
290                requires_grad=x.requires_grad,
291            )
292            for x in itertools.chain(input, input)
293        ]
294        combine_graph = reenter_make_fx(combine_fn)(*sample_inputs)
295
296    outputs = None
297    for node in combine_graph.graph.nodes:
298        if node.op == "output":
299            assert outputs is None
300            assert len(node.args) == 1
301            outputs = node.args[0]
302
303        if not all(is_pointwise_use(use) or use.op == "output" for use in node.users):
304            raise ValueError(
305                "For combine_mode='pointwise', the combine_fn needs to be pointwise"
306            )
307
308    assert outputs is not None
309    assert len(outputs) == len(
310        input
311    ), f"expected combine_fn to return {len(input)} results but got {len(outputs)}"
312
313    for i, o in zip(input, outputs):
314        o_meta = o.meta["tensor_meta"]
315        assert o_meta.dtype == i.dtype, (
316            f"combine_fn output type mismatch, expected {i.dtype} "
317            + f"but got {o_meta.dtype}"
318        )
319
320    _, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph")
321
322    proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph)
323
324    args = (combine_graph, input, dim)
325    proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
326    out_proxy = proxy_mode.tracer.create_proxy(
327        "call_function", func_overload, proxy_args, {}, name="associative_scan"
328    )
329
330    with disable_proxy_modes_tracing():
331        out = [aten.clone(x) for x in input]
332
333    return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
334
335
336@associative_scan_op.py_impl(DispatchKey.CompositeExplicitAutograd)
337def associative_scan_op_dense(combine_fn, input, dim):
338    raise NotImplementedError("associative_scan is not implemented for eager")
339
340
341associative_scan_op.py_impl(DispatchKey.Autograd)(
342    autograd_not_implemented(associative_scan_op, deferred_error=True)
343)
344
345
346@associative_scan_op.py_impl(ProxyTorchDispatchMode)
347def associative_scan_proxy_mode(mode, combine_fn, input, dim):
348    return trace_associative_scan(mode, associative_scan_op, combine_fn, input, dim)
349
350
351@associative_scan_op.py_impl(FakeTensorMode)
352def assoiciative_scan_fake_tensor_mode(mode, combine_fn, input, dim):
353    with mode:
354        return [x.clone() for x in input]
355
356
357@associative_scan_op.py_functionalize_impl
358def associative_scan_functionalize(ctx, combine_fn, input, dim):
359    unwrapped_input = ctx.unwrap_tensors(input)
360    with ctx.redispatch_to_next() as m:
361        functional_combine_fn = ctx.functionalize(
362            _maybe_run_with_interpreter(combine_fn)
363        )
364        ret = associative_scan_op(functional_combine_fn, unwrapped_input, dim)
365    return ctx.wrap_tensors(ret)
366