1# mypy: allow-untyped-defs 2import functools 3import itertools 4from typing import Callable, List 5 6import torch 7import torch._prims_common as utils 8import torch._subclasses.functional_tensor 9import torch.utils._pytree as pytree 10from torch._C import DispatchKey 11from torch._higher_order_ops.utils import ( 12 _maybe_run_with_interpreter, 13 _set_compilation_env, 14 autograd_not_implemented, 15 reenter_make_fx, 16 unique_graph_id, 17) 18from torch._inductor.utils import is_pointwise_use 19from torch._ops import HigherOrderOperator 20from torch._subclasses.fake_tensor import FakeTensorMode 21from torch.fx.experimental.proxy_tensor import ( 22 disable_proxy_modes_tracing, 23 ProxyTorchDispatchMode, 24 track_tensor_tree, 25) 26 27 28aten = torch._ops.ops.aten 29 30 31def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves): 32 assert len(args) == 2 * num_leaves 33 lhs = pytree.tree_unflatten(args[:num_leaves], spec) 34 rhs = pytree.tree_unflatten(args[num_leaves:], spec) 35 combined = combine_fn(lhs, rhs) 36 combined_leaves = pytree.tree_leaves(combined) 37 assert num_leaves == len(combined_leaves) 38 return combined_leaves 39 40 41def _interleave(a, b, dim): 42 # https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors 43 if b_trunc := (a.shape[dim] == b.shape[dim] + 1): 44 pad = ( 45 [0] * ((b.ndim - dim - 1) * 2 + 1) 46 + [1] 47 + [0] * (b.ndim * 2 - ((b.ndim - dim - 1) * 2 + 2)) 48 ) 49 b = torch.nn.functional.pad(b, pad) 50 51 stacked = torch.stack([a, b], dim=dim + 1) 52 interleaved = torch.flatten(stacked, start_dim=dim, end_dim=dim + 1) 53 if b_trunc: 54 # TODO: find torch alternative for slice_along dim for torch.jit.script to work 55 interleaved = aten.slice(interleaved, dim, 0, b.shape[dim] + a.shape[dim] - 1) 56 return interleaved 57 58 59def safe_map(f, *args): 60 args = list(map(list, args)) 61 n = len(args[0]) 62 for arg in args[1:]: 63 if len(arg) != n: 64 raise ValueError("length mismatch: {list(map(len, args))}") 65 66 def nf(a): 67 return f(*a) 68 69 return list(map(nf, zip(*args))) 70 71 72class AssociativeScanOp(HigherOrderOperator): 73 def __init__(self): 74 super().__init__("associative_scan") 75 76 def __call__(self, combine_fn, input, dim): 77 return super().__call__(combine_fn, input, dim) 78 79 80associative_scan_op = AssociativeScanOp() 81 82 83def associative_scan( 84 combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree], 85 input: pytree.PyTree, 86 dim: int, 87 reverse: bool = False, 88 combine_mode: str = "pointwise", 89) -> torch.Tensor: 90 r""" 91 Performs an inclusive scan with an associative pointwise combine function. 92 93 .. warning:: 94 `torch.associative_scan` is a prototype feature in PyTorch. It currently 95 does not support autograd and you may run into miscompiles. 96 Read more about feature classification at: 97 https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype 98 99 This operator requires runtime code generation and so requires support for 100 ``torch.compile``. Further, only CUDA device codegen is supported at the moment. 101 102 Args: 103 combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``, 104 or if input is a pytree ``(pytree, pytree) -> pytree``. 105 This function must be pure, pointwise, and satisfy the associative property. 106 input (torch.Tensor): The input tensor, or nested pytree of tensors. 107 All inputs are expected to have the same shape. 108 dim (int): the dimension to scan over 109 reverse (bool): A boolean stating if the scan should be reversed with respect to the dimension. 110 combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``. 111 If ``combine_mode=pointwise``, ``combine_fn`` must be pure, may only contain pointwise operations 112 and ``input`` must be CUDA tensors. 113 In all other cases ``combine_mode=generic`` should be used. 114 Note: ``combine_mode=pointwise`` is more efficient than ``combine_mode=generic``. 115 116 117 Example:: 118 119 def add(x: torch.Tensor, y: torch.Tensor): 120 return x + y 121 122 cumsum = associative_scan(add, x, dim) 123 124 """ 125 assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}" 126 assert isinstance(dim, int), "dim must be an int, but got {type(dim)}" 127 assert combine_mode in ["pointwise", "generic"] 128 129 if not torch._dynamo.is_compiling(): 130 with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): 131 return torch.compile(associative_scan, fullgraph=True)( 132 combine_fn, input, dim, reverse=reverse, combine_mode=combine_mode 133 ) 134 135 leaves, spec = pytree.tree_flatten(input) 136 137 if combine_mode == "pointwise" and not all(l.device.type == "cuda" for l in leaves): 138 raise ValueError( 139 "For combine_mode='pointwise', all input tensors need to be on CUDA" 140 ) 141 142 assert len(leaves) >= 1, "expected at least 1 input leaf" 143 assert all( 144 isinstance(x, torch.Tensor) for x in leaves 145 ), "input leaves must be a Tensor" 146 147 if reverse: 148 leaves = [torch.flip(elem, [dim]) for elem in leaves] 149 150 shape = leaves[0].shape 151 ndim = len(shape) 152 dim = utils.canonicalize_dim(ndim, dim) 153 154 for x in leaves[1:]: 155 assert x.shape == shape, "All input tensors must have the same shape" 156 157 out = combine_fn( 158 pytree.tree_unflatten(leaves, spec), 159 pytree.tree_unflatten(leaves, spec), 160 ) 161 out_leaves, tree_out = pytree.tree_flatten(out) 162 assert len(leaves) == len( 163 out_leaves 164 ), "The pytree of the output of the operator needs to match the input pytree" 165 for x in out_leaves: 166 assert ( 167 x.shape == shape 168 ), "The pytree of the output of the operator needs to match the input pytree" 169 170 combine_fn = functools.partial( 171 wrap_combine_fn_flat, combine_fn=combine_fn, spec=spec, num_leaves=len(leaves) 172 ) 173 174 if combine_mode == "generic": 175 result_flat = generic_associative_scan(combine_fn, leaves, dim) 176 else: 177 result_flat = associative_scan_op(combine_fn, leaves, dim) 178 179 if reverse: 180 result_flat = [torch.flip(elem, [dim]) for elem in result_flat] 181 182 return pytree.tree_unflatten(result_flat, spec) 183 184 185def generic_associative_scan(operator, elems_flat, dim=0): 186 r""" 187 This function performs the associative_scan operation. 188 The algorithm works by recursively collecting neighbours of ``elems_flat`` and subsequently 189 applying the ``operator`` on all pairs in parallel along ``dim``. 190 The results of the recursive calls are later combined. 191 192 Args: 193 operator (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``, 194 or if input is a pytree ``(pytree, pytree) -> pytree``. 195 This function must be pure, pointwise, and satisfy the associative property. 196 elems_flat (torch.Tensor): A list of torch.Tensors converted from the pytree of 197 ``input`` provided to ``associative_scan``. 198 All inputs are expected to have the same shape. 199 dim (int): the dimension to scan over 200 201 202 Example:: 203 204 def add(x: torch.Tensor, y: torch.Tensor): 205 return x + y 206 207 elems_flat = torch.tensor([0.0, 1.0, 2.0, 3.0]) 208 209 First iteration of _scan -> 210 # odd_elems -> apply operator on all neighbours 211 # odd_elems = operator([torch.tensor([0.0, 2.0])], 212 # [torch.tensor([1.0, 3.0])]) 213 odd_elems = torch.tensor([1.0, 5.0]) 214 Second iteration of _scan -> 215 # odd_elems = operator([torch.tensor([1.0])], 216 # [torch.tensor([5.0])]) 217 odd_elems = torch.tensor([6.0]) 218 # even_elems -> apply operator on all odd_elems and 219 # every second element of ``elems``, starting from the second element. 220 # even_elems is expanded with the first element of ``elems`` 221 even_elems = [1.0] 222 # Merges odd_elems and even_elems 223 res = torch.tensor([1.0, 6.0]) 224 # even_elems -> apply operator on all odd_elems and 225 # every second element of ``elems``, starting from the second element. 226 # even_elems is expanded with the first element of ``elems`` 227 even_elems = [0.0, 3.0] 228 # Merges odd_elems and even_elems 229 res = torch.tensor([0.0, 1.0, 3.0, 6.0]) 230 231 """ 232 233 def _scan(elems): 234 """Perform the actual recursive scan on ``elems``.""" 235 num_elems = elems[0].shape[dim] 236 237 if num_elems < 2: 238 return elems 239 240 reduced_elems = operator( 241 *[aten.slice(elem, dim, 0, -1, 2) for elem in elems], 242 *[aten.slice(elem, dim, 1, None, 2) for elem in elems], 243 ) 244 245 # Recursively compute scan for partially reduced tensors. 246 odd_elems = _scan(reduced_elems) 247 248 if num_elems % 2 == 0: 249 even_elems = operator( 250 *[aten.slice(e, dim, 0, -1) for e in odd_elems], 251 *[aten.slice(e, dim, 2, None, 2) for e in elems], 252 ) 253 else: 254 even_elems = operator( 255 *odd_elems, 256 *[aten.slice(e, dim, 2, None, 2) for e in elems], 257 ) 258 259 # The first element of a scan is the same as the first element 260 # of the original `elems`. 261 even_elems = [ 262 torch.cat([aten.slice(elem, dim, 0, 1), result], dim=dim) 263 if result.shape.numel() > 0 and elem.shape[dim] > 0 264 else result 265 if result.shape.numel() > 0 266 else aten.slice( 267 elem, dim, 0, 1 268 ) # Jax allows/ignores concat with 0-dim, Pytorch does not 269 for (elem, result) in zip(elems, even_elems) 270 ] 271 272 return list( 273 safe_map(functools.partial(_interleave, dim=dim), even_elems, odd_elems) 274 ) 275 276 scans = _scan(elems_flat) 277 278 return scans 279 280 281def trace_associative_scan( 282 proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int 283): 284 with disable_proxy_modes_tracing(): 285 sample_inputs = [ 286 torch.empty_like( 287 x, 288 dtype=x.dtype, 289 device=x.device, 290 requires_grad=x.requires_grad, 291 ) 292 for x in itertools.chain(input, input) 293 ] 294 combine_graph = reenter_make_fx(combine_fn)(*sample_inputs) 295 296 outputs = None 297 for node in combine_graph.graph.nodes: 298 if node.op == "output": 299 assert outputs is None 300 assert len(node.args) == 1 301 outputs = node.args[0] 302 303 if not all(is_pointwise_use(use) or use.op == "output" for use in node.users): 304 raise ValueError( 305 "For combine_mode='pointwise', the combine_fn needs to be pointwise" 306 ) 307 308 assert outputs is not None 309 assert len(outputs) == len( 310 input 311 ), f"expected combine_fn to return {len(input)} results but got {len(outputs)}" 312 313 for i, o in zip(input, outputs): 314 o_meta = o.meta["tensor_meta"] 315 assert o_meta.dtype == i.dtype, ( 316 f"combine_fn output type mismatch, expected {i.dtype} " 317 + f"but got {o_meta.dtype}" 318 ) 319 320 _, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph") 321 322 proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph) 323 324 args = (combine_graph, input, dim) 325 proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) 326 out_proxy = proxy_mode.tracer.create_proxy( 327 "call_function", func_overload, proxy_args, {}, name="associative_scan" 328 ) 329 330 with disable_proxy_modes_tracing(): 331 out = [aten.clone(x) for x in input] 332 333 return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) 334 335 336@associative_scan_op.py_impl(DispatchKey.CompositeExplicitAutograd) 337def associative_scan_op_dense(combine_fn, input, dim): 338 raise NotImplementedError("associative_scan is not implemented for eager") 339 340 341associative_scan_op.py_impl(DispatchKey.Autograd)( 342 autograd_not_implemented(associative_scan_op, deferred_error=True) 343) 344 345 346@associative_scan_op.py_impl(ProxyTorchDispatchMode) 347def associative_scan_proxy_mode(mode, combine_fn, input, dim): 348 return trace_associative_scan(mode, associative_scan_op, combine_fn, input, dim) 349 350 351@associative_scan_op.py_impl(FakeTensorMode) 352def assoiciative_scan_fake_tensor_mode(mode, combine_fn, input, dim): 353 with mode: 354 return [x.clone() for x in input] 355 356 357@associative_scan_op.py_functionalize_impl 358def associative_scan_functionalize(ctx, combine_fn, input, dim): 359 unwrapped_input = ctx.unwrap_tensors(input) 360 with ctx.redispatch_to_next() as m: 361 functional_combine_fn = ctx.functionalize( 362 _maybe_run_with_interpreter(combine_fn) 363 ) 364 ret = associative_scan_op(functional_combine_fn, unwrapped_input, dim) 365 return ctx.wrap_tensors(ret) 366