xref: /aosp_15_r20/external/pytorch/torch/nested/_internal/ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import math
4import operator
5from typing import *  # noqa: F403
6
7import torch
8import torch.nn.functional as F
9from torch.fx.operator_schemas import normalize_function
10from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
11
12from .nested_tensor import NestedTensor
13
14
15__all__: List[Any] = []
16
17JAGGED_OPS_TABLE: Dict[Any, Any] = {}
18
19
20# Simplifying assumption: we assume that the batch dim is always the left-most
21# dim, and the ragged dim is always the second dim.
22def _outer_to_inner_dim(ndim, dim):
23    assert dim >= 0 and dim < ndim
24    return 0 if dim < 2 else dim - 1
25
26
27def _wrap_jagged_dim(
28    ndim, dim, op_name, convert_to_inner_dim=True, allow_batch_dim=False
29):
30    from torch._prims_common import canonicalize_dims
31
32    wrapped = canonicalize_dims(ndim, dim)
33    if wrapped == 1:
34        raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=1")
35    elif wrapped == 0 and not allow_batch_dim:
36        raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0")
37    return _outer_to_inner_dim(ndim, wrapped) if convert_to_inner_dim else wrapped
38
39
40def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
41    """
42    For NestedTensor operators,
43    wraps dimensions to non-negative values,
44    and returns metadata related to reduction dimension(s).
45    """
46    from torch._prims_common import canonicalize_dims
47
48    assert isinstance(
49        dims, (tuple, list)
50    ), f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}"
51
52    wrapped_dims = [
53        canonicalize_dims(ndim, d) for d in dims
54    ]  # convert all indices to non-negative values
55
56    operate_on_batch = 0 in wrapped_dims
57    operate_on_ragged = ragged_idx in wrapped_dims
58    operate_on_non_batch = any(d != 0 and d != ragged_idx for d in wrapped_dims)
59
60    outer_to_inner_dim = tuple(
61        _outer_to_inner_dim(ndim, d) for d in wrapped_dims if d != 0
62    )
63
64    return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch
65
66
67def check_schema(schema_str: str, func, *args, **kwargs) -> None:
68    named_arg_types = schema_str.split(", ")
69    num_optional_args = [x.endswith("?") for x in named_arg_types].count(True)
70    min_args = len(named_arg_types) - num_optional_args
71
72    # special case: ellipses allows for any number of unchecked args at the end
73    if named_arg_types[-1] == "...":
74        named_arg_types = named_arg_types[:-1]
75    else:
76        if not (len(args) >= min_args and len(args) <= len(named_arg_types)):
77            raise ValueError(
78                f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} "
79                f"arguments and at most {len(named_arg_types)} arguments, but got: "
80                f"{len(args)} arguments"
81            )
82
83    arg_type_check_fns = {
84        "t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
85        "jt": lambda x: isinstance(x, NestedTensor)
86        and x._lengths is None
87        and x._ragged_idx == 1,  # ops with "jt" require contiguous JT only
88        "jt_all": lambda x: isinstance(
89            x, NestedTensor
90        ),  # ops with "jt_all" can accept all kinds of JT
91        "any": lambda x: True,
92    }
93    for i, named_arg_type in enumerate(named_arg_types):
94        name, arg_type = named_arg_type.split(": ")
95        is_optional = arg_type.endswith("?")
96        normalized_arg_type = arg_type[:-1] if is_optional else arg_type
97        if normalized_arg_type not in arg_type_check_fns.keys():
98            raise AssertionError(f"Unknown arg type: {normalized_arg_type}")
99
100        if i >= len(args):
101            if not is_optional:
102                raise ValueError(
103                    f"NestedTensor {func.__name__}({schema_str}) "
104                    f"missing required argument: {name}"
105                )
106            continue
107
108        _check_fn = arg_type_check_fns[normalized_arg_type]
109
110        def check_fn(x, is_optional=is_optional):
111            if is_optional:
112                return x is None or _check_fn(x)
113            else:
114                return _check_fn(x)
115
116        if not check_fn(args[i]):
117            type_to_desc = {
118                "t": "tensor",
119                "t?": "optional tensor",
120                "jt": "contiguous jagged layout NestedTensor",
121                "jt_all": "jagged layout NestedTensor",
122                "any": "<any type>",
123            }
124
125            raise ValueError(
126                f"NestedTensor {func.__name__}({schema_str}): expected {name} to be a "
127                f"{type_to_desc[arg_type]}"
128            )
129
130
131def check_ragged_dim_same(
132    func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str
133) -> None:
134    # Calling into .shape here
135    if a._size[a._ragged_idx] != b._size[b._ragged_idx]:
136        raise RuntimeError(
137            f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the "
138            "same exact offsets tensor."
139        )
140
141
142# returns True if the raggedness-relevant portions of the NT shape
143# match those of the specified size
144def raggedness_matches(nt, size):
145    end = nt._ragged_idx + 1
146    nt_ragged = nt._size[:end]
147    size_ragged = size[:end]
148    return len(nt_ragged) == len(size_ragged) and (
149        all(ns == s or s == -1 for ns, s in zip(nt_ragged, size_ragged))
150    )
151
152
153def squeeze_leading_ones(t):
154    # Note: [ Squeezing leading ones ]
155    #
156    # Squeeze leading ones from t.
157    #
158    # We want:
159    #   (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
160    #   (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)  (not yet supported)
161    #
162    # 1) Squeeze extra ones and grab values from NT
163    #   (1, 1, ?, ?) -> (?, ?)   and   (sum(*), ?, ?) -> (B, j0, ?, ?)
164    # 2) Do dense broadcasting:
165    #   (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?)
166    # 3) Construct nested tensor
167    #   (sum(*), ?, ?) -> (B, j0, ?, ?)
168    #
169    # If unsqueezing on the 0th dim becomes supported, we would unsqueeze
170    # at step (4) and we would need to update this function to record how
171    # many ones we unsqueezed.
172    while t.dim() > 0 and t.shape[0] == 1:
173        t = t.squeeze(0)
174    return t
175
176
177def register_func(tables, aten_ops, schema_str):
178    if not isinstance(aten_ops, list):
179        aten_ops = [aten_ops]
180    if not isinstance(tables, list):
181        tables = [tables]
182
183    def wrapper(func):
184        for aten_op in aten_ops:
185
186            def get_inner(aten_op):
187                def inner(*args, **kwargs):
188                    check_schema(schema_str, func, *args, **kwargs)
189                    return func(aten_op, *args, **kwargs)
190
191                return inner
192
193            for table in tables:
194                table[aten_op] = get_inner(aten_op)
195        return func
196
197    return wrapper
198
199
200register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE)
201
202
203def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
204    dispatch_func = JAGGED_OPS_TABLE.get(func, None)
205    if dispatch_func is not None:
206        return dispatch_func
207
208    # Handle pointwise fallbacks
209    if torch.Tag.pointwise in func.tags:
210        # Assume there aren't additional tensors that aren't the "unary/binary" args
211        num_tensor_args = sum(isinstance(x, torch.Tensor) for x in args)
212        if num_tensor_args == 1:
213            # Build up the check schema string. The first tensor arg is assumed to be
214            # an NJT and other args are sent through as-is.
215            schema_parts = []
216            for arg in func._schema.arguments:
217                if isinstance(arg.type, torch.TensorType):
218                    schema_parts.append(f"{arg.name}: jt_all")
219                    break
220                else:
221                    schema_parts.append(f"{arg.name}: any")
222            schema_parts.append("...")
223            check_schema_str = ", ".join(schema_parts)
224            check_schema(check_schema_str, func, *args, **kwargs)
225            return functools.partial(jagged_unary_pointwise, func)
226        elif num_tensor_args == 2:
227            check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs)
228            return functools.partial(jagged_binary_pointwise, func)
229
230    return None
231
232
233def extract_kwargs(arg):
234    kwargs = {
235        "offsets": arg.offsets(),
236        "_metadata_cache": arg._metadata_cache,
237        "_ragged_idx": arg._ragged_idx,
238    }
239    return kwargs
240
241
242def jagged_unary_pointwise(func, *args, **kwargs):
243    # assume if we get here that there is a single NJT input in the args
244    njt = next(arg for arg in args if isinstance(arg, NestedTensor))
245    return NestedTensor(
246        func(*(arg._values if arg is njt else arg for arg in args), **kwargs),
247        **extract_kwargs(njt),
248    )
249
250
251def jagged_binary_pointwise(func, *args, **kwargs):
252    a, b = args[0], args[1]
253    assert isinstance(a, NestedTensor) or isinstance(b, NestedTensor)
254
255    mismatch_error_msg = (
256        "cannot call binary pointwise function {} with inputs of shapes {} and {}"
257    )
258    # a is NT, b is NT
259    if isinstance(a, NestedTensor) and isinstance(b, NestedTensor):
260        # ex: (B, j0, D) + (B, j0, D)
261        # ex: (B, j0, D) + (B, j0, 1)
262        if raggedness_matches(a, b._size):
263            return NestedTensor(
264                func(a._values, b._values, *args[2:], **kwargs), **extract_kwargs(a)
265            )
266        raise RuntimeError(mismatch_error_msg.format(func.__name__, a._size, b._size))
267    # either a is NT or b is NT at this point
268    a_is_nt = isinstance(a, NestedTensor)
269    extracted_kwargs = extract_kwargs(a) if a_is_nt else extract_kwargs(b)
270
271    # === Handle broadcasting across the batch / ragged dims ===
272
273    # Easy case: take advantage of pre-existing broadcasting logic
274    # ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
275    # ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
276    # ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
277    nt, t = (a, b) if a_is_nt else (b, a)
278    # See Note: [ Squeezing leading ones ]
279    if t.dim() > nt.dim():
280        raise NotImplementedError("NYI: broadcasting NT with T with larger dim")
281    t_squeezed = squeeze_leading_ones(t)
282    if nt.dim() >= t_squeezed.dim() + 2:
283        lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values)
284        return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs)
285
286    # Harder case: do manual broadcasting over unbound components
287    # when NT dim == non-NT dim
288    # ex: (B, j0, D_0, D_1) + (B, 1, D_0, D_1) -> (B, j0, D_0, D_1)
289    if a.dim() == b.dim():
290        # ex: (B, j0, D_0, D_1) + (1, 1, D_0, D_1) -> should
291        # be (B, j0, D_0, D_1) but not yet supported
292        if a.shape[0] != b.shape[0]:
293            raise RuntimeError(
294                mismatch_error_msg.format(func.__name__, a.shape, b.shape)
295            )
296
297        # need to use offsets to broadcast across ragged dim properly
298        # NB: inefficient fallback here; Triton codegen can help this
299        # TODO: Make this work with autograd
300        outputs = []
301        for a_comp, b_comp in zip(a.unbind(), b.unbind()):
302            outputs.append(func(a_comp, b_comp, *args[2:], **kwargs))
303        new_values = torch.cat(outputs, dim=0)
304        return NestedTensor(new_values, **extracted_kwargs)
305
306    # ex: (B, j0, D_0, D_1) + (A, B, 1, D_0, D_1) -> error because this breaks the invariant
307    # that ragged dim is wrt left-most batch dim
308    raise RuntimeError(mismatch_error_msg.format(func.__name__, a.shape, b.shape))
309
310
311def jagged_torch_function(func, *args, **kwargs):
312    # SDPA has special kernels that handle nested tensors.
313    # Dispatch to the correct implementation here
314    if func is torch._C._nn.scaled_dot_product_attention:
315        return jagged_scaled_dot_product_attention(*args, **kwargs)
316
317    if func.__name__ == "apply_":
318        func(args[0]._values, *args[1:], **kwargs)
319        return args[0]
320
321    # Handle flatten() here because it's CompositeImplicit.
322    if func.__name__ == "flatten":
323
324        def _flatten_sig(input, start_dim=0, end_dim=-1):
325            pass
326
327        _, new_kwargs = normalize_function(  # type: ignore[misc]
328            _flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
329        )
330
331        inp = new_kwargs.pop("input")
332
333        # NB: stay in outer dim space because we're going to redispatch on a NT input
334        start_dim = _wrap_jagged_dim(
335            inp.dim(), new_kwargs["start_dim"], "flatten", convert_to_inner_dim=False
336        )
337        end_dim = _wrap_jagged_dim(
338            inp.dim(), new_kwargs["end_dim"], "flatten", convert_to_inner_dim=False
339        )
340
341        if start_dim == end_dim:
342            return inp
343
344        product = functools.reduce(operator.mul, inp.shape[start_dim : end_dim + 1])
345        new_shape = (*inp.shape[:start_dim], product, *inp.shape[end_dim + 1 :])
346
347        return inp.reshape(*new_shape)
348
349    raise NotImplementedError(func)
350
351
352@register_jagged_func(
353    [
354        torch.ops.aten.is_non_overlapping_and_dense.default,
355        torch.ops.aten.sym_size.default,
356        torch.ops.aten.dim.default,
357        torch.ops.aten.numel.default,
358        torch.ops.aten.sym_numel.default,
359        torch.ops.aten.sym_stride.default,
360        torch.ops.aten.sym_storage_offset.default,
361    ],
362    "self: jt_all",
363)
364def tensor_attr_supported_getter(func, *args, **kwargs):
365    if func == torch.ops.aten.is_non_overlapping_and_dense.default:
366        return False
367
368    if func == torch.ops.aten.sym_size.default:
369        return args[0]._size
370
371    if func == torch.ops.aten.dim.default:
372        return len(args[0]._size)
373
374    if func in (torch.ops.aten.sym_numel.default, torch.ops.aten.numel.default):
375        if args[0]._lengths is not None:
376            return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:]))
377        return args[0]._values.numel()
378
379    if func == torch.ops.aten.sym_stride.default:
380        return args[0]._strides
381
382    if func == torch.ops.aten.sym_storage_offset.default:
383        return args[0]._values.storage_offset()
384
385
386@register_jagged_func(torch.ops.prim.layout.default, "self: jt_all")
387def prim_layout_default(func, *args, **kwargs):
388    return torch.jagged
389
390
391@register_jagged_func(
392    [torch.ops.aten.size.default],
393    "self: jt_all",
394)
395def tensor_attr_unsupported_getter(func, *args, **kwargs):
396    if func == torch.ops.aten.size.default:
397        raise RuntimeError(
398            "NestedTensors does not support directly calling torch.ops.aten.size "
399            "please use `nested_tensor.size()` instead."
400        )
401
402
403@register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all")
404def is_contiguous_general(func, *args, **kwargs):
405    from torch._prims_common import is_contiguous_for_memory_format
406
407    _, new_kwargs = normalize_function(  # type: ignore[misc]
408        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
409    )
410    inp = new_kwargs.pop("input")
411
412    # If created from narrow() check for lengths
413    if inp.lengths() is not None:
414        return False
415
416    new_kwargs["memory_format"] = new_kwargs.get(
417        "memory_format", torch.contiguous_format
418    )
419    if new_kwargs["memory_format"] == torch.preserve_format:
420        return True
421    return is_contiguous_for_memory_format(inp._values, **new_kwargs)
422
423
424register_jagged_func(
425    torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?"
426)(is_contiguous_general)
427
428
429@register_jagged_func(
430    torch.ops.aten.clone.default, "input: jt_all, memory_format: any?"
431)
432def clone_default(func, *args, **kwargs):
433    _, new_kwargs = normalize_function(  # type: ignore[misc]
434        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
435    )
436
437    inp = new_kwargs.pop("input")
438
439    new_meta = extract_kwargs(inp)
440
441    if inp._lengths is not None:
442        if new_kwargs["memory_format"] == torch.contiguous_format:
443            # need to copy to remove "holes" non-contiguity / lengths metadata
444            # TODO: write a kernel for this
445            from .nested_tensor import jagged_from_list
446
447            # TODO: We probably want the output to have the same ragged structure / nested int.
448            assert (
449                inp._ragged_idx == 1
450            ), "NJT with ragged_idx != 1 not supported for contiguous clone"
451            contig, _ = jagged_from_list(inp.unbind(), offsets=None)
452            return contig
453        else:
454            # need to preserve any lengths metadata present
455            new_meta["lengths"] = inp._lengths
456
457    return NestedTensor(func(inp._values, **new_kwargs), **new_meta)
458
459
460@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
461def linear_default(func, *args, **kwargs):
462    _, new_kwargs = normalize_function(  # type: ignore[misc]
463        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
464    )
465
466    inp = new_kwargs.pop("input")
467
468    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
469
470
471@register_jagged_func(
472    torch.ops.aten.linear_backward.default,
473    "self: jt, grad_output: jt, weight: t, output_mask: any",
474)
475def linear_backward_default(func, *args, **kwargs):
476    _, new_kwargs = normalize_function(  # type: ignore[misc]
477        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
478    )
479
480    inp = new_kwargs.pop("input")
481    grad_output = new_kwargs.pop("grad_output")
482    weight = new_kwargs.pop("weight")
483
484    check_ragged_dim_same(func, inp, "self", grad_output, "grad_output")
485    ds = NestedTensor(
486        torch.matmul(grad_output._values, weight), **extract_kwargs(grad_output)
487    )
488    dw = torch.matmul(grad_output._values.transpose(-2, -1), inp._values)
489    db = None  # NYI: gradient for bias, need to reduce over ragged dim
490    return (ds, dw, db)
491
492
493@register_jagged_func(torch.ops.aten.to.dtype, "input: jt_all, dtype: any")
494def to_dtype(func, *args, **kwargs):
495    _, new_kwargs = normalize_function(  # type: ignore[misc]
496        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
497    )
498
499    inp = new_kwargs.pop("input")
500
501    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
502
503
504@register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all")
505def to_copy_default(func, *args, **kwargs):
506    from .nested_tensor import _tensor_symint_registry
507
508    _, new_kwargs = normalize_function(  # type: ignore[misc]
509        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
510    )
511
512    inp = new_kwargs.pop("input")
513    # don't change layout
514    new_kwargs.pop("layout")
515
516    new_values = func(inp._values, **new_kwargs)
517    new_offsets = inp._offsets.to(device=new_values.device)
518
519    from torch._subclasses.fake_tensor import FakeTensor
520    from torch._subclasses.functional_tensor import (
521        FunctionalTensor,
522        mb_unwrap_functional_tensor,
523    )
524
525    if isinstance(new_offsets, (FakeTensor, FunctionalTensor)):
526        # Temporary hack until we have the union find
527        tgt = mb_unwrap_functional_tensor(new_offsets)
528        src = mb_unwrap_functional_tensor(inp._offsets)
529        tgt.nested_int_memo = src.nested_int_memo
530    else:
531        _tensor_symint_registry[new_offsets] = _tensor_symint_registry[inp._offsets]
532    inp_kwargs = extract_kwargs(inp)
533    inp_kwargs["offsets"] = new_offsets
534
535    return NestedTensor(new_values, **inp_kwargs)
536
537
538@register_jagged_func(
539    torch.ops.aten.copy_.default, "self: jt_all, src: jt_all, non_blocking: any?"
540)
541def copy_default(func, *args, **kwargs):
542    _, new_kwargs = normalize_function(  # type: ignore[misc]
543        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
544    )
545    inp = new_kwargs.pop("input")
546    src = new_kwargs.pop("src")
547    if inp._size != src._size:
548        raise RuntimeError(
549            "copy_ only supports Nested Tensors that have same size and the exact same offset tensor."
550        )
551    inp.values().copy_(src.values())
552    return inp
553
554
555register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")(
556    jagged_unary_pointwise
557)
558
559
560@register_jagged_func(
561    [
562        torch.ops.aten.empty_like.default,
563        torch.ops.aten.ones_like.default,
564        torch.ops.aten.zeros_like.default,
565        torch.ops.aten.randn_like.default,
566    ],
567    "self: jt_all",
568)
569def like_factory_default(func, *args, **kwargs):
570    _, new_kwargs = normalize_function(  # type: ignore[misc]
571        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
572    )
573
574    inp = new_kwargs.pop("input")
575
576    # Default layout is technically torch.strided but only jagged is supported here.
577    # Rather than force users to specify the layout, assume jagged.
578    # This should be set to strided for redispatching on values.
579    new_kwargs["layout"] = torch.strided
580
581    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
582
583
584@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
585def zero__default(func, *args, **kwargs):
586    _, new_kwargs = normalize_function(  # type: ignore[misc]
587        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
588    )
589
590    inp = new_kwargs.pop("input")
591    func(inp._values)
592    return inp
593
594
595@register_jagged_func(
596    torch.ops.aten._softmax.default, "self: jt_all, dim: any, half_to_float: any"
597)
598def _softmax_default(func, *args, **kwargs):
599    _, new_kwargs = normalize_function(  # type: ignore[misc]
600        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
601    )
602
603    if isinstance(new_kwargs["dim"], tuple):
604        raise RuntimeError(
605            "softmax(): not supported for dimensions of type 'tuple' for NestedTensor"
606        )
607
608    inp = new_kwargs.pop("input")
609
610    (
611        new_kwargs["dim"],
612        reduce_on_batch,
613        reduce_on_ragged,
614        reduce_on_non_batch,
615    ) = _wrap_jagged_dims(
616        inp.dim(),
617        (new_kwargs["dim"],),
618        "softmax",
619        inp._ragged_idx,
620    )
621
622    if reduce_on_batch:
623        raise RuntimeError(
624            "softmax(): not supported when reducing across the batch dimension for NestedTensor"
625        )
626
627    if reduce_on_ragged and inp._ragged_idx > 1:
628        raise RuntimeError(
629            "softmax(): not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor"
630        )
631
632    if reduce_on_ragged and inp._lengths is not None:
633        raise RuntimeError(
634            "softmax(): not supported where lengths is not None "
635            + "if reducing across the ragged dimension for NestedTensor"
636        )
637
638    new_kwargs["dim"] = new_kwargs["dim"][
639        0
640    ]  # torch.softmax takes in the reduction dimension as an integer
641
642    if reduce_on_ragged:
643        padded_softmax_values = torch.nn.functional.softmax(
644            torch.ops.aten._jagged_to_padded_dense_forward(
645                inp._values.reshape(
646                    inp._values.shape[0], -1
647                ),  # values are required to be 2D tensors for j2pd
648                [inp._offsets],
649                max_lengths=[inp._max_seqlen],  # max length of ragged dimension
650                padding_value=float("-inf"),  # e^-inf = 0
651            ),
652            dim=inp._ragged_idx,
653        )
654
655        softmax_values = torch.ops.aten._padded_dense_to_jagged_forward(
656            padded_softmax_values,
657            [inp._offsets],
658            total_L=inp._values.shape[
659                0
660            ],  # providing this parameter helps avoid a GPU/CPU sync
661        ).reshape(
662            -1, *inp._values.shape[1:]
663        )  # expand softmax_values back to original shape (inp._values.shape)
664
665        return NestedTensor(softmax_values, **extract_kwargs(inp))
666
667    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
668
669
670@register_jagged_func(
671    torch.ops.aten._softmax_backward_data.default,
672    "grad_output: jt, output: jt, dim: any, input_dtype: any",
673)
674def _softmax_backward(func, *args, **kwargs):
675    _, new_kwargs = normalize_function(  # type: ignore[misc]
676        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
677    )
678    grad_out = new_kwargs.pop("grad_output")
679    output = new_kwargs.pop("output")
680    return NestedTensor(
681        func(grad_out._values, output._values, **new_kwargs), **extract_kwargs(grad_out)
682    )
683
684
685@register_jagged_func(
686    torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?"
687)
688def native_dropout_default(func, *args, **kwargs):
689    _, new_kwargs = normalize_function(  # type: ignore[misc]
690        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
691    )
692
693    inp = new_kwargs.pop("input")
694    out1, out2 = func(inp._values, **new_kwargs)
695    return (
696        NestedTensor(out1, **extract_kwargs(inp)),
697        NestedTensor(out2, **extract_kwargs(inp)),
698    )
699
700
701@register_jagged_func(
702    torch.ops.aten.native_dropout_backward.default,
703    "grad_output: jt, mask: jt, scale: any",
704)
705def native_dropout_backward_default(func, *args, **kwargs):
706    _, new_kwargs = normalize_function(  # type: ignore[misc]
707        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
708    )
709    grad_output = new_kwargs.pop("grad_output")
710    mask = new_kwargs.pop("mask")
711    return NestedTensor(
712        func(grad_output._values, mask._values, **new_kwargs),
713        **extract_kwargs(grad_output),
714    )
715
716
717@register_jagged_func(torch.ops.aten.prod.dim_int, "self: jt, dim: any, keepdim: any?")
718def prod_dim_int(func, *args, **kwargs):
719    _, new_kwargs = normalize_function(  # type: ignore[misc]
720        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
721    )
722
723    inp = new_kwargs.pop("input")
724    # TODO: Figure out how to handle this better
725    # keep_dim is required to keep it in jagged format
726    if not new_kwargs["keepdim"]:
727        raise RuntimeError("prod(): keepdim=True must be set for NestedTensor")
728    dim = new_kwargs["dim"]
729    new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), dim, "prod")
730
731    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(args[0]))
732
733
734@register_jagged_func(
735    torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any"
736)
737def split_tensor(func, *args, **kwargs):
738    _, new_kwargs = normalize_function(  # type: ignore[misc]
739        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
740    )
741
742    inp = new_kwargs.pop("input")
743
744    new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "split")
745
746    return tuple(
747        NestedTensor(values=x, **extract_kwargs(inp))
748        for x in func(inp._values, **new_kwargs)
749    )
750
751
752@register_jagged_func(
753    torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any"
754)
755def split_with_sizes_default(func, *args, **kwargs):
756    _, new_kwargs = normalize_function(  # type: ignore[misc]
757        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
758    )
759
760    inp = new_kwargs.pop("input")
761
762    new_kwargs["dim"] = _wrap_jagged_dim(
763        inp.dim(), new_kwargs["dim"], "split_with_sizes"
764    )
765
766    return [
767        NestedTensor(values=x, **extract_kwargs(inp))
768        for x in func(inp._values, **new_kwargs)
769    ]
770
771
772@register_jagged_func(
773    torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any"
774)
775def narrow(func, *args, **kwargs):
776    _, new_kwargs = normalize_function(  # type: ignore[misc]
777        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
778    )
779    inp = new_kwargs.pop("input")
780
781    dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "narrow")
782    values = func(
783        inp._values,
784        dim=dim,
785        start=new_kwargs["start"],
786        length=new_kwargs["length"],
787    )
788    return NestedTensor(values, **extract_kwargs(inp))
789
790
791@register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?")
792def chunk_default(func, *args, **kwargs):
793    _, new_kwargs = normalize_function(  # type: ignore[misc]
794        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
795    )
796
797    inp = new_kwargs.pop("input")
798
799    new_kwargs["dim"] = _wrap_jagged_dim(
800        inp.dim(), new_kwargs["dim"], "chunk", allow_batch_dim=True
801    )
802
803    if new_kwargs["dim"] == 0:
804        chunks = new_kwargs["chunks"]
805        dim0_size = inp._size[0]
806        chunk_size = math.ceil(dim0_size / chunks)
807
808        # get _offsets of the chunks
809        lengths = inp._offsets.diff()
810        chunked_lengths = lengths.chunk(chunks)
811        chunked_offsets = [torch.cumsum(x, dim=0) for x in chunked_lengths]
812        chunked_offsets = [F.pad(x, (1, 0), value=0) for x in chunked_offsets]  # type: ignore[arg-type]
813        nested_kwargs = [
814            {"offsets": per_offsets, "_ragged_idx": inp._ragged_idx}
815            for per_offsets in chunked_offsets
816        ]
817
818        # get _values of the chunks
819        split_sizes = [x.sum().item() for x in chunked_lengths]
820        chunk_values = inp._values.split(split_sizes)
821
822        return [
823            NestedTensor(values=chunk_values[i], **(nested_kwargs[i]))
824            for i in range(0, chunk_size)
825        ]
826    else:
827        return [
828            NestedTensor(values=x, **extract_kwargs(inp))
829            for x in func(inp._values, **new_kwargs)
830        ]
831
832
833@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
834def unbind_int(func, *args, **kwargs):
835    # Note that this specializes on the length of the offsets
836    _, new_kwargs = normalize_function(  # type: ignore[misc]
837        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
838    )
839
840    dim = new_kwargs["dim"]
841    if dim != 0:
842        raise RuntimeError("unbind(): only supported for NestedTensor on dim=0")
843
844    inp = new_kwargs.pop("input")
845    values = inp.values()
846    offsets = inp.offsets()
847    lengths = inp.lengths()
848    ragged_idx = inp._ragged_idx
849
850    if lengths is None:
851        return torch.split(values, offsets.diff().tolist(), dim=(ragged_idx - 1))
852
853    if ragged_idx <= 0:
854        raise RuntimeError(
855            "unbind(): nested tensor ragged_idx out of bounds (should be >= 1)"
856        )
857    for i in range(lengths.shape[0]):
858        if offsets[i] + lengths[i] > values.shape[ragged_idx - 1]:
859            raise RuntimeError(
860                "unbind(): nested tensor offsets and lengths do not match ragged_idx dimension"
861            )
862    return [
863        torch.narrow(values, dim=(ragged_idx - 1), start=offsets[i], length=lengths[i])
864        for i in range(lengths.shape[0])
865    ]
866
867
868@register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any")
869def squeeze_dim(func, *args, **kwargs):
870    _, new_kwargs = normalize_function(  # type: ignore[misc]
871        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
872    )
873
874    inp = new_kwargs.pop("input")
875    values = inp._values
876
877    new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), new_kwargs["dim"], "squeeze")
878    return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp))
879
880
881@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any")
882def unsqueeze_default(func, *args, **kwargs):
883    _, new_kwargs = normalize_function(  # type: ignore[misc]
884        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
885    )
886
887    inp = new_kwargs.pop("input")
888    values = inp._values
889
890    # Account for collapsed jagged dim
891    dim = new_kwargs["dim"]
892    new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size) + 1, dim, "unsqueeze")
893    return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp))
894
895
896@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
897def cat_default(func, *args, **kwargs):
898    _, new_kwargs = normalize_function(  # type: ignore[misc]
899        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
900    )
901
902    tensors = new_kwargs.pop("tensors")
903
904    # Convert any non-nested to nested
905    nested = [t for t in tensors if t.is_nested]
906    assert len(nested) > 0
907    first = nested[0]
908    tensors = [t if t.is_nested else t.expand_as(first) for t in tensors]
909
910    # Account for collapsed jagged dim
911    dim = new_kwargs["dim"]
912    new_kwargs["dim"] = _wrap_jagged_dim(len(first.shape), dim, "cat")
913
914    return NestedTensor(
915        func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
916    )
917
918
919@register_jagged_func(torch.ops.aten.matmul.default, "self: jt, other: any")
920def matmul_default(func, *args, **kwargs):
921    _, new_kwargs = normalize_function(  # type: ignore[misc]
922        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
923    )
924
925    inp = new_kwargs.pop("input")
926    other = new_kwargs.pop("other")
927
928    if inp.is_nested and not other.is_nested:
929        return NestedTensor(
930            func(inp._values, other, **new_kwargs), **extract_kwargs(inp)
931        )
932    elif inp.is_nested and other.is_nested:
933        # BMM with equivalent ragged dims between the two inputs
934        if inp.dim() > 3 and other.dim() > 3 and raggedness_matches(inp, other._size):
935            return NestedTensor(func(inp._values, other._values), **extract_kwargs(inp))
936
937    raise RuntimeError(
938        f"matmul(): not supported between inputs of shapes {inp._size} and {other.shape}"
939    )
940
941
942@register_jagged_func(
943    torch.ops.aten.expand.default, "self: jt, size: any, implicit: any?"
944)
945def expand_default(func, *args, **kwargs):
946    _, new_kwargs = normalize_function(  # type: ignore[misc]
947        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
948    )
949
950    inp = new_kwargs.pop("input")
951    size = new_kwargs["size"]
952
953    assert ("implicit" not in new_kwargs) or (not new_kwargs.pop("implicit"))
954    if not raggedness_matches(inp, size):
955        raise RuntimeError(f"expand(): cannot expand shape {inp._size} -> {size}")
956
957    expand_arg = [-1, *size[2:]]
958    return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp))
959
960
961@register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt")
962def expand_as_default(func, *args, **kwargs):
963    _, new_kwargs = normalize_function(  # type: ignore[misc]
964        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
965    )
966
967    inp = new_kwargs.pop("input")
968    other = new_kwargs.pop("other")
969
970    return NestedTensor(func(inp, other._values), **extract_kwargs(other))
971
972
973@register_jagged_func(torch.ops.aten.where.self, "condition: jt, self: jt, other: jt")
974def where_self(func, *args, **kwargs):
975    _, new_kwargs = normalize_function(  # type: ignore[misc]
976        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
977    )
978
979    condition = new_kwargs.pop("condition")
980    inp = new_kwargs.pop("input")
981    other = new_kwargs.pop("other")
982
983    assert condition._size == other._size == inp._size
984
985    return NestedTensor(
986        func(condition._values, inp._values, other._values, **new_kwargs),
987        **extract_kwargs(condition),
988    )
989
990
991@register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?")
992def _pin_memory_default(func, *args, **kwargs):
993    _, new_kwargs = normalize_function(  # type: ignore[misc]
994        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
995    )
996
997    inp = new_kwargs.pop("input")
998
999    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
1000
1001
1002@register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?")
1003def is_pinned_default(func, *args, **kwargs):
1004    _, new_kwargs = normalize_function(  # type: ignore[misc]
1005        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1006    )
1007
1008    inp = new_kwargs.pop("input")
1009
1010    return func(inp._values, **new_kwargs)
1011
1012
1013@register_jagged_func(
1014    torch.ops.aten.is_same_size.default, "self: jt_all, other: jt_all"
1015)
1016def is_same_size_default(func, *args, **kwargs):
1017    return args[0]._size == args[1]._size
1018
1019
1020@register_jagged_func(
1021    torch.ops.aten.sum.dim_IntList,
1022    "self: jt_all, dim: any?, keepdim: any?, dtype: any?",
1023)
1024def sum_dim_IntList(func, *args, **kwargs):
1025    """
1026    Performs a sum along the provided tensor dimension.
1027    Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
1028    """
1029    _, new_kwargs = normalize_function(  # type: ignore[misc]
1030        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1031    )
1032    inp = new_kwargs.pop("input")
1033
1034    (
1035        new_kwargs["dim"],
1036        reduce_on_batch,
1037        reduce_on_ragged,
1038        reduce_on_non_batch,
1039    ) = _wrap_jagged_dims(
1040        inp.dim(),
1041        new_kwargs["dim"],
1042        "sum",
1043        inp._ragged_idx,
1044    )
1045
1046    if reduce_on_ragged and inp._lengths is not None:
1047        raise RuntimeError(
1048            "sum(): not supported where lengths is not None "
1049            + "if reducing across the ragged dimension for NestedTensor"
1050        )
1051
1052    if reduce_on_ragged:  # raggedness reduced away --> return dense tensor
1053        if (
1054            reduce_on_batch
1055        ):  # reduction cases: (batch, ragged), (batch, ragged, non-batch), etc.
1056            out = func(
1057                inp._values, **new_kwargs
1058            )  # no need to read offsets --> apply sum directly on values
1059        else:
1060            if (
1061                reduce_on_non_batch
1062            ):  # invalid reduction cases: (ragged, non-batch), etc.
1063                raise RuntimeError(
1064                    "sum(): not supported along a ragged and non-batch dimension for NestedTensor"
1065                )
1066            # reduction cases: (ragged)
1067            values_ragged_dim_outer = inp._values.permute(
1068                inp._ragged_idx - 1,  # outer dimension
1069                *range(0, inp._ragged_idx - 1),
1070                *range(inp._ragged_idx, inp.dim() - 1),
1071            )  # shift reduction dimension of values backward to outer dimension
1072
1073            # _jagged_to_padded_dense_forward requires values to be a 2D tensor
1074            # with the ragged dimension as the 0th dimension
1075            padded = torch.ops.aten._jagged_to_padded_dense_forward(
1076                values_ragged_dim_outer.reshape(values_ragged_dim_outer.shape[0], -1),
1077                [inp._offsets],
1078                max_lengths=[inp._max_seqlen],
1079            )
1080
1081            padded_ragged_dim_original = padded.view(
1082                padded.shape[0],
1083                inp._max_seqlen,
1084                *values_ragged_dim_outer.shape[
1085                    1:
1086                ],  # expand non-batch dimensions of padded tensor
1087            ).permute(
1088                0,
1089                *range(2, inp._ragged_idx + 1),
1090                1,
1091                *range(inp._ragged_idx + 1, inp.dim()),
1092            )  # shift reduction dimension of padded tensor forward to original ragged dimension
1093
1094            out = torch.sum(
1095                padded_ragged_dim_original,
1096                dim=inp._ragged_idx,
1097            )  # need to read offsets --> pad jagged dimension and apply sum
1098
1099        if new_kwargs["keepdim"]:
1100            # TODO: Fix this; it's a bug. should be unsqueezing on ragged_idx
1101            out = out.unsqueeze(0)
1102        return out
1103    else:  # raggedness preserved --> return nested tensor
1104        if (
1105            reduce_on_batch
1106        ):  # invalid reduction cases: (batch), (batch, non-batch), etc.
1107            raise RuntimeError(
1108                "sum(): not supported along the batch dimension but not the ragged dimension for NestedTensor"
1109            )
1110        # reduction cases: (non-batch), (non-batch, non-batch), etc.
1111        return NestedTensor(
1112            func(inp._values, **new_kwargs), **extract_kwargs(inp)
1113        )  # apply sum directly on values
1114
1115
1116@register_jagged_func(
1117    torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any"
1118)
1119def transpose_int(func, *args, **kwargs):
1120    _, new_kwargs = normalize_function(  # type: ignore[misc]
1121        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1122    )
1123
1124    from torch._prims_common import canonicalize_dims
1125
1126    inp = new_kwargs.pop("input")
1127    dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"]))
1128
1129    if inp._lengths is not None:
1130        raise ValueError(
1131            "transpose(): not supported on jagged layout nested tensor with holes"
1132        )
1133
1134    # To support the SDPA API, inputs need to have the ragged idx transposed to dim 2
1135    # instead of 1, although the internal Flash and mem-effn implementations will
1136    # use the inputs with raggedness in dim 1.
1137    if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx:
1138        if dim0 == 0 or dim1 == 0:
1139            raise ValueError(
1140                "Transpose is not supported on the batch dimension for jagged NT"
1141            )
1142        if dim0 == inp._ragged_idx:
1143            to_dim = dim1
1144        else:
1145            to_dim = dim0
1146        inp_kwargs = extract_kwargs(inp)
1147        inp_kwargs["_ragged_idx"] = to_dim
1148        return NestedTensor(
1149            inp.values().transpose(
1150                _outer_to_inner_dim(len(inp._size), dim0),
1151                _outer_to_inner_dim(len(inp._size), dim1),
1152            ),
1153            **inp_kwargs,
1154        )
1155
1156    new_kwargs["dim0"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim0"], "transpose")
1157    new_kwargs["dim1"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim1"], "transpose")
1158
1159    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
1160
1161
1162@register_jagged_func(torch.ops.aten.permute.default, "self: jt_all, dims: any")
1163def permute_default(func, *args, **kwargs):
1164    _, new_kwargs = normalize_function(  # type: ignore[misc]
1165        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1166    )
1167    inp = new_kwargs.pop("input")
1168    dims = new_kwargs.pop("dims")
1169    inp_kwargs = extract_kwargs(inp)
1170    inp_dim = len(inp._size)
1171
1172    # The first two checks are the same as the checks in the normal permute implementation
1173    if inp_dim != len(dims):
1174        raise ValueError(
1175            f"permute(): number of dimensions in the tensor input ({inp_dim}) "
1176            + f"does not match the length of the desired ordering of dimensions ({len(dims)}).",
1177        )
1178
1179    from torch._prims_common import canonicalize_dims
1180
1181    canonicalized_dims = canonicalize_dims(inp_dim, dims)
1182
1183    if len(canonicalized_dims) != len(set(canonicalized_dims)):
1184        raise ValueError("permute(): duplicate dims are not allowed.")
1185
1186    if inp._lengths is not None:
1187        raise ValueError(
1188            "permute(): not supported on jagged layout nested tensor with holes"
1189        )
1190    if canonicalized_dims[0] != 0:
1191        raise ValueError(
1192            "Permute is not supported on the batch dimension for jagged NT"
1193        )
1194    inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx)
1195    inner_dims = [_outer_to_inner_dim(inp_dim, dim) for dim in canonicalized_dims[1:]]
1196    new_kwargs["dims"] = inner_dims
1197    return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs)
1198
1199
1200@register_jagged_func(
1201    [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default],
1202    "self: jt_all, size: any",
1203)
1204def view_default(func, *args, **kwargs):
1205    _, new_kwargs = normalize_function(  # type: ignore[misc]
1206        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1207    )
1208
1209    inp = new_kwargs.pop("input")
1210    size = new_kwargs.pop("size")
1211
1212    if inp._ragged_idx != 1 and tuple(inp._size) != tuple(size):
1213        raise RuntimeError(
1214            f"view(): does not support ragged_idx != 1 except when inp._size == size. "
1215            f"inp._size is ({inp._size}) and size is ({size})."
1216        )
1217
1218    # Ensure specified size still includes batch and ragged dims
1219    if len(size) < 3 or not raggedness_matches(inp, size):
1220        raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}")
1221
1222    # outer size: the size of the NT, e.g. [3, j0, 10]
1223    # inner size: the size of the values, e.g. [8, 10] (e.g. for offsets = [0, 3, 5, 8])
1224    # this function gets inner_size[inner_idx] for a given inner_idx.
1225    #
1226    # example: for outer size [a, b, c, j0, d, e, f]
1227    #                         assume that j0 is ragged, other are concrete integers
1228    #                         and ragged_idx=3
1229    # inner size will be      [b, c, inp._values.size(ragged_idx), d, e, f]
1230    # therefore:
1231    #    inner_size[0] = outer_size[1]
1232    #    inner_size[1] = outer_size[2]
1233    #    inner_size[0] = inp._values.size(ragged_idx - 1)
1234    #    inner_size[3] = outer_size[4]
1235    #    inner_size[4] = outer_size[5]
1236    def get_inner_size(inner_idx):
1237        nonlocal inp, size
1238        if inner_idx == inp._ragged_idx - 1:
1239            return inp._values.size(inner_idx)
1240        else:
1241            return size[inner_idx + 1]
1242
1243    inner_size = [get_inner_size(i) for i in range(len(size) - 1)]
1244
1245    return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp))
1246
1247
1248@register_jagged_func(
1249    torch.ops.aten.native_layer_norm.default,
1250    "input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any",
1251)
1252def native_layer_norm_default(func, *args, **kwargs):
1253    _, new_kwargs = normalize_function(  # type: ignore[misc]
1254        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1255    )
1256
1257    inp = new_kwargs.pop("input")
1258
1259    if inp.dim() <= 2:
1260        raise RuntimeError(
1261            "layer_norm(): not supported for NestedTensor objects with 2 or fewer dimensions"
1262        )
1263
1264    normalized_shape = new_kwargs["normalized_shape"]
1265    ragged_size = inp.shape[inp._ragged_idx]
1266
1267    num_dims_not_normalized = inp.dim() - len(normalized_shape)
1268
1269    if (
1270        num_dims_not_normalized == 0
1271    ):  # error if trying to normalize over the batch dimension
1272        raise RuntimeError(
1273            "layer_norm(): not supported when normalizing over the batch dimension for NestedTensor"
1274        )
1275
1276    if ragged_size in normalized_shape and inp._lengths is not None:
1277        raise RuntimeError(
1278            "layer_norm(): not supported where lengths is not None if operating on the ragged dimension for NestedTensor"
1279        )
1280
1281    if (
1282        ragged_size in normalized_shape
1283    ):  # special handling for normalizing over the ragged dimension
1284        padded_input = torch.ops.aten._jagged_to_padded_dense_forward(
1285            inp._values.flatten(
1286                start_dim=inp._ragged_idx
1287            ),  # _jagged_to_padded_dense_forward requires values to be a 2D tensor
1288            [inp._offsets],
1289            max_lengths=[inp._max_seqlen],  # max length of ragged dimension
1290        )
1291
1292        padded_mask = torch.ops.aten._jagged_to_padded_dense_forward(
1293            torch.ones((inp._values.shape[0], 1), device=inp.device, dtype=inp.dtype),
1294            [inp._offsets],
1295            max_lengths=[inp._max_seqlen],  # max length of ragged dimension
1296        ).expand(
1297            padded_input.shape
1298        )  # mask elements outside of the ragged dimension and expand to the same shape as padded input (3D dense tensor)
1299
1300        ragged_lengths = (
1301            inp._offsets.diff().unsqueeze(1).unsqueeze(1) * padded_input.shape[2]
1302        )  # ragged dim * inner dim, since we sum over dims (1, 2) (the layer on which we normalize)
1303
1304        mean = (
1305            torch.sum(
1306                padded_input,
1307                dim=(1, 2),
1308                keepdim=True,
1309            )
1310            / ragged_lengths
1311        )  # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
1312
1313        padded_normalized = (
1314            padded_input - mean
1315        ) * padded_mask  # mask elements outside of the ragged dimension size for correct variance calculation
1316
1317        variance = (
1318            torch.sum(
1319                torch.square(padded_normalized),
1320                dim=(1, 2),
1321                keepdim=True,
1322            )
1323            / ragged_lengths
1324        )  # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
1325
1326        std = torch.sqrt(variance + new_kwargs["eps"])
1327        padded_layer_norm = padded_normalized / std
1328
1329        jagged_layer_norm_values = torch.ops.aten._padded_dense_to_jagged_forward(
1330            padded_layer_norm,
1331            [inp._offsets],
1332            total_L=inp._values.shape[
1333                0
1334            ],  # providing this parameter helps avoid a GPU/CPU sync
1335        ).unflatten(
1336            -1, inp.shape[inp._ragged_idx + 1 :]
1337        )  # unflatten last dimension back into original nested tensor shape, e.g. (B, *, WH) --> (B, *, W, H)
1338
1339        return (
1340            NestedTensor(jagged_layer_norm_values, **extract_kwargs(inp)),
1341            mean,
1342            std,
1343        )
1344
1345    output, mean, std = func(inp._values, **new_kwargs)
1346    return (NestedTensor(output, **extract_kwargs(inp)), mean, std)
1347
1348
1349@register_jagged_func(
1350    torch.ops.aten.native_layer_norm_backward.default,
1351    "grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any",
1352)
1353def native_layer_norm_backward_default(func, *args, **kwargs):
1354    _, new_kwargs = normalize_function(  # type: ignore[misc]
1355        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1356    )
1357    grad_out = new_kwargs.pop("grad_out")
1358    inp = new_kwargs.pop("input")
1359    d_input, d_gamma, d_beta = func(grad_out._values, inp._values, **new_kwargs)
1360    if d_input is None:
1361        return (None, d_gamma, d_beta)
1362
1363    return (NestedTensor(d_input, **extract_kwargs(inp)), d_gamma, d_beta)
1364
1365
1366@register_jagged_func(torch.ops.aten.select.int, "self: jt, dim: any, index: any")
1367def select_int(func, *args, **kwargs):
1368    _, new_kwargs = normalize_function(  # type: ignore[misc]
1369        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1370    )
1371
1372    inp = new_kwargs.pop("input")
1373    new_kwargs["dim"] = _wrap_jagged_dim(
1374        inp.dim(), new_kwargs["dim"], "select", allow_batch_dim=True
1375    )
1376
1377    # handle batch dim slicing via unbind() for now
1378    # TODO: make this more efficient
1379    if new_kwargs["dim"] == 0:
1380        return inp.unbind()[new_kwargs["index"]]
1381
1382    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
1383
1384
1385@register_jagged_func(
1386    torch.ops.aten.slice.Tensor,
1387    "self: jt, dim: any?, start: any?, end: any?, step: any?",
1388)
1389def slice_tensor(func, *args, **kwargs):
1390    _, new_kwargs = normalize_function(  # type: ignore[misc]
1391        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1392    )
1393
1394    inp = new_kwargs.pop("input")
1395    new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "slice")
1396
1397    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
1398
1399
1400@register_jagged_func(
1401    torch.ops.aten.convolution.default,
1402    "input: jt, weight: t, bias: t?, stride: any, padding: any, "
1403    "dilation: any, transposed: any, output_padding: any, groups: any",
1404)
1405def convolution_default(func, *args, **kwargs):
1406    _, new_kwargs = normalize_function(  # type: ignore[misc]
1407        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1408    )
1409
1410    inp = new_kwargs.pop("input")
1411
1412    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
1413
1414
1415@register_jagged_func(
1416    torch.ops.aten.mean.dim, "self: jt_all, dim: any?, keepdim: any?, dtype: any?"
1417)
1418def mean_dim(func, *args, **kwargs):
1419    """
1420    Performs a mean along the provided tensor dimension.
1421    Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
1422    """
1423    _, new_kwargs = normalize_function(  # type: ignore[misc]
1424        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1425    )
1426
1427    if len(new_kwargs["dim"]) > 1:
1428        raise RuntimeError(
1429            "mean(): not supported across multiple dimensions for NestedTensor"
1430        )
1431
1432    inp = new_kwargs.pop("input")
1433
1434    (
1435        new_kwargs["dim"],
1436        reduce_on_batch,
1437        reduce_on_ragged,
1438        reduce_on_non_batch,
1439    ) = _wrap_jagged_dims(
1440        inp.dim(),
1441        new_kwargs["dim"],
1442        "mean",
1443        inp._ragged_idx,
1444    )
1445
1446    if reduce_on_batch:
1447        raise RuntimeError(
1448            "mean(): not supported along the batch dimension but not the ragged dimension for NestedTensor"
1449        )
1450
1451    if reduce_on_ragged and inp._lengths is not None:
1452        raise RuntimeError(
1453            "mean(): not supported where lengths is not None "
1454            + "if reducing across the ragged dimension for NestedTensor"
1455        )
1456
1457    if not new_kwargs["keepdim"]:
1458        raise RuntimeError("mean(): not supported when keepdim=False for NestedTensor")
1459
1460    if reduce_on_ragged:  # raggedness reduced away
1461        torch_sum = torch.sum(inp, dim=inp._ragged_idx, keepdim=new_kwargs["keepdim"])
1462
1463        # for every non-batch dimension,
1464        #   unsqueeze lengths into the same shape as the PyTorch sum,
1465        #   as the extra dimensions must all be divided by the same length
1466        lengths = inp._offsets.diff()
1467        for _ in range(inp.dim() - 2):
1468            lengths = lengths.unsqueeze(-1)
1469
1470        return torch_sum / lengths.broadcast_to(torch_sum.shape)
1471
1472    return NestedTensor(
1473        func(inp._values, **new_kwargs), **extract_kwargs(inp)
1474    )  # raggedness preserved
1475
1476
1477@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
1478def stack_default(func, *args, **kwargs):
1479    _, new_kwargs = normalize_function(  # type: ignore[misc]
1480        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1481    )
1482
1483    # guaranteed this is non-empty if we got here
1484    tensors = new_kwargs.pop("tensors")
1485    for t in tensors:
1486        if not isinstance(t, NestedTensor):
1487            raise RuntimeError("stack(): expected all nested tensors inputs")
1488
1489        if t.dim() != tensors[0].dim():
1490            raise RuntimeError(
1491                "stack(): expected all nested tensors to have the same dim"
1492            )
1493
1494        if not raggedness_matches(t, tensors[0].shape):
1495            raise RuntimeError(
1496                "stack(): expected all nested tensors to have the same nested structure"
1497            )
1498
1499    new_kwargs["dim"] = _wrap_jagged_dim(
1500        tensors[0].dim() + 1, new_kwargs["dim"], "stack"
1501    )
1502
1503    return NestedTensor(
1504        func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
1505    )
1506
1507
1508@register_jagged_func(
1509    torch.ops.aten.embedding.default,
1510    "weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?",
1511)
1512def embedding_default(func, *args, **kwargs):
1513    _, new_kwargs = normalize_function(  # type: ignore[misc]
1514        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1515    )
1516
1517    # guaranteed this is non-empty if we got here
1518    indices = new_kwargs.pop("indices")
1519    weight = new_kwargs.pop("weight")
1520
1521    return NestedTensor(
1522        func(weight, indices._values, **new_kwargs), **extract_kwargs(indices)
1523    )
1524
1525
1526@register_jagged_func(
1527    [
1528        torch.ops.aten.values.default,
1529        torch.ops.aten._nested_get_values.default,
1530    ],
1531    "self: jt_all",
1532)
1533def values_default(func, *args, **kwargs):
1534    _, new_kwargs = normalize_function(  # type: ignore[misc]
1535        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1536    )
1537
1538    inp = new_kwargs.pop("input")
1539
1540    # TODO: Handle inference mode properly.
1541    # See https://github.com/pytorch/pytorch/issues/112024#issuecomment-1779554292
1542    return inp._values.detach()
1543
1544
1545@register_jagged_func(torch.ops.aten.all.default, "self: jt_all")
1546def all_default(func, *args, **kwargs):
1547    _, new_kwargs = normalize_function(  # type: ignore[misc]
1548        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1549    )
1550
1551    inp = new_kwargs.pop("input")
1552
1553    return func(inp._values)
1554
1555
1556@register_jagged_func(
1557    torch.ops.aten._nested_view_from_jagged.default,
1558    "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?",
1559)
1560def _nested_view_from_jagged_default(func, *args, **kwargs):
1561    _, new_kwargs = normalize_function(  # type: ignore[misc]
1562        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1563    )
1564
1565    values, offsets, lengths = (
1566        new_kwargs["input"],
1567        new_kwargs["offsets"],
1568        new_kwargs["lengths"],
1569    )
1570    ragged_idx = new_kwargs["ragged_idx"]
1571    min_seqlen = new_kwargs["min_seqlen"]
1572    max_seqlen = new_kwargs["max_seqlen"]
1573    metadata_cache = {}
1574    if min_seqlen is not None:
1575        metadata_cache["min_seqlen"] = min_seqlen
1576    if max_seqlen is not None:
1577        metadata_cache["max_seqlen"] = max_seqlen
1578
1579    return NestedTensor(
1580        values,
1581        offsets,
1582        lengths=lengths,
1583        _ragged_idx=ragged_idx,
1584        _metadata_cache=metadata_cache,
1585    )
1586
1587
1588@register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all")
1589def _nested_get_offsets(func, *args, **kwargs):
1590    _, new_kwargs = normalize_function(  # type: ignore[misc]
1591        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1592    )
1593
1594    inp = new_kwargs.pop("input")
1595    return inp._offsets
1596
1597
1598@register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all")
1599def _nested_get_lengths(func, *args, **kwargs):
1600    _, new_kwargs = normalize_function(  # type: ignore[misc]
1601        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1602    )
1603
1604    inp = new_kwargs.pop("input")
1605    return inp._lengths
1606
1607
1608@register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all")
1609def _nested_get_ragged_idx(func, *args, **kwargs):
1610    _, new_kwargs = normalize_function(  # type: ignore[misc]
1611        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1612    )
1613
1614    inp = new_kwargs.pop("input")
1615    return inp._ragged_idx
1616
1617
1618@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all")
1619def _nested_get_min_seqlen(func, *args, **kwargs):
1620    _, new_kwargs = normalize_function(  # type: ignore[misc]
1621        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1622    )
1623
1624    inp = new_kwargs.pop("input")
1625    return inp._metadata_cache.get("min_seqlen", None)
1626
1627
1628@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all")
1629def _nested_get_max_seqlen(func, *args, **kwargs):
1630    _, new_kwargs = normalize_function(  # type: ignore[misc]
1631        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1632    )
1633
1634    inp = new_kwargs.pop("input")
1635    return inp._metadata_cache.get("max_seqlen", None)
1636
1637
1638# If a section of the Nested Tensor is fully masked out we still retain the section with a length of 0
1639@register_jagged_func(torch.ops.aten.masked_select.default, "self: jt, mask: any")
1640def masked_select_default(func, *args, **kwargs):
1641    _, new_kwargs = normalize_function(  # type: ignore[misc]
1642        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1643    )
1644    inp = new_kwargs.pop("input")
1645    mask = new_kwargs.pop("mask")
1646
1647    if inp.ndim > 2:
1648        raise RuntimeError("masked_select only support 2-D selections currently")
1649    elif inp.shape != mask.shape:
1650        raise RuntimeError(
1651            f"Mask with shape {mask.shape} is not compatible with input's shape {inp.shape}"
1652        )
1653    res_values = inp._values.masked_select(mask.values())
1654    mask_cumsum = F.pad(mask.values().cumsum(dim=0), (1, 0))  # type: ignore[arg-type]
1655
1656    args = extract_kwargs(inp)
1657    args["offsets"] = mask_cumsum[inp._offsets]
1658    return NestedTensor(
1659        values=res_values,
1660        **args,
1661    )
1662
1663
1664# Make the dummy available on the C++ side.
1665@register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any")
1666def _nested_get_jagged_dummy(func, *args, **kwargs):
1667    from torch.nested._internal.nested_tensor import _nt_view_dummy
1668
1669    return _nt_view_dummy()
1670
1671
1672with torch.library._scoped_library("aten", "IMPL") as aten:
1673    aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CPU")
1674    aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CUDA")
1675    aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "Meta")
1676