1# mypy: allow-untyped-defs 2import functools 3import math 4import operator 5from typing import * # noqa: F403 6 7import torch 8import torch.nn.functional as F 9from torch.fx.operator_schemas import normalize_function 10from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention 11 12from .nested_tensor import NestedTensor 13 14 15__all__: List[Any] = [] 16 17JAGGED_OPS_TABLE: Dict[Any, Any] = {} 18 19 20# Simplifying assumption: we assume that the batch dim is always the left-most 21# dim, and the ragged dim is always the second dim. 22def _outer_to_inner_dim(ndim, dim): 23 assert dim >= 0 and dim < ndim 24 return 0 if dim < 2 else dim - 1 25 26 27def _wrap_jagged_dim( 28 ndim, dim, op_name, convert_to_inner_dim=True, allow_batch_dim=False 29): 30 from torch._prims_common import canonicalize_dims 31 32 wrapped = canonicalize_dims(ndim, dim) 33 if wrapped == 1: 34 raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=1") 35 elif wrapped == 0 and not allow_batch_dim: 36 raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0") 37 return _outer_to_inner_dim(ndim, wrapped) if convert_to_inner_dim else wrapped 38 39 40def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1): 41 """ 42 For NestedTensor operators, 43 wraps dimensions to non-negative values, 44 and returns metadata related to reduction dimension(s). 45 """ 46 from torch._prims_common import canonicalize_dims 47 48 assert isinstance( 49 dims, (tuple, list) 50 ), f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}" 51 52 wrapped_dims = [ 53 canonicalize_dims(ndim, d) for d in dims 54 ] # convert all indices to non-negative values 55 56 operate_on_batch = 0 in wrapped_dims 57 operate_on_ragged = ragged_idx in wrapped_dims 58 operate_on_non_batch = any(d != 0 and d != ragged_idx for d in wrapped_dims) 59 60 outer_to_inner_dim = tuple( 61 _outer_to_inner_dim(ndim, d) for d in wrapped_dims if d != 0 62 ) 63 64 return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch 65 66 67def check_schema(schema_str: str, func, *args, **kwargs) -> None: 68 named_arg_types = schema_str.split(", ") 69 num_optional_args = [x.endswith("?") for x in named_arg_types].count(True) 70 min_args = len(named_arg_types) - num_optional_args 71 72 # special case: ellipses allows for any number of unchecked args at the end 73 if named_arg_types[-1] == "...": 74 named_arg_types = named_arg_types[:-1] 75 else: 76 if not (len(args) >= min_args and len(args) <= len(named_arg_types)): 77 raise ValueError( 78 f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} " 79 f"arguments and at most {len(named_arg_types)} arguments, but got: " 80 f"{len(args)} arguments" 81 ) 82 83 arg_type_check_fns = { 84 "t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor), 85 "jt": lambda x: isinstance(x, NestedTensor) 86 and x._lengths is None 87 and x._ragged_idx == 1, # ops with "jt" require contiguous JT only 88 "jt_all": lambda x: isinstance( 89 x, NestedTensor 90 ), # ops with "jt_all" can accept all kinds of JT 91 "any": lambda x: True, 92 } 93 for i, named_arg_type in enumerate(named_arg_types): 94 name, arg_type = named_arg_type.split(": ") 95 is_optional = arg_type.endswith("?") 96 normalized_arg_type = arg_type[:-1] if is_optional else arg_type 97 if normalized_arg_type not in arg_type_check_fns.keys(): 98 raise AssertionError(f"Unknown arg type: {normalized_arg_type}") 99 100 if i >= len(args): 101 if not is_optional: 102 raise ValueError( 103 f"NestedTensor {func.__name__}({schema_str}) " 104 f"missing required argument: {name}" 105 ) 106 continue 107 108 _check_fn = arg_type_check_fns[normalized_arg_type] 109 110 def check_fn(x, is_optional=is_optional): 111 if is_optional: 112 return x is None or _check_fn(x) 113 else: 114 return _check_fn(x) 115 116 if not check_fn(args[i]): 117 type_to_desc = { 118 "t": "tensor", 119 "t?": "optional tensor", 120 "jt": "contiguous jagged layout NestedTensor", 121 "jt_all": "jagged layout NestedTensor", 122 "any": "<any type>", 123 } 124 125 raise ValueError( 126 f"NestedTensor {func.__name__}({schema_str}): expected {name} to be a " 127 f"{type_to_desc[arg_type]}" 128 ) 129 130 131def check_ragged_dim_same( 132 func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str 133) -> None: 134 # Calling into .shape here 135 if a._size[a._ragged_idx] != b._size[b._ragged_idx]: 136 raise RuntimeError( 137 f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the " 138 "same exact offsets tensor." 139 ) 140 141 142# returns True if the raggedness-relevant portions of the NT shape 143# match those of the specified size 144def raggedness_matches(nt, size): 145 end = nt._ragged_idx + 1 146 nt_ragged = nt._size[:end] 147 size_ragged = size[:end] 148 return len(nt_ragged) == len(size_ragged) and ( 149 all(ns == s or s == -1 for ns, s in zip(nt_ragged, size_ragged)) 150 ) 151 152 153def squeeze_leading_ones(t): 154 # Note: [ Squeezing leading ones ] 155 # 156 # Squeeze leading ones from t. 157 # 158 # We want: 159 # (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?) 160 # (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) (not yet supported) 161 # 162 # 1) Squeeze extra ones and grab values from NT 163 # (1, 1, ?, ?) -> (?, ?) and (sum(*), ?, ?) -> (B, j0, ?, ?) 164 # 2) Do dense broadcasting: 165 # (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?) 166 # 3) Construct nested tensor 167 # (sum(*), ?, ?) -> (B, j0, ?, ?) 168 # 169 # If unsqueezing on the 0th dim becomes supported, we would unsqueeze 170 # at step (4) and we would need to update this function to record how 171 # many ones we unsqueezed. 172 while t.dim() > 0 and t.shape[0] == 1: 173 t = t.squeeze(0) 174 return t 175 176 177def register_func(tables, aten_ops, schema_str): 178 if not isinstance(aten_ops, list): 179 aten_ops = [aten_ops] 180 if not isinstance(tables, list): 181 tables = [tables] 182 183 def wrapper(func): 184 for aten_op in aten_ops: 185 186 def get_inner(aten_op): 187 def inner(*args, **kwargs): 188 check_schema(schema_str, func, *args, **kwargs) 189 return func(aten_op, *args, **kwargs) 190 191 return inner 192 193 for table in tables: 194 table[aten_op] = get_inner(aten_op) 195 return func 196 197 return wrapper 198 199 200register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE) 201 202 203def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]: 204 dispatch_func = JAGGED_OPS_TABLE.get(func, None) 205 if dispatch_func is not None: 206 return dispatch_func 207 208 # Handle pointwise fallbacks 209 if torch.Tag.pointwise in func.tags: 210 # Assume there aren't additional tensors that aren't the "unary/binary" args 211 num_tensor_args = sum(isinstance(x, torch.Tensor) for x in args) 212 if num_tensor_args == 1: 213 # Build up the check schema string. The first tensor arg is assumed to be 214 # an NJT and other args are sent through as-is. 215 schema_parts = [] 216 for arg in func._schema.arguments: 217 if isinstance(arg.type, torch.TensorType): 218 schema_parts.append(f"{arg.name}: jt_all") 219 break 220 else: 221 schema_parts.append(f"{arg.name}: any") 222 schema_parts.append("...") 223 check_schema_str = ", ".join(schema_parts) 224 check_schema(check_schema_str, func, *args, **kwargs) 225 return functools.partial(jagged_unary_pointwise, func) 226 elif num_tensor_args == 2: 227 check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs) 228 return functools.partial(jagged_binary_pointwise, func) 229 230 return None 231 232 233def extract_kwargs(arg): 234 kwargs = { 235 "offsets": arg.offsets(), 236 "_metadata_cache": arg._metadata_cache, 237 "_ragged_idx": arg._ragged_idx, 238 } 239 return kwargs 240 241 242def jagged_unary_pointwise(func, *args, **kwargs): 243 # assume if we get here that there is a single NJT input in the args 244 njt = next(arg for arg in args if isinstance(arg, NestedTensor)) 245 return NestedTensor( 246 func(*(arg._values if arg is njt else arg for arg in args), **kwargs), 247 **extract_kwargs(njt), 248 ) 249 250 251def jagged_binary_pointwise(func, *args, **kwargs): 252 a, b = args[0], args[1] 253 assert isinstance(a, NestedTensor) or isinstance(b, NestedTensor) 254 255 mismatch_error_msg = ( 256 "cannot call binary pointwise function {} with inputs of shapes {} and {}" 257 ) 258 # a is NT, b is NT 259 if isinstance(a, NestedTensor) and isinstance(b, NestedTensor): 260 # ex: (B, j0, D) + (B, j0, D) 261 # ex: (B, j0, D) + (B, j0, 1) 262 if raggedness_matches(a, b._size): 263 return NestedTensor( 264 func(a._values, b._values, *args[2:], **kwargs), **extract_kwargs(a) 265 ) 266 raise RuntimeError(mismatch_error_msg.format(func.__name__, a._size, b._size)) 267 # either a is NT or b is NT at this point 268 a_is_nt = isinstance(a, NestedTensor) 269 extracted_kwargs = extract_kwargs(a) if a_is_nt else extract_kwargs(b) 270 271 # === Handle broadcasting across the batch / ragged dims === 272 273 # Easy case: take advantage of pre-existing broadcasting logic 274 # ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?) 275 # ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?) 276 # ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?) 277 nt, t = (a, b) if a_is_nt else (b, a) 278 # See Note: [ Squeezing leading ones ] 279 if t.dim() > nt.dim(): 280 raise NotImplementedError("NYI: broadcasting NT with T with larger dim") 281 t_squeezed = squeeze_leading_ones(t) 282 if nt.dim() >= t_squeezed.dim() + 2: 283 lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values) 284 return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs) 285 286 # Harder case: do manual broadcasting over unbound components 287 # when NT dim == non-NT dim 288 # ex: (B, j0, D_0, D_1) + (B, 1, D_0, D_1) -> (B, j0, D_0, D_1) 289 if a.dim() == b.dim(): 290 # ex: (B, j0, D_0, D_1) + (1, 1, D_0, D_1) -> should 291 # be (B, j0, D_0, D_1) but not yet supported 292 if a.shape[0] != b.shape[0]: 293 raise RuntimeError( 294 mismatch_error_msg.format(func.__name__, a.shape, b.shape) 295 ) 296 297 # need to use offsets to broadcast across ragged dim properly 298 # NB: inefficient fallback here; Triton codegen can help this 299 # TODO: Make this work with autograd 300 outputs = [] 301 for a_comp, b_comp in zip(a.unbind(), b.unbind()): 302 outputs.append(func(a_comp, b_comp, *args[2:], **kwargs)) 303 new_values = torch.cat(outputs, dim=0) 304 return NestedTensor(new_values, **extracted_kwargs) 305 306 # ex: (B, j0, D_0, D_1) + (A, B, 1, D_0, D_1) -> error because this breaks the invariant 307 # that ragged dim is wrt left-most batch dim 308 raise RuntimeError(mismatch_error_msg.format(func.__name__, a.shape, b.shape)) 309 310 311def jagged_torch_function(func, *args, **kwargs): 312 # SDPA has special kernels that handle nested tensors. 313 # Dispatch to the correct implementation here 314 if func is torch._C._nn.scaled_dot_product_attention: 315 return jagged_scaled_dot_product_attention(*args, **kwargs) 316 317 if func.__name__ == "apply_": 318 func(args[0]._values, *args[1:], **kwargs) 319 return args[0] 320 321 # Handle flatten() here because it's CompositeImplicit. 322 if func.__name__ == "flatten": 323 324 def _flatten_sig(input, start_dim=0, end_dim=-1): 325 pass 326 327 _, new_kwargs = normalize_function( # type: ignore[misc] 328 _flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 329 ) 330 331 inp = new_kwargs.pop("input") 332 333 # NB: stay in outer dim space because we're going to redispatch on a NT input 334 start_dim = _wrap_jagged_dim( 335 inp.dim(), new_kwargs["start_dim"], "flatten", convert_to_inner_dim=False 336 ) 337 end_dim = _wrap_jagged_dim( 338 inp.dim(), new_kwargs["end_dim"], "flatten", convert_to_inner_dim=False 339 ) 340 341 if start_dim == end_dim: 342 return inp 343 344 product = functools.reduce(operator.mul, inp.shape[start_dim : end_dim + 1]) 345 new_shape = (*inp.shape[:start_dim], product, *inp.shape[end_dim + 1 :]) 346 347 return inp.reshape(*new_shape) 348 349 raise NotImplementedError(func) 350 351 352@register_jagged_func( 353 [ 354 torch.ops.aten.is_non_overlapping_and_dense.default, 355 torch.ops.aten.sym_size.default, 356 torch.ops.aten.dim.default, 357 torch.ops.aten.numel.default, 358 torch.ops.aten.sym_numel.default, 359 torch.ops.aten.sym_stride.default, 360 torch.ops.aten.sym_storage_offset.default, 361 ], 362 "self: jt_all", 363) 364def tensor_attr_supported_getter(func, *args, **kwargs): 365 if func == torch.ops.aten.is_non_overlapping_and_dense.default: 366 return False 367 368 if func == torch.ops.aten.sym_size.default: 369 return args[0]._size 370 371 if func == torch.ops.aten.dim.default: 372 return len(args[0]._size) 373 374 if func in (torch.ops.aten.sym_numel.default, torch.ops.aten.numel.default): 375 if args[0]._lengths is not None: 376 return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:])) 377 return args[0]._values.numel() 378 379 if func == torch.ops.aten.sym_stride.default: 380 return args[0]._strides 381 382 if func == torch.ops.aten.sym_storage_offset.default: 383 return args[0]._values.storage_offset() 384 385 386@register_jagged_func(torch.ops.prim.layout.default, "self: jt_all") 387def prim_layout_default(func, *args, **kwargs): 388 return torch.jagged 389 390 391@register_jagged_func( 392 [torch.ops.aten.size.default], 393 "self: jt_all", 394) 395def tensor_attr_unsupported_getter(func, *args, **kwargs): 396 if func == torch.ops.aten.size.default: 397 raise RuntimeError( 398 "NestedTensors does not support directly calling torch.ops.aten.size " 399 "please use `nested_tensor.size()` instead." 400 ) 401 402 403@register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all") 404def is_contiguous_general(func, *args, **kwargs): 405 from torch._prims_common import is_contiguous_for_memory_format 406 407 _, new_kwargs = normalize_function( # type: ignore[misc] 408 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 409 ) 410 inp = new_kwargs.pop("input") 411 412 # If created from narrow() check for lengths 413 if inp.lengths() is not None: 414 return False 415 416 new_kwargs["memory_format"] = new_kwargs.get( 417 "memory_format", torch.contiguous_format 418 ) 419 if new_kwargs["memory_format"] == torch.preserve_format: 420 return True 421 return is_contiguous_for_memory_format(inp._values, **new_kwargs) 422 423 424register_jagged_func( 425 torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?" 426)(is_contiguous_general) 427 428 429@register_jagged_func( 430 torch.ops.aten.clone.default, "input: jt_all, memory_format: any?" 431) 432def clone_default(func, *args, **kwargs): 433 _, new_kwargs = normalize_function( # type: ignore[misc] 434 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 435 ) 436 437 inp = new_kwargs.pop("input") 438 439 new_meta = extract_kwargs(inp) 440 441 if inp._lengths is not None: 442 if new_kwargs["memory_format"] == torch.contiguous_format: 443 # need to copy to remove "holes" non-contiguity / lengths metadata 444 # TODO: write a kernel for this 445 from .nested_tensor import jagged_from_list 446 447 # TODO: We probably want the output to have the same ragged structure / nested int. 448 assert ( 449 inp._ragged_idx == 1 450 ), "NJT with ragged_idx != 1 not supported for contiguous clone" 451 contig, _ = jagged_from_list(inp.unbind(), offsets=None) 452 return contig 453 else: 454 # need to preserve any lengths metadata present 455 new_meta["lengths"] = inp._lengths 456 457 return NestedTensor(func(inp._values, **new_kwargs), **new_meta) 458 459 460@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?") 461def linear_default(func, *args, **kwargs): 462 _, new_kwargs = normalize_function( # type: ignore[misc] 463 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 464 ) 465 466 inp = new_kwargs.pop("input") 467 468 return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) 469 470 471@register_jagged_func( 472 torch.ops.aten.linear_backward.default, 473 "self: jt, grad_output: jt, weight: t, output_mask: any", 474) 475def linear_backward_default(func, *args, **kwargs): 476 _, new_kwargs = normalize_function( # type: ignore[misc] 477 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 478 ) 479 480 inp = new_kwargs.pop("input") 481 grad_output = new_kwargs.pop("grad_output") 482 weight = new_kwargs.pop("weight") 483 484 check_ragged_dim_same(func, inp, "self", grad_output, "grad_output") 485 ds = NestedTensor( 486 torch.matmul(grad_output._values, weight), **extract_kwargs(grad_output) 487 ) 488 dw = torch.matmul(grad_output._values.transpose(-2, -1), inp._values) 489 db = None # NYI: gradient for bias, need to reduce over ragged dim 490 return (ds, dw, db) 491 492 493@register_jagged_func(torch.ops.aten.to.dtype, "input: jt_all, dtype: any") 494def to_dtype(func, *args, **kwargs): 495 _, new_kwargs = normalize_function( # type: ignore[misc] 496 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 497 ) 498 499 inp = new_kwargs.pop("input") 500 501 return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) 502 503 504@register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all") 505def to_copy_default(func, *args, **kwargs): 506 from .nested_tensor import _tensor_symint_registry 507 508 _, new_kwargs = normalize_function( # type: ignore[misc] 509 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 510 ) 511 512 inp = new_kwargs.pop("input") 513 # don't change layout 514 new_kwargs.pop("layout") 515 516 new_values = func(inp._values, **new_kwargs) 517 new_offsets = inp._offsets.to(device=new_values.device) 518 519 from torch._subclasses.fake_tensor import FakeTensor 520 from torch._subclasses.functional_tensor import ( 521 FunctionalTensor, 522 mb_unwrap_functional_tensor, 523 ) 524 525 if isinstance(new_offsets, (FakeTensor, FunctionalTensor)): 526 # Temporary hack until we have the union find 527 tgt = mb_unwrap_functional_tensor(new_offsets) 528 src = mb_unwrap_functional_tensor(inp._offsets) 529 tgt.nested_int_memo = src.nested_int_memo 530 else: 531 _tensor_symint_registry[new_offsets] = _tensor_symint_registry[inp._offsets] 532 inp_kwargs = extract_kwargs(inp) 533 inp_kwargs["offsets"] = new_offsets 534 535 return NestedTensor(new_values, **inp_kwargs) 536 537 538@register_jagged_func( 539 torch.ops.aten.copy_.default, "self: jt_all, src: jt_all, non_blocking: any?" 540) 541def copy_default(func, *args, **kwargs): 542 _, new_kwargs = normalize_function( # type: ignore[misc] 543 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 544 ) 545 inp = new_kwargs.pop("input") 546 src = new_kwargs.pop("src") 547 if inp._size != src._size: 548 raise RuntimeError( 549 "copy_ only supports Nested Tensors that have same size and the exact same offset tensor." 550 ) 551 inp.values().copy_(src.values()) 552 return inp 553 554 555register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")( 556 jagged_unary_pointwise 557) 558 559 560@register_jagged_func( 561 [ 562 torch.ops.aten.empty_like.default, 563 torch.ops.aten.ones_like.default, 564 torch.ops.aten.zeros_like.default, 565 torch.ops.aten.randn_like.default, 566 ], 567 "self: jt_all", 568) 569def like_factory_default(func, *args, **kwargs): 570 _, new_kwargs = normalize_function( # type: ignore[misc] 571 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 572 ) 573 574 inp = new_kwargs.pop("input") 575 576 # Default layout is technically torch.strided but only jagged is supported here. 577 # Rather than force users to specify the layout, assume jagged. 578 # This should be set to strided for redispatching on values. 579 new_kwargs["layout"] = torch.strided 580 581 return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) 582 583 584@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all") 585def zero__default(func, *args, **kwargs): 586 _, new_kwargs = normalize_function( # type: ignore[misc] 587 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 588 ) 589 590 inp = new_kwargs.pop("input") 591 func(inp._values) 592 return inp 593 594 595@register_jagged_func( 596 torch.ops.aten._softmax.default, "self: jt_all, dim: any, half_to_float: any" 597) 598def _softmax_default(func, *args, **kwargs): 599 _, new_kwargs = normalize_function( # type: ignore[misc] 600 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 601 ) 602 603 if isinstance(new_kwargs["dim"], tuple): 604 raise RuntimeError( 605 "softmax(): not supported for dimensions of type 'tuple' for NestedTensor" 606 ) 607 608 inp = new_kwargs.pop("input") 609 610 ( 611 new_kwargs["dim"], 612 reduce_on_batch, 613 reduce_on_ragged, 614 reduce_on_non_batch, 615 ) = _wrap_jagged_dims( 616 inp.dim(), 617 (new_kwargs["dim"],), 618 "softmax", 619 inp._ragged_idx, 620 ) 621 622 if reduce_on_batch: 623 raise RuntimeError( 624 "softmax(): not supported when reducing across the batch dimension for NestedTensor" 625 ) 626 627 if reduce_on_ragged and inp._ragged_idx > 1: 628 raise RuntimeError( 629 "softmax(): not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor" 630 ) 631 632 if reduce_on_ragged and inp._lengths is not None: 633 raise RuntimeError( 634 "softmax(): not supported where lengths is not None " 635 + "if reducing across the ragged dimension for NestedTensor" 636 ) 637 638 new_kwargs["dim"] = new_kwargs["dim"][ 639 0 640 ] # torch.softmax takes in the reduction dimension as an integer 641 642 if reduce_on_ragged: 643 padded_softmax_values = torch.nn.functional.softmax( 644 torch.ops.aten._jagged_to_padded_dense_forward( 645 inp._values.reshape( 646 inp._values.shape[0], -1 647 ), # values are required to be 2D tensors for j2pd 648 [inp._offsets], 649 max_lengths=[inp._max_seqlen], # max length of ragged dimension 650 padding_value=float("-inf"), # e^-inf = 0 651 ), 652 dim=inp._ragged_idx, 653 ) 654 655 softmax_values = torch.ops.aten._padded_dense_to_jagged_forward( 656 padded_softmax_values, 657 [inp._offsets], 658 total_L=inp._values.shape[ 659 0 660 ], # providing this parameter helps avoid a GPU/CPU sync 661 ).reshape( 662 -1, *inp._values.shape[1:] 663 ) # expand softmax_values back to original shape (inp._values.shape) 664 665 return NestedTensor(softmax_values, **extract_kwargs(inp)) 666 667 return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) 668 669 670@register_jagged_func( 671 torch.ops.aten._softmax_backward_data.default, 672 "grad_output: jt, output: jt, dim: any, input_dtype: any", 673) 674def _softmax_backward(func, *args, **kwargs): 675 _, new_kwargs = normalize_function( # type: ignore[misc] 676 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 677 ) 678 grad_out = new_kwargs.pop("grad_output") 679 output = new_kwargs.pop("output") 680 return NestedTensor( 681 func(grad_out._values, output._values, **new_kwargs), **extract_kwargs(grad_out) 682 ) 683 684 685@register_jagged_func( 686 torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?" 687) 688def native_dropout_default(func, *args, **kwargs): 689 _, new_kwargs = normalize_function( # type: ignore[misc] 690 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 691 ) 692 693 inp = new_kwargs.pop("input") 694 out1, out2 = func(inp._values, **new_kwargs) 695 return ( 696 NestedTensor(out1, **extract_kwargs(inp)), 697 NestedTensor(out2, **extract_kwargs(inp)), 698 ) 699 700 701@register_jagged_func( 702 torch.ops.aten.native_dropout_backward.default, 703 "grad_output: jt, mask: jt, scale: any", 704) 705def native_dropout_backward_default(func, *args, **kwargs): 706 _, new_kwargs = normalize_function( # type: ignore[misc] 707 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 708 ) 709 grad_output = new_kwargs.pop("grad_output") 710 mask = new_kwargs.pop("mask") 711 return NestedTensor( 712 func(grad_output._values, mask._values, **new_kwargs), 713 **extract_kwargs(grad_output), 714 ) 715 716 717@register_jagged_func(torch.ops.aten.prod.dim_int, "self: jt, dim: any, keepdim: any?") 718def prod_dim_int(func, *args, **kwargs): 719 _, new_kwargs = normalize_function( # type: ignore[misc] 720 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 721 ) 722 723 inp = new_kwargs.pop("input") 724 # TODO: Figure out how to handle this better 725 # keep_dim is required to keep it in jagged format 726 if not new_kwargs["keepdim"]: 727 raise RuntimeError("prod(): keepdim=True must be set for NestedTensor") 728 dim = new_kwargs["dim"] 729 new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), dim, "prod") 730 731 return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(args[0])) 732 733 734@register_jagged_func( 735 torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any" 736) 737def split_tensor(func, *args, **kwargs): 738 _, new_kwargs = normalize_function( # type: ignore[misc] 739 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 740 ) 741 742 inp = new_kwargs.pop("input") 743 744 new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "split") 745 746 return tuple( 747 NestedTensor(values=x, **extract_kwargs(inp)) 748 for x in func(inp._values, **new_kwargs) 749 ) 750 751 752@register_jagged_func( 753 torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any" 754) 755def split_with_sizes_default(func, *args, **kwargs): 756 _, new_kwargs = normalize_function( # type: ignore[misc] 757 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 758 ) 759 760 inp = new_kwargs.pop("input") 761 762 new_kwargs["dim"] = _wrap_jagged_dim( 763 inp.dim(), new_kwargs["dim"], "split_with_sizes" 764 ) 765 766 return [ 767 NestedTensor(values=x, **extract_kwargs(inp)) 768 for x in func(inp._values, **new_kwargs) 769 ] 770 771 772@register_jagged_func( 773 torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any" 774) 775def narrow(func, *args, **kwargs): 776 _, new_kwargs = normalize_function( # type: ignore[misc] 777 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 778 ) 779 inp = new_kwargs.pop("input") 780 781 dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "narrow") 782 values = func( 783 inp._values, 784 dim=dim, 785 start=new_kwargs["start"], 786 length=new_kwargs["length"], 787 ) 788 return NestedTensor(values, **extract_kwargs(inp)) 789 790 791@register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?") 792def chunk_default(func, *args, **kwargs): 793 _, new_kwargs = normalize_function( # type: ignore[misc] 794 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 795 ) 796 797 inp = new_kwargs.pop("input") 798 799 new_kwargs["dim"] = _wrap_jagged_dim( 800 inp.dim(), new_kwargs["dim"], "chunk", allow_batch_dim=True 801 ) 802 803 if new_kwargs["dim"] == 0: 804 chunks = new_kwargs["chunks"] 805 dim0_size = inp._size[0] 806 chunk_size = math.ceil(dim0_size / chunks) 807 808 # get _offsets of the chunks 809 lengths = inp._offsets.diff() 810 chunked_lengths = lengths.chunk(chunks) 811 chunked_offsets = [torch.cumsum(x, dim=0) for x in chunked_lengths] 812 chunked_offsets = [F.pad(x, (1, 0), value=0) for x in chunked_offsets] # type: ignore[arg-type] 813 nested_kwargs = [ 814 {"offsets": per_offsets, "_ragged_idx": inp._ragged_idx} 815 for per_offsets in chunked_offsets 816 ] 817 818 # get _values of the chunks 819 split_sizes = [x.sum().item() for x in chunked_lengths] 820 chunk_values = inp._values.split(split_sizes) 821 822 return [ 823 NestedTensor(values=chunk_values[i], **(nested_kwargs[i])) 824 for i in range(0, chunk_size) 825 ] 826 else: 827 return [ 828 NestedTensor(values=x, **extract_kwargs(inp)) 829 for x in func(inp._values, **new_kwargs) 830 ] 831 832 833@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?") 834def unbind_int(func, *args, **kwargs): 835 # Note that this specializes on the length of the offsets 836 _, new_kwargs = normalize_function( # type: ignore[misc] 837 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 838 ) 839 840 dim = new_kwargs["dim"] 841 if dim != 0: 842 raise RuntimeError("unbind(): only supported for NestedTensor on dim=0") 843 844 inp = new_kwargs.pop("input") 845 values = inp.values() 846 offsets = inp.offsets() 847 lengths = inp.lengths() 848 ragged_idx = inp._ragged_idx 849 850 if lengths is None: 851 return torch.split(values, offsets.diff().tolist(), dim=(ragged_idx - 1)) 852 853 if ragged_idx <= 0: 854 raise RuntimeError( 855 "unbind(): nested tensor ragged_idx out of bounds (should be >= 1)" 856 ) 857 for i in range(lengths.shape[0]): 858 if offsets[i] + lengths[i] > values.shape[ragged_idx - 1]: 859 raise RuntimeError( 860 "unbind(): nested tensor offsets and lengths do not match ragged_idx dimension" 861 ) 862 return [ 863 torch.narrow(values, dim=(ragged_idx - 1), start=offsets[i], length=lengths[i]) 864 for i in range(lengths.shape[0]) 865 ] 866 867 868@register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any") 869def squeeze_dim(func, *args, **kwargs): 870 _, new_kwargs = normalize_function( # type: ignore[misc] 871 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 872 ) 873 874 inp = new_kwargs.pop("input") 875 values = inp._values 876 877 new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), new_kwargs["dim"], "squeeze") 878 return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp)) 879 880 881@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any") 882def unsqueeze_default(func, *args, **kwargs): 883 _, new_kwargs = normalize_function( # type: ignore[misc] 884 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 885 ) 886 887 inp = new_kwargs.pop("input") 888 values = inp._values 889 890 # Account for collapsed jagged dim 891 dim = new_kwargs["dim"] 892 new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size) + 1, dim, "unsqueeze") 893 return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp)) 894 895 896@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any") 897def cat_default(func, *args, **kwargs): 898 _, new_kwargs = normalize_function( # type: ignore[misc] 899 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 900 ) 901 902 tensors = new_kwargs.pop("tensors") 903 904 # Convert any non-nested to nested 905 nested = [t for t in tensors if t.is_nested] 906 assert len(nested) > 0 907 first = nested[0] 908 tensors = [t if t.is_nested else t.expand_as(first) for t in tensors] 909 910 # Account for collapsed jagged dim 911 dim = new_kwargs["dim"] 912 new_kwargs["dim"] = _wrap_jagged_dim(len(first.shape), dim, "cat") 913 914 return NestedTensor( 915 func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0]) 916 ) 917 918 919@register_jagged_func(torch.ops.aten.matmul.default, "self: jt, other: any") 920def matmul_default(func, *args, **kwargs): 921 _, new_kwargs = normalize_function( # type: ignore[misc] 922 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 923 ) 924 925 inp = new_kwargs.pop("input") 926 other = new_kwargs.pop("other") 927 928 if inp.is_nested and not other.is_nested: 929 return NestedTensor( 930 func(inp._values, other, **new_kwargs), **extract_kwargs(inp) 931 ) 932 elif inp.is_nested and other.is_nested: 933 # BMM with equivalent ragged dims between the two inputs 934 if inp.dim() > 3 and other.dim() > 3 and raggedness_matches(inp, other._size): 935 return NestedTensor(func(inp._values, other._values), **extract_kwargs(inp)) 936 937 raise RuntimeError( 938 f"matmul(): not supported between inputs of shapes {inp._size} and {other.shape}" 939 ) 940 941 942@register_jagged_func( 943 torch.ops.aten.expand.default, "self: jt, size: any, implicit: any?" 944) 945def expand_default(func, *args, **kwargs): 946 _, new_kwargs = normalize_function( # type: ignore[misc] 947 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 948 ) 949 950 inp = new_kwargs.pop("input") 951 size = new_kwargs["size"] 952 953 assert ("implicit" not in new_kwargs) or (not new_kwargs.pop("implicit")) 954 if not raggedness_matches(inp, size): 955 raise RuntimeError(f"expand(): cannot expand shape {inp._size} -> {size}") 956 957 expand_arg = [-1, *size[2:]] 958 return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp)) 959 960 961@register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt") 962def expand_as_default(func, *args, **kwargs): 963 _, new_kwargs = normalize_function( # type: ignore[misc] 964 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 965 ) 966 967 inp = new_kwargs.pop("input") 968 other = new_kwargs.pop("other") 969 970 return NestedTensor(func(inp, other._values), **extract_kwargs(other)) 971 972 973@register_jagged_func(torch.ops.aten.where.self, "condition: jt, self: jt, other: jt") 974def where_self(func, *args, **kwargs): 975 _, new_kwargs = normalize_function( # type: ignore[misc] 976 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 977 ) 978 979 condition = new_kwargs.pop("condition") 980 inp = new_kwargs.pop("input") 981 other = new_kwargs.pop("other") 982 983 assert condition._size == other._size == inp._size 984 985 return NestedTensor( 986 func(condition._values, inp._values, other._values, **new_kwargs), 987 **extract_kwargs(condition), 988 ) 989 990 991@register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?") 992def _pin_memory_default(func, *args, **kwargs): 993 _, new_kwargs = normalize_function( # type: ignore[misc] 994 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 995 ) 996 997 inp = new_kwargs.pop("input") 998 999 return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) 1000 1001 1002@register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?") 1003def is_pinned_default(func, *args, **kwargs): 1004 _, new_kwargs = normalize_function( # type: ignore[misc] 1005 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1006 ) 1007 1008 inp = new_kwargs.pop("input") 1009 1010 return func(inp._values, **new_kwargs) 1011 1012 1013@register_jagged_func( 1014 torch.ops.aten.is_same_size.default, "self: jt_all, other: jt_all" 1015) 1016def is_same_size_default(func, *args, **kwargs): 1017 return args[0]._size == args[1]._size 1018 1019 1020@register_jagged_func( 1021 torch.ops.aten.sum.dim_IntList, 1022 "self: jt_all, dim: any?, keepdim: any?, dtype: any?", 1023) 1024def sum_dim_IntList(func, *args, **kwargs): 1025 """ 1026 Performs a sum along the provided tensor dimension. 1027 Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor. 1028 """ 1029 _, new_kwargs = normalize_function( # type: ignore[misc] 1030 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1031 ) 1032 inp = new_kwargs.pop("input") 1033 1034 ( 1035 new_kwargs["dim"], 1036 reduce_on_batch, 1037 reduce_on_ragged, 1038 reduce_on_non_batch, 1039 ) = _wrap_jagged_dims( 1040 inp.dim(), 1041 new_kwargs["dim"], 1042 "sum", 1043 inp._ragged_idx, 1044 ) 1045 1046 if reduce_on_ragged and inp._lengths is not None: 1047 raise RuntimeError( 1048 "sum(): not supported where lengths is not None " 1049 + "if reducing across the ragged dimension for NestedTensor" 1050 ) 1051 1052 if reduce_on_ragged: # raggedness reduced away --> return dense tensor 1053 if ( 1054 reduce_on_batch 1055 ): # reduction cases: (batch, ragged), (batch, ragged, non-batch), etc. 1056 out = func( 1057 inp._values, **new_kwargs 1058 ) # no need to read offsets --> apply sum directly on values 1059 else: 1060 if ( 1061 reduce_on_non_batch 1062 ): # invalid reduction cases: (ragged, non-batch), etc. 1063 raise RuntimeError( 1064 "sum(): not supported along a ragged and non-batch dimension for NestedTensor" 1065 ) 1066 # reduction cases: (ragged) 1067 values_ragged_dim_outer = inp._values.permute( 1068 inp._ragged_idx - 1, # outer dimension 1069 *range(0, inp._ragged_idx - 1), 1070 *range(inp._ragged_idx, inp.dim() - 1), 1071 ) # shift reduction dimension of values backward to outer dimension 1072 1073 # _jagged_to_padded_dense_forward requires values to be a 2D tensor 1074 # with the ragged dimension as the 0th dimension 1075 padded = torch.ops.aten._jagged_to_padded_dense_forward( 1076 values_ragged_dim_outer.reshape(values_ragged_dim_outer.shape[0], -1), 1077 [inp._offsets], 1078 max_lengths=[inp._max_seqlen], 1079 ) 1080 1081 padded_ragged_dim_original = padded.view( 1082 padded.shape[0], 1083 inp._max_seqlen, 1084 *values_ragged_dim_outer.shape[ 1085 1: 1086 ], # expand non-batch dimensions of padded tensor 1087 ).permute( 1088 0, 1089 *range(2, inp._ragged_idx + 1), 1090 1, 1091 *range(inp._ragged_idx + 1, inp.dim()), 1092 ) # shift reduction dimension of padded tensor forward to original ragged dimension 1093 1094 out = torch.sum( 1095 padded_ragged_dim_original, 1096 dim=inp._ragged_idx, 1097 ) # need to read offsets --> pad jagged dimension and apply sum 1098 1099 if new_kwargs["keepdim"]: 1100 # TODO: Fix this; it's a bug. should be unsqueezing on ragged_idx 1101 out = out.unsqueeze(0) 1102 return out 1103 else: # raggedness preserved --> return nested tensor 1104 if ( 1105 reduce_on_batch 1106 ): # invalid reduction cases: (batch), (batch, non-batch), etc. 1107 raise RuntimeError( 1108 "sum(): not supported along the batch dimension but not the ragged dimension for NestedTensor" 1109 ) 1110 # reduction cases: (non-batch), (non-batch, non-batch), etc. 1111 return NestedTensor( 1112 func(inp._values, **new_kwargs), **extract_kwargs(inp) 1113 ) # apply sum directly on values 1114 1115 1116@register_jagged_func( 1117 torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any" 1118) 1119def transpose_int(func, *args, **kwargs): 1120 _, new_kwargs = normalize_function( # type: ignore[misc] 1121 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1122 ) 1123 1124 from torch._prims_common import canonicalize_dims 1125 1126 inp = new_kwargs.pop("input") 1127 dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"])) 1128 1129 if inp._lengths is not None: 1130 raise ValueError( 1131 "transpose(): not supported on jagged layout nested tensor with holes" 1132 ) 1133 1134 # To support the SDPA API, inputs need to have the ragged idx transposed to dim 2 1135 # instead of 1, although the internal Flash and mem-effn implementations will 1136 # use the inputs with raggedness in dim 1. 1137 if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx: 1138 if dim0 == 0 or dim1 == 0: 1139 raise ValueError( 1140 "Transpose is not supported on the batch dimension for jagged NT" 1141 ) 1142 if dim0 == inp._ragged_idx: 1143 to_dim = dim1 1144 else: 1145 to_dim = dim0 1146 inp_kwargs = extract_kwargs(inp) 1147 inp_kwargs["_ragged_idx"] = to_dim 1148 return NestedTensor( 1149 inp.values().transpose( 1150 _outer_to_inner_dim(len(inp._size), dim0), 1151 _outer_to_inner_dim(len(inp._size), dim1), 1152 ), 1153 **inp_kwargs, 1154 ) 1155 1156 new_kwargs["dim0"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim0"], "transpose") 1157 new_kwargs["dim1"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim1"], "transpose") 1158 1159 return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) 1160 1161 1162@register_jagged_func(torch.ops.aten.permute.default, "self: jt_all, dims: any") 1163def permute_default(func, *args, **kwargs): 1164 _, new_kwargs = normalize_function( # type: ignore[misc] 1165 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1166 ) 1167 inp = new_kwargs.pop("input") 1168 dims = new_kwargs.pop("dims") 1169 inp_kwargs = extract_kwargs(inp) 1170 inp_dim = len(inp._size) 1171 1172 # The first two checks are the same as the checks in the normal permute implementation 1173 if inp_dim != len(dims): 1174 raise ValueError( 1175 f"permute(): number of dimensions in the tensor input ({inp_dim}) " 1176 + f"does not match the length of the desired ordering of dimensions ({len(dims)}).", 1177 ) 1178 1179 from torch._prims_common import canonicalize_dims 1180 1181 canonicalized_dims = canonicalize_dims(inp_dim, dims) 1182 1183 if len(canonicalized_dims) != len(set(canonicalized_dims)): 1184 raise ValueError("permute(): duplicate dims are not allowed.") 1185 1186 if inp._lengths is not None: 1187 raise ValueError( 1188 "permute(): not supported on jagged layout nested tensor with holes" 1189 ) 1190 if canonicalized_dims[0] != 0: 1191 raise ValueError( 1192 "Permute is not supported on the batch dimension for jagged NT" 1193 ) 1194 inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx) 1195 inner_dims = [_outer_to_inner_dim(inp_dim, dim) for dim in canonicalized_dims[1:]] 1196 new_kwargs["dims"] = inner_dims 1197 return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs) 1198 1199 1200@register_jagged_func( 1201 [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default], 1202 "self: jt_all, size: any", 1203) 1204def view_default(func, *args, **kwargs): 1205 _, new_kwargs = normalize_function( # type: ignore[misc] 1206 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1207 ) 1208 1209 inp = new_kwargs.pop("input") 1210 size = new_kwargs.pop("size") 1211 1212 if inp._ragged_idx != 1 and tuple(inp._size) != tuple(size): 1213 raise RuntimeError( 1214 f"view(): does not support ragged_idx != 1 except when inp._size == size. " 1215 f"inp._size is ({inp._size}) and size is ({size})." 1216 ) 1217 1218 # Ensure specified size still includes batch and ragged dims 1219 if len(size) < 3 or not raggedness_matches(inp, size): 1220 raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}") 1221 1222 # outer size: the size of the NT, e.g. [3, j0, 10] 1223 # inner size: the size of the values, e.g. [8, 10] (e.g. for offsets = [0, 3, 5, 8]) 1224 # this function gets inner_size[inner_idx] for a given inner_idx. 1225 # 1226 # example: for outer size [a, b, c, j0, d, e, f] 1227 # assume that j0 is ragged, other are concrete integers 1228 # and ragged_idx=3 1229 # inner size will be [b, c, inp._values.size(ragged_idx), d, e, f] 1230 # therefore: 1231 # inner_size[0] = outer_size[1] 1232 # inner_size[1] = outer_size[2] 1233 # inner_size[0] = inp._values.size(ragged_idx - 1) 1234 # inner_size[3] = outer_size[4] 1235 # inner_size[4] = outer_size[5] 1236 def get_inner_size(inner_idx): 1237 nonlocal inp, size 1238 if inner_idx == inp._ragged_idx - 1: 1239 return inp._values.size(inner_idx) 1240 else: 1241 return size[inner_idx + 1] 1242 1243 inner_size = [get_inner_size(i) for i in range(len(size) - 1)] 1244 1245 return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp)) 1246 1247 1248@register_jagged_func( 1249 torch.ops.aten.native_layer_norm.default, 1250 "input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any", 1251) 1252def native_layer_norm_default(func, *args, **kwargs): 1253 _, new_kwargs = normalize_function( # type: ignore[misc] 1254 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1255 ) 1256 1257 inp = new_kwargs.pop("input") 1258 1259 if inp.dim() <= 2: 1260 raise RuntimeError( 1261 "layer_norm(): not supported for NestedTensor objects with 2 or fewer dimensions" 1262 ) 1263 1264 normalized_shape = new_kwargs["normalized_shape"] 1265 ragged_size = inp.shape[inp._ragged_idx] 1266 1267 num_dims_not_normalized = inp.dim() - len(normalized_shape) 1268 1269 if ( 1270 num_dims_not_normalized == 0 1271 ): # error if trying to normalize over the batch dimension 1272 raise RuntimeError( 1273 "layer_norm(): not supported when normalizing over the batch dimension for NestedTensor" 1274 ) 1275 1276 if ragged_size in normalized_shape and inp._lengths is not None: 1277 raise RuntimeError( 1278 "layer_norm(): not supported where lengths is not None if operating on the ragged dimension for NestedTensor" 1279 ) 1280 1281 if ( 1282 ragged_size in normalized_shape 1283 ): # special handling for normalizing over the ragged dimension 1284 padded_input = torch.ops.aten._jagged_to_padded_dense_forward( 1285 inp._values.flatten( 1286 start_dim=inp._ragged_idx 1287 ), # _jagged_to_padded_dense_forward requires values to be a 2D tensor 1288 [inp._offsets], 1289 max_lengths=[inp._max_seqlen], # max length of ragged dimension 1290 ) 1291 1292 padded_mask = torch.ops.aten._jagged_to_padded_dense_forward( 1293 torch.ones((inp._values.shape[0], 1), device=inp.device, dtype=inp.dtype), 1294 [inp._offsets], 1295 max_lengths=[inp._max_seqlen], # max length of ragged dimension 1296 ).expand( 1297 padded_input.shape 1298 ) # mask elements outside of the ragged dimension and expand to the same shape as padded input (3D dense tensor) 1299 1300 ragged_lengths = ( 1301 inp._offsets.diff().unsqueeze(1).unsqueeze(1) * padded_input.shape[2] 1302 ) # ragged dim * inner dim, since we sum over dims (1, 2) (the layer on which we normalize) 1303 1304 mean = ( 1305 torch.sum( 1306 padded_input, 1307 dim=(1, 2), 1308 keepdim=True, 1309 ) 1310 / ragged_lengths 1311 ) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm 1312 1313 padded_normalized = ( 1314 padded_input - mean 1315 ) * padded_mask # mask elements outside of the ragged dimension size for correct variance calculation 1316 1317 variance = ( 1318 torch.sum( 1319 torch.square(padded_normalized), 1320 dim=(1, 2), 1321 keepdim=True, 1322 ) 1323 / ragged_lengths 1324 ) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm 1325 1326 std = torch.sqrt(variance + new_kwargs["eps"]) 1327 padded_layer_norm = padded_normalized / std 1328 1329 jagged_layer_norm_values = torch.ops.aten._padded_dense_to_jagged_forward( 1330 padded_layer_norm, 1331 [inp._offsets], 1332 total_L=inp._values.shape[ 1333 0 1334 ], # providing this parameter helps avoid a GPU/CPU sync 1335 ).unflatten( 1336 -1, inp.shape[inp._ragged_idx + 1 :] 1337 ) # unflatten last dimension back into original nested tensor shape, e.g. (B, *, WH) --> (B, *, W, H) 1338 1339 return ( 1340 NestedTensor(jagged_layer_norm_values, **extract_kwargs(inp)), 1341 mean, 1342 std, 1343 ) 1344 1345 output, mean, std = func(inp._values, **new_kwargs) 1346 return (NestedTensor(output, **extract_kwargs(inp)), mean, std) 1347 1348 1349@register_jagged_func( 1350 torch.ops.aten.native_layer_norm_backward.default, 1351 "grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any", 1352) 1353def native_layer_norm_backward_default(func, *args, **kwargs): 1354 _, new_kwargs = normalize_function( # type: ignore[misc] 1355 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1356 ) 1357 grad_out = new_kwargs.pop("grad_out") 1358 inp = new_kwargs.pop("input") 1359 d_input, d_gamma, d_beta = func(grad_out._values, inp._values, **new_kwargs) 1360 if d_input is None: 1361 return (None, d_gamma, d_beta) 1362 1363 return (NestedTensor(d_input, **extract_kwargs(inp)), d_gamma, d_beta) 1364 1365 1366@register_jagged_func(torch.ops.aten.select.int, "self: jt, dim: any, index: any") 1367def select_int(func, *args, **kwargs): 1368 _, new_kwargs = normalize_function( # type: ignore[misc] 1369 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1370 ) 1371 1372 inp = new_kwargs.pop("input") 1373 new_kwargs["dim"] = _wrap_jagged_dim( 1374 inp.dim(), new_kwargs["dim"], "select", allow_batch_dim=True 1375 ) 1376 1377 # handle batch dim slicing via unbind() for now 1378 # TODO: make this more efficient 1379 if new_kwargs["dim"] == 0: 1380 return inp.unbind()[new_kwargs["index"]] 1381 1382 return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) 1383 1384 1385@register_jagged_func( 1386 torch.ops.aten.slice.Tensor, 1387 "self: jt, dim: any?, start: any?, end: any?, step: any?", 1388) 1389def slice_tensor(func, *args, **kwargs): 1390 _, new_kwargs = normalize_function( # type: ignore[misc] 1391 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1392 ) 1393 1394 inp = new_kwargs.pop("input") 1395 new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "slice") 1396 1397 return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) 1398 1399 1400@register_jagged_func( 1401 torch.ops.aten.convolution.default, 1402 "input: jt, weight: t, bias: t?, stride: any, padding: any, " 1403 "dilation: any, transposed: any, output_padding: any, groups: any", 1404) 1405def convolution_default(func, *args, **kwargs): 1406 _, new_kwargs = normalize_function( # type: ignore[misc] 1407 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1408 ) 1409 1410 inp = new_kwargs.pop("input") 1411 1412 return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) 1413 1414 1415@register_jagged_func( 1416 torch.ops.aten.mean.dim, "self: jt_all, dim: any?, keepdim: any?, dtype: any?" 1417) 1418def mean_dim(func, *args, **kwargs): 1419 """ 1420 Performs a mean along the provided tensor dimension. 1421 Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor. 1422 """ 1423 _, new_kwargs = normalize_function( # type: ignore[misc] 1424 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1425 ) 1426 1427 if len(new_kwargs["dim"]) > 1: 1428 raise RuntimeError( 1429 "mean(): not supported across multiple dimensions for NestedTensor" 1430 ) 1431 1432 inp = new_kwargs.pop("input") 1433 1434 ( 1435 new_kwargs["dim"], 1436 reduce_on_batch, 1437 reduce_on_ragged, 1438 reduce_on_non_batch, 1439 ) = _wrap_jagged_dims( 1440 inp.dim(), 1441 new_kwargs["dim"], 1442 "mean", 1443 inp._ragged_idx, 1444 ) 1445 1446 if reduce_on_batch: 1447 raise RuntimeError( 1448 "mean(): not supported along the batch dimension but not the ragged dimension for NestedTensor" 1449 ) 1450 1451 if reduce_on_ragged and inp._lengths is not None: 1452 raise RuntimeError( 1453 "mean(): not supported where lengths is not None " 1454 + "if reducing across the ragged dimension for NestedTensor" 1455 ) 1456 1457 if not new_kwargs["keepdim"]: 1458 raise RuntimeError("mean(): not supported when keepdim=False for NestedTensor") 1459 1460 if reduce_on_ragged: # raggedness reduced away 1461 torch_sum = torch.sum(inp, dim=inp._ragged_idx, keepdim=new_kwargs["keepdim"]) 1462 1463 # for every non-batch dimension, 1464 # unsqueeze lengths into the same shape as the PyTorch sum, 1465 # as the extra dimensions must all be divided by the same length 1466 lengths = inp._offsets.diff() 1467 for _ in range(inp.dim() - 2): 1468 lengths = lengths.unsqueeze(-1) 1469 1470 return torch_sum / lengths.broadcast_to(torch_sum.shape) 1471 1472 return NestedTensor( 1473 func(inp._values, **new_kwargs), **extract_kwargs(inp) 1474 ) # raggedness preserved 1475 1476 1477@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any") 1478def stack_default(func, *args, **kwargs): 1479 _, new_kwargs = normalize_function( # type: ignore[misc] 1480 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1481 ) 1482 1483 # guaranteed this is non-empty if we got here 1484 tensors = new_kwargs.pop("tensors") 1485 for t in tensors: 1486 if not isinstance(t, NestedTensor): 1487 raise RuntimeError("stack(): expected all nested tensors inputs") 1488 1489 if t.dim() != tensors[0].dim(): 1490 raise RuntimeError( 1491 "stack(): expected all nested tensors to have the same dim" 1492 ) 1493 1494 if not raggedness_matches(t, tensors[0].shape): 1495 raise RuntimeError( 1496 "stack(): expected all nested tensors to have the same nested structure" 1497 ) 1498 1499 new_kwargs["dim"] = _wrap_jagged_dim( 1500 tensors[0].dim() + 1, new_kwargs["dim"], "stack" 1501 ) 1502 1503 return NestedTensor( 1504 func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0]) 1505 ) 1506 1507 1508@register_jagged_func( 1509 torch.ops.aten.embedding.default, 1510 "weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?", 1511) 1512def embedding_default(func, *args, **kwargs): 1513 _, new_kwargs = normalize_function( # type: ignore[misc] 1514 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1515 ) 1516 1517 # guaranteed this is non-empty if we got here 1518 indices = new_kwargs.pop("indices") 1519 weight = new_kwargs.pop("weight") 1520 1521 return NestedTensor( 1522 func(weight, indices._values, **new_kwargs), **extract_kwargs(indices) 1523 ) 1524 1525 1526@register_jagged_func( 1527 [ 1528 torch.ops.aten.values.default, 1529 torch.ops.aten._nested_get_values.default, 1530 ], 1531 "self: jt_all", 1532) 1533def values_default(func, *args, **kwargs): 1534 _, new_kwargs = normalize_function( # type: ignore[misc] 1535 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1536 ) 1537 1538 inp = new_kwargs.pop("input") 1539 1540 # TODO: Handle inference mode properly. 1541 # See https://github.com/pytorch/pytorch/issues/112024#issuecomment-1779554292 1542 return inp._values.detach() 1543 1544 1545@register_jagged_func(torch.ops.aten.all.default, "self: jt_all") 1546def all_default(func, *args, **kwargs): 1547 _, new_kwargs = normalize_function( # type: ignore[misc] 1548 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1549 ) 1550 1551 inp = new_kwargs.pop("input") 1552 1553 return func(inp._values) 1554 1555 1556@register_jagged_func( 1557 torch.ops.aten._nested_view_from_jagged.default, 1558 "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?", 1559) 1560def _nested_view_from_jagged_default(func, *args, **kwargs): 1561 _, new_kwargs = normalize_function( # type: ignore[misc] 1562 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1563 ) 1564 1565 values, offsets, lengths = ( 1566 new_kwargs["input"], 1567 new_kwargs["offsets"], 1568 new_kwargs["lengths"], 1569 ) 1570 ragged_idx = new_kwargs["ragged_idx"] 1571 min_seqlen = new_kwargs["min_seqlen"] 1572 max_seqlen = new_kwargs["max_seqlen"] 1573 metadata_cache = {} 1574 if min_seqlen is not None: 1575 metadata_cache["min_seqlen"] = min_seqlen 1576 if max_seqlen is not None: 1577 metadata_cache["max_seqlen"] = max_seqlen 1578 1579 return NestedTensor( 1580 values, 1581 offsets, 1582 lengths=lengths, 1583 _ragged_idx=ragged_idx, 1584 _metadata_cache=metadata_cache, 1585 ) 1586 1587 1588@register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all") 1589def _nested_get_offsets(func, *args, **kwargs): 1590 _, new_kwargs = normalize_function( # type: ignore[misc] 1591 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1592 ) 1593 1594 inp = new_kwargs.pop("input") 1595 return inp._offsets 1596 1597 1598@register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all") 1599def _nested_get_lengths(func, *args, **kwargs): 1600 _, new_kwargs = normalize_function( # type: ignore[misc] 1601 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1602 ) 1603 1604 inp = new_kwargs.pop("input") 1605 return inp._lengths 1606 1607 1608@register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all") 1609def _nested_get_ragged_idx(func, *args, **kwargs): 1610 _, new_kwargs = normalize_function( # type: ignore[misc] 1611 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1612 ) 1613 1614 inp = new_kwargs.pop("input") 1615 return inp._ragged_idx 1616 1617 1618@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all") 1619def _nested_get_min_seqlen(func, *args, **kwargs): 1620 _, new_kwargs = normalize_function( # type: ignore[misc] 1621 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1622 ) 1623 1624 inp = new_kwargs.pop("input") 1625 return inp._metadata_cache.get("min_seqlen", None) 1626 1627 1628@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all") 1629def _nested_get_max_seqlen(func, *args, **kwargs): 1630 _, new_kwargs = normalize_function( # type: ignore[misc] 1631 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1632 ) 1633 1634 inp = new_kwargs.pop("input") 1635 return inp._metadata_cache.get("max_seqlen", None) 1636 1637 1638# If a section of the Nested Tensor is fully masked out we still retain the section with a length of 0 1639@register_jagged_func(torch.ops.aten.masked_select.default, "self: jt, mask: any") 1640def masked_select_default(func, *args, **kwargs): 1641 _, new_kwargs = normalize_function( # type: ignore[misc] 1642 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 1643 ) 1644 inp = new_kwargs.pop("input") 1645 mask = new_kwargs.pop("mask") 1646 1647 if inp.ndim > 2: 1648 raise RuntimeError("masked_select only support 2-D selections currently") 1649 elif inp.shape != mask.shape: 1650 raise RuntimeError( 1651 f"Mask with shape {mask.shape} is not compatible with input's shape {inp.shape}" 1652 ) 1653 res_values = inp._values.masked_select(mask.values()) 1654 mask_cumsum = F.pad(mask.values().cumsum(dim=0), (1, 0)) # type: ignore[arg-type] 1655 1656 args = extract_kwargs(inp) 1657 args["offsets"] = mask_cumsum[inp._offsets] 1658 return NestedTensor( 1659 values=res_values, 1660 **args, 1661 ) 1662 1663 1664# Make the dummy available on the C++ side. 1665@register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any") 1666def _nested_get_jagged_dummy(func, *args, **kwargs): 1667 from torch.nested._internal.nested_tensor import _nt_view_dummy 1668 1669 return _nt_view_dummy() 1670 1671 1672with torch.library._scoped_library("aten", "IMPL") as aten: 1673 aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CPU") 1674 aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CUDA") 1675 aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "Meta") 1676