xref: /aosp_15_r20/external/pytorch/test/test_nestedtensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nestedtensor"]
2
3import io
4import itertools
5import math
6import sys
7import unittest
8from functools import partial
9from typing import Optional, Tuple
10
11import numpy as np
12
13import torch
14import torch._dynamo
15import torch._dynamo.testing
16import torch.nn
17import torch.nn.functional as F
18from torch.nested._internal.nested_tensor import (
19    buffer_from_jagged,
20    jagged_from_list,
21    nested_view_from_values_offsets,
22    NestedTensor,
23    ViewNestedFromBuffer,
24)
25from torch.testing._internal.common_cuda import (
26    PLATFORM_SUPPORTS_FUSED_ATTENTION,
27    SM70OrLater,
28    SM80OrLater,
29)
30from torch.testing._internal.common_device_type import (
31    dtypes,
32    dtypesIfCUDA,
33    instantiate_device_type_tests,
34    onlyCPU,
35    onlyCUDA,
36    ops,
37    PYTORCH_CUDA_MEMCHECK,
38    skipCPUIf,
39    skipCUDAIf,
40    skipCUDAIfRocm,
41    skipMeta,
42)
43from torch.testing._internal.common_dtype import floating_types_and_half
44from torch.testing._internal.common_utils import (
45    decorateIf,
46    freeze_rng_state,
47    gradcheck,
48    instantiate_parametrized_tests,
49    IS_FBCODE,
50    IS_WINDOWS,
51    markDynamoStrictTest,
52    NestedTensorTestCase,
53    parametrize,
54    run_tests,
55    skipIfSlowGradcheckEnv,
56    skipIfTorchDynamo,
57    subtest,
58    TEST_WITH_ROCM,
59    xfailIfTorchDynamo,
60)
61from torch.testing._internal.opinfo.definitions.nested import njt_op_db
62from torch.utils._pytree import tree_flatten
63from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
64
65
66# Tests are ported from pytorch/nestedtensor.
67# This makes porting as_nested_tensor easier in the future.
68
69
70def _iter_constructors():
71    # yield as_nested_tensor
72    yield torch.nested.nested_tensor
73
74
75# Returns True if the function recompiles between inputs1 and inputs2 with the
76# specified dynamic setting.
77def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
78    compile_count = [0]
79
80    def counter(gm, example_inputs):
81        compile_count[0] += 1
82        return gm
83
84    compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic)
85    compiled_f(*inputs1)
86    compiled_f(*inputs2)
87    return compile_count[0] > 1
88
89
90# Helper function to generate a pair of random nested tensors
91# one is contiguous, the other is not, but they appear to have same entries
92# an output nested tensor consists of
93# * `len(ragged_sizes)` matrices
94# * matrices[i].shape == (20, ragged_sizes[i])
95
96
97def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16):
98    xs = []
99    for size in ragged_sizes:
100        xs.append(torch.randn((size, 20), device=device, dtype=dtype))
101    # contiguous nested tensor
102    ys = []
103    for x in xs:
104        ys.append(x.transpose(-1, -2))
105    nt_contiguous = torch.nested.nested_tensor(ys)
106    # noncontiguous nested tensor
107    n = len(ragged_sizes)
108    nt_noncontiguous = torch.nested.nested_tensor(xs).transpose(-1, -2)
109    return nt_contiguous, nt_noncontiguous
110
111
112# Helper functions to pad a noncontiguous nested tensor
113# can be replaced once to_padded_tensor supports noncontiguous memory
114
115
116def noncontiguous_to_padded_tensor(input, shape=None):
117    tensors = input.unbind()
118    ntensors = len(tensors)
119    assert ntensors > 0
120    if shape is None:
121        shape = []
122        for size in tensors[0].shape:
123            shape.append(size)
124        for i in range(1, ntensors):
125            new_shape = tensors[i].shape
126            for j in range(len(shape)):
127                shape[j] = max(shape[j], new_shape[j])
128        shape = [ntensors] + shape
129    result = tensors[0].new_zeros(shape)
130    for itensor in range(ntensors):
131        tensor = tensors[itensor]
132        view = result[itensor]
133        for idim in range(tensor.dim()):
134            view = view.narrow(idim, 0, tensor.size(idim))
135        view.copy_(tensor)
136    return result
137
138
139# Helper function to generate a random nested tensor
140
141
142def random_nt(
143    device,
144    dtype,
145    num_tensors,
146    max_dims,
147    min_dims=None,
148    layout=torch.strided,
149    require_non_empty=True,
150):
151    if min_dims is None:
152        min_dims = tuple([0] * len(max_dims))
153
154    assert len(max_dims) == len(min_dims)
155    for min_dim, max_dim in zip(min_dims, max_dims):
156        assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim"
157        assert min_dim >= 0, "random_nt: min_dim must be non-negative"
158        if require_non_empty:
159            assert not (
160                min_dim == 0 and max_dim == 1
161            ), "random_nt: zero cannot be the only possible value if require_non_empty is True"
162
163    if require_non_empty:
164        # Select a random idx that will be required to be non-empty
165        non_zero_idx = torch.randint(low=0, high=num_tensors, size=(1,)).item()
166
167    ts1 = []
168    for i, _ in enumerate(range(num_tensors)):
169        tensor_dims = []
170        for min_dim, max_dim in zip(min_dims, max_dims):
171            new_min_dim = min_dim
172            if require_non_empty and i == non_zero_idx and min_dim == 0:
173                new_min_dim = 1
174            tensor_dims.append(
175                torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item()
176            )
177        t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
178        ts1.append(t1)
179
180    return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout)
181
182
183# Alternate approach to generating a random NT.
184# dims should be something like [5, None, 10], with None indicating that a
185# random ragged structure should be used
186def random_nt_from_dims(
187    dims, device=None, dtype=None, layout=torch.strided, requires_grad=False
188):
189    sizes = [
190        [
191            d if d is not None else torch.randint(2, 10, size=(1,)).item()
192            for d in dims[1:]
193        ]
194        for d in range(dims[0])
195    ]
196    return torch.nested.nested_tensor(
197        [torch.randn(*size) for size in sizes],
198        device=device,
199        dtype=dtype,
200        layout=layout,
201        requires_grad=requires_grad,
202    )
203
204
205# Creates an NT matching another NT's number of components and
206# shape / ragged structure for all dims specified to be -1.
207def random_nt_from_similar(other, dims=None):
208    if dims is None:
209        return torch.randn_like(other)
210    assert len(dims) == other.dim()
211    assert dims[0] == -1 or dims[0] == other.size(0)
212
213    ret_sizes = []
214    for t in other.unbind():
215        other_size = t.shape
216        ret_size = []
217        for i, d in enumerate(dims[1:]):
218            if d == -1:
219                ret_size.append(other_size[i])
220            else:
221                ret_size.append(d)
222        ret_sizes.append(ret_size)
223
224    return torch.nested.nested_tensor(
225        [torch.randn(*size) for size in ret_sizes], device=other.device
226    )
227
228
229# makes naming nice for tests that parametrize over layout.
230def layout_name(layout):
231    # e.g. "torch.jagged" -> "jagged"
232    return layout.__repr__().split(".")[-1]
233
234
235def get_op_name(layout):
236    # e.g. "<OpOverload(op='aten.sum', overload='dim_IntList')>" -> "sum"
237    return layout.__name__.split(".")[0].split("_")[-1]
238
239
240# Helper function for test_dummy_mha_with_nt
241@torch.fx.wrap
242def convert_dense_to_nested_tensor_legacy(values):
243    offsets = torch.arange(
244        0, values.shape[0] * values.shape[1] + 1, values.shape[1], device=values.device
245    )
246    metadata_cache = {"max_seqlen": values.shape[1], "min_seqlen": 1}
247    nt = ViewNestedFromBuffer.apply(
248        values.view(-1, values.shape[-1]), offsets, metadata_cache
249    )
250    return nt
251
252
253# Helper function for test_dummy_mha_with_nt
254@torch.fx.wrap
255def convert_jagged_to_nested_tensor_legacy(
256    values: torch.Tensor, offsets: torch.Tensor, max_length: int
257) -> torch.Tensor:
258    metadata_cache = {"max_seqlen": max_length, "min_seqlen": 1}
259    nt = ViewNestedFromBuffer.apply(values, offsets, metadata_cache)
260    return nt
261
262
263# Helper function for test_dummy_mha_with_nt
264@torch.fx.wrap
265def convert_nt_to_jagged_legacy(nt):
266    return buffer_from_jagged(nt)
267
268
269# Helper function for test_dummy_mha_with_nt
270@torch.fx.wrap
271def convert_dense_to_nested_tensor(values):
272    nt = torch.nested.as_nested_tensor(values, layout=torch.jagged)
273    return nt
274
275
276# Helper function for test_dummy_mha_with_nt
277@torch.fx.wrap
278def convert_jagged_to_nested_tensor(
279    values: torch.Tensor, offsets: torch.Tensor, max_length: int
280) -> torch.Tensor:
281    nt = torch.nested.nested_tensor_from_jagged(
282        values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length
283    )
284    return nt
285
286
287# Helper function for test_dummy_mha_with_nt
288def convert_nt_to_jagged(nt):
289    return nt.values()
290
291
292@markDynamoStrictTest
293class TestNestedTensor(NestedTensorTestCase):
294    @parametrize("batch_size", [2, 4])
295    @parametrize("max_seq_len", [3, 5])
296    @parametrize("vocab_size", [10, 20])
297    def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
298        data = []
299        nested_tensor_ref_list = []
300        for _ in range(batch_size):
301            if max_seq_len == 0:
302                length = 0
303            else:
304                length = np.random.randint(low=1, high=max_seq_len)
305            row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
306            data.append(row)
307            nested_tensor_ref_list.append(torch.Tensor(row))
308        nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
309        nested_tensor_list = nested_tensor.unbind()
310        for id in range(batch_size):
311            self.assertEqual(
312                nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64)
313            )
314
315    @parametrize("batch_size", [2, 4])
316    @parametrize("max_seq_len", [3, 5])
317    @parametrize("vocab_size", [10, 20])
318    def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
319        data = []
320        nested_tensor_ref_list = []
321        for _ in range(batch_size):
322            if max_seq_len == 0:
323                length = 0
324            else:
325                length = np.random.randint(low=1, high=max_seq_len)
326            row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
327            row = [list(item * np.arange(max_seq_len)) for item in row]
328            data.append(row)
329            nested_tensor_ref_list.append(torch.Tensor(row))
330        nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
331        nested_tensor_list = nested_tensor.unbind()
332        for id in range(batch_size):
333            self.assertEqual(
334                nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64)
335            )
336
337    @parametrize("batch_size", [2, 4])
338    @parametrize("max_seq_len", [3, 5])
339    @parametrize("vocab_size", [10, 20])
340    def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size):
341        data = []
342        nested_tensor_ref_list = []
343        for _ in range(batch_size):
344            if max_seq_len == 0:
345                length = 0
346            else:
347                length = np.random.randint(low=1, high=max_seq_len)
348            row = list(
349                np.random.randint(low=0, high=vocab_size, size=(length,)).astype(float)
350            )
351            row = [list(item * np.arange(max_seq_len)) for item in row]
352            data.append(row)
353            nested_tensor_ref_list.append(torch.Tensor(row))
354        nested_tensor = torch.nested.nested_tensor(data, dtype=torch.float)
355        nested_tensor_list = nested_tensor.unbind()
356        for id in range(batch_size):
357            self.assertEqual(
358                nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.float)
359            )
360
361    @torch.inference_mode()
362    def _test_unbind_case(self, a, b):
363        nt = torch.nested.nested_tensor([a, b])
364        a1, b1 = nt.unbind()
365        self.assertTrue(a is not a1)
366        self.assertTrue(b is not b1)
367
368        nt = torch.nested.nested_tensor([a, b], dtype=a.dtype)
369        a1, b1 = nt.unbind(0)
370        self.assertEqual(a, a1)
371        self.assertEqual(b, b1)
372
373        a = torch.randn((2, 3)).add_(1)
374        nt = torch.nested.nested_tensor([a])
375        self.assertEqual(a, nt.unbind(0)[0])
376
377    @torch.inference_mode()
378    def test_unbind_0(self):
379        self._test_unbind_case(torch.tensor([1, 2]), torch.tensor([7, 8]))
380
381    @torch.inference_mode()
382    def test_unbind_1(self):
383        self._test_unbind_case(torch.tensor([1]), torch.tensor([7]))
384
385    @torch.inference_mode()
386    def test_unbind_3(self):
387        self._test_unbind_case(torch.tensor([1.0]), torch.tensor([]))
388
389    @torch.inference_mode()
390    def test_unbind_4(self):
391        self._test_unbind_case(torch.tensor([]), torch.tensor([]))
392
393    @torch.inference_mode()
394    def test_unbind_dim(self):
395        def _test_fn(unbind_fn):
396            a = torch.rand(3, 2)
397            b = torch.rand(2, 3)
398            nt = torch.nested.nested_tensor([a, b])
399            self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1))
400
401        # Both of these tests are necessary, because we're using
402        # torch_function.
403        _test_fn(lambda x, dim: x.unbind(dim))
404        # TODO: Re-enable this once using torch_dispatch
405        # _test_fn(lambda x, dim: torch.unbind(x, dim))
406
407    @torch.inference_mode()
408    def test_nested_tensor(self):
409        self.assertRaises(
410            TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0]))
411        )
412        self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0))
413
414    @torch.inference_mode()
415    def test_nested_tensor_matching_dim(self):
416        self.assertRaisesRegex(
417            RuntimeError,
418            "Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.",
419            lambda: torch.nested.nested_tensor([torch.tensor(1.0), torch.tensor([])]),
420        )
421        self.assertRaisesRegex(
422            RuntimeError,
423            "Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.",
424            lambda: torch.nested.nested_tensor(
425                [torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])]
426            ),
427        )
428
429    @torch.inference_mode()
430    def test_default_nested_tensor(self):
431        self.assertRaises(TypeError, lambda: torch.nested.nested_tensor())
432        default_nested_tensor = torch.nested.nested_tensor([])
433        default_tensor = torch.tensor([])
434        # self.assertEqual(default_nested_tensor.nested_dim(), 1)
435        # self.assertEqual(default_nested_tensor.nested_size(), ())
436        self.assertEqual(default_nested_tensor.dim(), default_tensor.dim())
437        self.assertEqual(default_nested_tensor.layout, default_tensor.layout)
438        self.assertEqual(default_nested_tensor.device, default_tensor.device)
439        self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype)
440        self.assertEqual(
441            default_nested_tensor.requires_grad, default_tensor.requires_grad
442        )
443        self.assertIsNone(default_tensor.grad)
444        # TODO: Re-enable once we have a performance driven
445        # use case and implementation.
446        # self.assertEqual(default_nested_tensor.is_pinned(),
447        #                  default_tensor.is_pinned())
448
449    @torch.inference_mode()
450    def test_dim(self):
451        for constructor in _iter_constructors():
452            a1 = constructor([])
453            self.assertEqual(a1.dim(), 1)
454            a1 = constructor([torch.tensor(3.0)])
455            self.assertEqual(a1.dim(), 1)
456            a1 = constructor([torch.tensor([1, 2, 3, 4])])
457            self.assertEqual(a1.dim(), 2)
458
459    @unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.")
460    @torch.inference_mode()
461    def test_numel(self):
462        for constructor in _iter_constructors():
463            a1 = constructor([])
464            self.assertEqual(a1.numel(), 0)
465            a1 = constructor([torch.tensor(3.0), torch.tensor(4.0)])
466            self.assertEqual(a1.numel(), 2)
467            a1 = constructor([torch.randn(2, 2, 2)])
468            self.assertEqual(a1.numel(), 8)
469            a1 = constructor([torch.randn([1, 2, 3]), torch.randn(3, 2, 1)])
470            self.assertEqual(a1.numel(), 12)
471            a1 = constructor([torch.randn([1, 1, 3]), torch.randn(3, 2, 4)])
472            self.assertEqual(a1.numel(), 27)
473            a1 = constructor([torch.randn([5, 5, 5]), torch.randn(6, 6, 6)])
474            self.assertEqual(a1.numel(), 341)
475
476            # Interesting edge case
477            a1 = constructor([torch.randn([1, 2, 3]), torch.randn(1, 2, 0)])
478            self.assertEqual(a1.numel(), 6)
479
480    @torch.inference_mode()
481    def test_size(self):
482        for constructor in _iter_constructors():
483            a1 = constructor([])
484            self.assertRaisesRegex(
485                RuntimeError,
486                "NestedTensorImpl doesn't support sizes",
487                lambda: a1.size(),
488            )
489
490    def test_size_dim(self):
491        a = torch.nested.nested_tensor([])
492        self.assertEqual(a.size(0), 0)
493
494        a = torch.nested.nested_tensor([torch.tensor(1)])
495        self.assertEqual(a.size(0), 1)
496
497        a = torch.nested.nested_tensor([torch.tensor(1), torch.tensor(2)])
498        self.assertEqual(a.size(0), 2)
499
500        a = torch.nested.nested_tensor([torch.rand(1, 2), torch.rand(1, 8)])
501        self.assertEqual(a.size(0), 2)
502        self.assertEqual(a.size(1), 1)
503        self.assertRaisesRegex(
504            RuntimeError,
505            "Given dimension 2 is irregular and does not have a size",
506            lambda: a.size(2),
507        )
508
509        a = torch.nested.nested_tensor([torch.rand(3, 4), torch.rand(5, 4)])
510        self.assertEqual(a.size(0), 2)
511        self.assertRaisesRegex(
512            RuntimeError,
513            "Given dimension 1 is irregular and does not have a size",
514            lambda: a.size(1),
515        )
516        self.assertEqual(a.size(2), 4)
517
518    @unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.")
519    @torch.inference_mode()
520    def test_stride(self):
521        for constructor in _iter_constructors():
522            a1 = constructor([])
523            self.assertRaisesRegex(
524                RuntimeError,
525                "NestedTensorImpl doesn't support strides",
526                lambda: a1.stride(),
527            )
528
529    @unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.")
530    @torch.inference_mode()
531    def test_is_contiguous(self):
532        # Test empty case
533        nt_empty = torch.nested.nested_tensor([])
534        assert nt_empty.is_contiguous()
535        self.assertEqual(nt_empty, nt_empty.contiguous())
536
537        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
538
539        # Test contiguous case
540        assert nt_contiguous.is_contiguous()
541        self.assertEqual(nt_contiguous, nt_contiguous.contiguous())
542
543        # Test non_contiguous case
544        assert not nt_noncontiguous.is_contiguous()
545        self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous())
546
547        # Test querying by memory_format
548        self.assertTrue(
549            nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)
550        )
551        self.assertTrue(
552            not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)
553        )
554
555    @torch.inference_mode()
556    def test_repr_string(self):
557        a = torch.nested.nested_tensor([])
558        expected = "nested_tensor([\n\n])"
559        self.assertEqual(str(a), expected)
560        self.assertEqual(repr(a), expected)
561
562        a = torch.nested.nested_tensor([torch.tensor(1.0)])
563        expected = "nested_tensor([\n  tensor(1.)\n])"
564        self.assertEqual(str(a), expected)
565        self.assertEqual(repr(a), expected)
566
567        a = torch.nested.nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])])
568        expected = "nested_tensor([\n  tensor([[1, 2]]),\n  tensor([[4, 5]])\n])"
569        self.assertEqual(str(a), expected)
570        self.assertEqual(repr(a), expected)
571
572    def test_to_padded_tensor_on_empty_tensor(self):
573        nt = torch.nested.nested_tensor([])
574        empty = torch.nested.to_padded_tensor(nt, 4)
575        self.assertEqual(empty, torch.tensor([]))
576
577    def test_nested_namespace(self):
578        nt = torch.nested.nested_tensor([torch.randn(2, 3), torch.randn(4, 5)])
579        result = nt.to_padded_tensor(4)
580        nested_namespace_result = torch.nested.to_padded_tensor(nt, 4)
581        self.assertEqual(result, nested_namespace_result)
582
583    def test_to(self):
584        ntensors = 4
585        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
586
587        def test_copy_behavior(t, non_blocking=False):
588            self.assertIs(t, t.to(t, non_blocking=non_blocking))
589            self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
590            self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking))
591            self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True))
592            self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True))
593            self.assertIsNot(
594                t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)
595            )
596
597            devices = [t.device]
598            if t.device.type == "cuda":
599                if t.device.index == -1:
600                    devices.append(f"cuda:{torch.cuda.current_device()}")
601                elif t.device.index == torch.cuda.current_device():
602                    devices.append("cuda")
603            for device in devices:
604                self.assertIs(t, t.to(device, non_blocking=non_blocking))
605                self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking))
606                self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True))
607                self.assertIsNot(
608                    t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True)
609                )
610
611        test_copy_behavior(nt)
612        self.assertEqual(nt.device, nt.to("cpu").device)
613        self.assertEqual(nt.device, nt.to("cpu", dtype=torch.float32).device)
614        self.assertIs(torch.float32, nt.to("cpu", dtype=torch.float32).dtype)
615        self.assertEqual(nt.device, nt.to(torch.float32).device)
616        self.assertIs(torch.float32, nt.to(dtype=torch.float32).dtype)
617
618        def test_data_ptr(getter):
619            self.assertEqual(getter(nt), getter(nt.to("cpu")))
620            self.assertEqual(
621                getter(nt), getter(nt.to(dtype=nt.dtype, device=nt.device, copy=False))
622            )
623            self.assertEqual(getter(nt), getter(nt.to("cpu", copy=False)))
624            self.assertNotEqual(getter(nt), getter(nt.to("cpu", copy=True)))
625
626        test_data_ptr(lambda nt: nt.data_ptr())
627
628        if torch.cuda.is_available():
629            for non_blocking in [True, False]:
630                for cuda in [
631                    "cuda",
632                    "cuda:0" if torch.cuda.device_count() == 1 else "cuda:1",
633                ]:
634                    nt2 = random_nt(cuda, torch.float32, ntensors, (4, 4))
635                    test_copy_behavior(nt2, non_blocking)
636                    self.assertEqual(
637                        nt2.device, nt2.to(cuda, non_blocking=non_blocking).device
638                    )
639                    self.assertEqual(
640                        nt.device, nt2.to("cpu", non_blocking=non_blocking).device
641                    )
642                    self.assertEqual(
643                        nt2.device, nt.to(cuda, non_blocking=non_blocking).device
644                    )
645                    self.assertIs(
646                        torch.int32,
647                        nt2.to(
648                            "cpu", dtype=torch.int32, non_blocking=non_blocking
649                        ).dtype,
650                    )
651                    self.assertEqual(
652                        nt.device,
653                        nt2.to(
654                            "cpu", dtype=torch.int32, non_blocking=non_blocking
655                        ).device,
656                    )
657                    self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype)
658                    self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device)
659
660    def test_copy_(self):
661        ntensors = 4
662        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
663        nt_copy = torch.empty_like(nt)
664        nt_copy.copy_(nt)
665
666        for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
667            self.assertEqual(nt_ub, nt_copy_ub)
668
669        nt_error = torch.nested.nested_tensor([torch.tensor([0, 0])])
670        self.assertRaisesRegex(
671            RuntimeError,
672            "copy_ only supports tensors that are the same size for Nested implementations",
673            lambda: nt_error.copy_(nt),
674        )
675
676        if torch.cuda.is_available():
677            nt = random_nt(torch.device("cuda"), torch.float32, ntensors, (4, 4))
678            nt_copy = torch.empty_like(nt, device=torch.device("cpu"))
679            nt_copy.copy_(nt, non_blocking=True)
680            torch.cuda.current_stream(torch.cuda.current_device()).synchronize()
681            for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
682                self.assertEqual(nt_ub, nt_copy_ub)
683
684            nt_copy = torch.empty_like(nt, device=torch.device("cpu"))
685            nt_copy.copy_(nt, non_blocking=False)
686            for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
687                self.assertEqual(nt_ub, nt_copy_ub)
688
689    def test_fill_(self):
690        ntensors = 4
691        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
692        nt.fill_(10.0)
693        for nt_ub in nt.unbind():
694            t = torch.empty_like(nt_ub)
695            t.fill_(10.0)
696            self.assertEqual(nt_ub, t)
697
698        fill_tensor = torch.tensor([11.0])
699        self.assertRaisesRegex(
700            RuntimeError,
701            "fill_ only supports 0-dimension value tensor",
702            lambda: nt.fill_(fill_tensor),
703        )
704
705        nt.fill_(fill_tensor[0])
706        for nt_ub in nt.unbind():
707            t = torch.empty_like(nt_ub)
708            t.fill_(11.0)
709            self.assertEqual(nt_ub, t)
710
711    def test_zero_(self):
712        ntensors = 4
713        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
714        nt.zero_()
715        for nt_ub in nt.unbind():
716            t = torch.empty_like(nt_ub)
717            t.fill_(0.0)
718            self.assertEqual(nt_ub, t)
719
720    @parametrize(
721        "func",
722        [torch.ones_like, torch.zeros_like, torch.randn_like],
723        name_fn=lambda f: f.__name__,
724    )
725    def test_like_functions(self, func):
726        ntensors = 4
727        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
728        torch.manual_seed(1)
729        nt_like = func(nt)
730
731        torch.manual_seed(1)
732        for nt_ub in nt_like.unbind():
733            t_like = func(nt_ub)
734            self.assertEqual(nt_ub, t_like)
735
736    def test_cat(self):
737        # dim=0 success case
738        # No constraints on ragged structures matching.
739        x = random_nt_from_dims([5, None, 10])
740        y = random_nt_from_dims([3, 4, None])
741        output = torch.cat([x, y], dim=0)
742        for out_component, xy_component in zip(
743            output.unbind(), itertools.chain(x.unbind(), y.unbind())
744        ):
745            self.assertEqual(out_component, xy_component)
746
747        # dim=-1 success case
748        # shape (B, *, D)
749        x = random_nt_from_dims([5, None, 10])
750        # shape (B, *, D'); same structure as x but dim=-1 differs
751        y = random_nt_from_similar(x, dims=[-1, -1, 8])
752        # should be shape (B, *, D + D') when supported
753        output = torch.cat([x, y], dim=-1)
754        for out_component, x_component, y_component in zip(
755            output.unbind(), x.unbind(), y.unbind()
756        ):
757            self.assertEqual(
758                out_component, torch.cat([x_component, y_component], dim=-1)
759            )
760
761        # dim between 0 and -1 success case
762        x = random_nt_from_dims([5, None, 2, 3])
763        # same structure as x but dim=2 differs
764        y = random_nt_from_similar(x, dims=[-1, -1, 4, -1])
765        output = torch.cat([x, y], dim=2)
766        for out_component, x_component, y_component in zip(
767            output.unbind(), x.unbind(), y.unbind()
768        ):
769            self.assertEqual(
770                out_component, torch.cat([x_component, y_component], dim=1)
771            )
772
773        # error case: mixed NT / dense inputs
774        x = random_nt_from_dims([5, None, 2])
775        y = torch.randn(5, 3, 2)
776        with self.assertRaisesRegex(
777            RuntimeError, "expected each tensor in given list to be nested"
778        ):
779            torch.cat([x, y], dim=-1)
780
781        # error case: NTs with different dims
782        x = random_nt_from_dims([5, None, 2])
783        y = random_nt_from_dims([5, None, 2, 3])
784        with self.assertRaisesRegex(
785            RuntimeError,
786            "expected all nested tensors to have matching ragged structures outside of the concatenated dim",
787        ):
788            torch.cat([x, y], dim=-1)
789
790        # error case: non-contiguous NT
791        x, y = random_nt_noncontiguous_pair((2, 3, 4), dtype=torch.float32)
792        # transpose to put ragged dim next to batch dim
793        x, y = x.transpose(-2, -1), y.transpose(-2, -1)
794        with self.assertRaisesRegex(
795            RuntimeError, "only contiguous nested tensors are supported"
796        ):
797            torch.cat([x, y], dim=-1)
798
799        # error case: multiple ragged dims in inputs
800        x = random_nt_from_dims([5, None, None, 2])
801        y = random_nt_from_similar(x)
802        with self.assertRaisesRegex(
803            RuntimeError,
804            "only nested tensors with a single ragged dim next to the batch dim are supported",
805        ):
806            torch.cat([x, y], dim=-1)
807
808        # error case: ragged dim not next to batch dim
809        x = random_nt_from_dims([5, 2, None])
810        y = random_nt_from_similar(x)
811        with self.assertRaisesRegex(
812            RuntimeError,
813            "only nested tensors with a single ragged dim next to the batch dim are supported",
814        ):
815            torch.cat([x, y], dim=1)
816
817        # error case: NTs with different batch sizes
818        x = random_nt_from_dims([5, None, 2])
819        y = random_nt_from_dims([3, None, 2])
820        with self.assertRaisesRegex(
821            RuntimeError,
822            "expected all nested tensors to have matching ragged structures outside of the concatenated dim",
823        ):
824            torch.cat([x, y], dim=-1)
825
826        # error case: NTs with different ragged structures
827        x = torch.nested.nested_tensor(
828            [
829                torch.randn(2, 6),
830                torch.randn(4, 6),
831                torch.randn(5, 6),
832            ]
833        )
834        y = torch.nested.nested_tensor(
835            [
836                torch.randn(5, 6),
837                torch.randn(4, 6),
838                torch.randn(2, 6),
839            ]
840        )
841        with self.assertRaisesRegex(
842            RuntimeError,
843            "expected all nested tensors to have matching ragged structures outside of the concatenated dim",
844        ):
845            torch.cat([x, y], dim=-1)
846
847
848@markDynamoStrictTest
849class TestNestedTensorDeviceType(NestedTensorTestCase):
850    # Helper function to generate a pair of random nested tensors
851    # the 2 nested tensors have same shapes
852    def random_nt_pair(self, device, dtype, num_tensors, max_dims):
853        ts1 = []
854        ts2 = []
855        for _ in range(num_tensors):
856            tensor_dims = tuple(
857                [
858                    torch.randint(low=0, high=max_dim, size=(1,)).item()
859                    for max_dim in max_dims
860                ]
861            )
862            t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
863            t2 = torch.randn(tensor_dims, device=device, dtype=dtype)
864            ts1.append(t1)
865            ts2.append(t2)
866        return (
867            torch.nested.nested_tensor(ts1, device=device, dtype=dtype),
868            torch.nested.nested_tensor(ts2, device=device, dtype=dtype),
869        )
870
871    @dtypes(*floating_types_and_half())
872    def test_detach(self, device, dtype):
873        a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False)
874        b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False)
875        x = torch.nested.nested_tensor([a, b], requires_grad=True)
876
877        x_detach = x.detach()
878
879        z = x_detach * 4
880        self.assertFalse(x_detach.requires_grad)
881        self.assertFalse(z.requires_grad)
882
883        a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True)
884        b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True)
885        x = torch.nested.as_nested_tensor([a, b])
886
887        y = x * 2
888        y = y.detach()
889        self.assertFalse(y.requires_grad)
890        self.assertIsNone(y.grad_fn)
891
892        z = x + y
893        torch.nested.to_padded_tensor(z, 0).sum().backward()
894        # This is an incorrect gradient, but we assume that's what the user
895        # wanted. detach() is an advanced option.
896        self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype))
897        self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype))
898
899    @dtypes(torch.float, torch.float16, torch.double)
900    def test_unbind_noncontiguous(self, device, dtype):
901        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
902            (2, 3, 6, 7), device, dtype
903        )
904        ub_contiguous = nt_contiguous.unbind()
905        ub_noncontiguous = nt_noncontiguous.unbind()
906        self.assertEqual(len(ub_contiguous), len(ub_noncontiguous))
907        n = len(ub_contiguous)
908        for i in range(n):
909            self.assertEqual(ub_contiguous[i], ub_noncontiguous[i])
910
911    @dtypes(torch.float)
912    @skipMeta
913    def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype):
914        t = torch.randn(4, 4, 4, device=device, dtype=dtype)
915        ts = list(torch.unbind(t))
916        ts[0] = ts[0][:-1]
917        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
918        padded = torch.nested.to_padded_tensor(nt, 0)
919
920        nt_to = torch._nested_from_padded_and_nested_example(padded, nt)
921
922        for t1, t2 in zip(nt.unbind(), nt_to.unbind()):
923            self.assertEqual(t1, t2)
924        self.assertEqual(nt.device, nt_to.device)
925
926    @dtypes(torch.float)
927    @dtypesIfCUDA(torch.float, torch.half)
928    @skipMeta
929    @torch.inference_mode()
930    def test_layer_norm(self, device, dtype):
931        def _test(size):
932            # Simple shapes test
933            t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
934            t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
935            ts = [t0, t1, t0, t1]
936            nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
937            layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
938            nt_result = layer_norm(nt)
939            for nt_subresult, t in zip(nt_result.unbind(), ts):
940                t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
941                self.assertEqual(nt_subresult, t_result)
942
943            # More complex nt test with different lengths for each tensor
944            t0 = torch.randn(4, size, device=device, dtype=dtype, requires_grad=False)
945            t1 = torch.randn(10, size, device=device, dtype=dtype, requires_grad=False)
946            t2 = torch.randn(7, size, device=device, dtype=dtype, requires_grad=False)
947            ts = [t0, t1, t2, t0, t2]
948            nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
949            layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
950            nt_result = layer_norm(nt)
951            for nt_subresult, t in zip(nt_result.unbind(), ts):
952                t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
953                self.assertEqual(nt_subresult, t_result)
954
955            if size <= 128:
956                # Test with multidimensional tensors after irregular dim
957                # (run only with smaller dimensions to ensure fast execution)
958                t0 = torch.randn(
959                    4, size, size, 4, device=device, dtype=dtype, requires_grad=False
960                )
961                t1 = torch.randn(
962                    10, size, size, 4, device=device, dtype=dtype, requires_grad=False
963                )
964                t2 = torch.randn(
965                    7, size, size, 4, device=device, dtype=dtype, requires_grad=False
966                )
967                ts = [t0, t1, t2, t0, t2]
968                nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
969                layer_norm = torch.nn.LayerNorm(
970                    (size, size, 4), device=device, dtype=dtype
971                )
972                nt_result = layer_norm(nt)
973                for nt_subresult, t in zip(nt_result.unbind(), ts):
974                    t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0))
975                    self.assertEqual(nt_subresult, t_result)
976
977                # Test where the normalizing dimensions are not all
978                layer_norm = torch.nn.LayerNorm((size, 4), device=device, dtype=dtype)
979                nt_result = layer_norm(nt)
980                for nt_subresult, t in zip(nt_result.unbind(), ts):
981                    t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0))
982                    self.assertEqual(nt_subresult, t_result)
983
984        for size in (1024, 1023, 513, 512, 256, 128, 2, 4, 32):
985            _test(size)
986
987    @dtypes(torch.float)
988    @dtypesIfCUDA(torch.float, torch.half)
989    @skipMeta
990    @torch.inference_mode()
991    def test_layer_norm_breaking(self, device, dtype):
992        size = 128
993        t0 = torch.randn(
994            4, size, size, 4, device=device, dtype=dtype, requires_grad=False
995        )
996        t1 = torch.randn(
997            10, size, size, 4, device=device, dtype=dtype, requires_grad=False
998        )
999        t2 = torch.randn(
1000            7, size, size, 4, device=device, dtype=dtype, requires_grad=False
1001        )
1002        ts = [t0, t1, t2, t0, t2]
1003        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1004        layer_norm = torch.nn.LayerNorm((4, size, size, 4), device=device, dtype=dtype)
1005        self.assertRaisesRegex(
1006            RuntimeError,
1007            "normalized_shape extends into irregular dimensions for the nested tensor",
1008            lambda: layer_norm(nt),
1009        )
1010        layer_norm = torch.nn.LayerNorm((size + 1, size, 4), device=device, dtype=dtype)
1011        self.assertRaisesRegex(
1012            RuntimeError,
1013            "The shape at dimension 0",
1014            lambda: layer_norm(nt),
1015        )
1016
1017    @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
1018    def test_embedding(self, device, layout):
1019        inputs = [
1020            torch.randint(100, (L,), device=device, dtype=torch.int64)
1021            for L in torch.randint(5, 50, (8,))
1022        ]
1023        x = torch.nested.nested_tensor(
1024            inputs, device=device, dtype=torch.int64, layout=layout
1025        )
1026        emb = torch.nn.Embedding(100, 8, device=device)
1027        y = emb(x)
1028
1029        @torch._dynamo.disable
1030        def check(inputs, y):
1031            ys = y.unbind()
1032            for i, inp in enumerate(inputs):
1033                self.assertEqual(emb(inp), ys[i])
1034
1035        check(inputs, y)
1036
1037    @skipMeta
1038    @torch.inference_mode()
1039    @dtypes(*floating_types_and_half())
1040    def test_masked_fill(self, device, dtype):
1041        # nested tensor * nested tensor
1042        (nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4))
1043        mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()])
1044        ref = torch.nested.nested_tensor(
1045            [t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())]
1046        )
1047        out = nt.masked_fill(mask, 0)
1048        self.assertEqual(ref, out)
1049
1050    @dtypes(torch.float, torch.float16)
1051    def test_to_padded_tensor_simple(self, device, dtype):
1052        t = torch.randn(4, 4, 4, device=device, dtype=dtype)
1053        ts = list(torch.unbind(t))
1054        ts[0] = ts[0][:-1]
1055        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1056        for padding_value in (0, 1):
1057            padded = torch.nested.to_padded_tensor(nt, padding_value)
1058
1059            correct_output = t.clone()
1060            if padding_value == 0:
1061                correct_output[0][-1] = torch.zeros_like(correct_output[0][-1])
1062            else:
1063                correct_output[0][-1] = torch.ones_like(correct_output[0][-1])
1064
1065            self.assertEqual(padded, correct_output)
1066            self.assertEqual(padded.device, torch.device(device))
1067            self.assertEqual(padded.dtype, dtype)
1068
1069    @dtypes(torch.float, torch.float16)
1070    def test_to_padded_tensor_output_size(self, device, dtype):
1071        t = torch.randn(4, 4, 4, device=device, dtype=dtype)
1072        output_size = (4, 6, 5)
1073        ts = list(torch.unbind(t))
1074        ts[0] = ts[0][:-1]
1075        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1076        for padding_value in (0, 1):
1077            padded = torch.nested.to_padded_tensor(
1078                nt, padding_value, output_size=output_size
1079            )
1080            correct_output = (
1081                torch.ones(output_size, device=device, dtype=dtype) * padding_value
1082            )
1083            correct_output[:4:, :4, :4] = t.clone()
1084            if padding_value == 0:
1085                correct_output[0][3] = torch.zeros_like(correct_output[0][3])
1086            else:
1087                correct_output[0][3] = torch.ones_like(correct_output[0][3])
1088
1089            self.assertEqual(padded, correct_output)
1090            self.assertEqual(padded.device, torch.device(device))
1091            self.assertEqual(padded.dtype, dtype)
1092
1093    @dtypes(torch.float, torch.float16, torch.double)
1094    def test_to_padded_tensor_dim2(self, device, dtype):
1095        ts = [
1096            torch.randn(160, device=device, dtype=dtype),
1097            torch.randn(1240, device=device, dtype=dtype),
1098            torch.randn(2400, device=device, dtype=dtype),
1099        ]
1100        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1101        pad = 42
1102        correct_output = []
1103        for t in ts:
1104            next_output = torch.ones_like(ts[2]) * pad
1105            correct_output.append(next_output)
1106            next_output[: t.size(0)].copy_(t)
1107        correct_output = torch.stack(correct_output)
1108        padded = torch.nested.to_padded_tensor(nt, pad)
1109        self.assertEqual(padded, correct_output)
1110
1111    @dtypes(torch.float, torch.float16, torch.double)
1112    def test_to_padded_tensor_dim3(self, device, dtype):
1113        ts = [
1114            torch.randn(16, 21, device=device, dtype=dtype),
1115            torch.randn(24, 32, device=device, dtype=dtype),
1116            torch.randn(40, 53, device=device, dtype=dtype),
1117        ]
1118        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1119        pad = 42
1120        correct_output = []
1121        for t in ts:
1122            next_output = torch.ones_like(ts[2]) * pad
1123            correct_output.append(next_output)
1124            next_output[: t.size(0), : t.size(1)].copy_(t)
1125        correct_output = torch.stack(correct_output)
1126        padded = torch.nested.to_padded_tensor(nt, pad)
1127        self.assertEqual(padded, correct_output)
1128
1129    @dtypes(torch.float, torch.float16, torch.double)
1130    def test_to_padded_tensor_dim4(self, device, dtype):
1131        ts = [
1132            torch.randn(16, 21, 13, device=device, dtype=dtype),
1133            torch.randn(24, 32, 14, device=device, dtype=dtype),
1134            torch.randn(40, 53, 16, device=device, dtype=dtype),
1135        ]
1136        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1137        pad = 42
1138        correct_output = []
1139        for t in ts:
1140            next_output = torch.ones_like(ts[2]) * pad
1141            correct_output.append(next_output)
1142            next_output[: t.size(0), : t.size(1), : t.size(2)].copy_(t)
1143        correct_output = torch.stack(correct_output)
1144        padded = torch.nested.to_padded_tensor(nt, pad)
1145        self.assertEqual(padded, correct_output)
1146
1147    # TODO: test noncontiguous to_padded_tensor
1148    # For now this tests the functionality of noncontiguous_to_padded_tensor
1149    # and the error message of to_padded_tensor
1150    # since to_padded_tensor does not support noncontiguous buffer yet
1151    @dtypes(torch.float, torch.float16, torch.double)
1152    @torch.inference_mode()
1153    def test_to_padded_tensor_noncontiguous(self, device, dtype):
1154        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
1155            (2, 3, 6, 7), device, dtype
1156        )
1157        # test noncontiguous_to_padded_tensor functionality
1158        self.assertEqual(
1159            torch.nested.to_padded_tensor(nt_contiguous, 0.0),
1160            noncontiguous_to_padded_tensor(nt_noncontiguous),
1161        )
1162        # test to_padded_tensor error message
1163        self.assertRaisesRegex(
1164            RuntimeError,
1165            r"for now to_padded_tensor only supports contiguous nested tensor",
1166            lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0),
1167        )
1168
1169    @skipMeta
1170    def test_device_checks(self, device):
1171        nt = torch.nested.nested_tensor([], device=device)
1172        is_cuda = "cuda" in str(device)
1173        self.assertEqual(nt.is_cuda, is_cuda)
1174
1175    @dtypes(torch.float, torch.float16, torch.double)
1176    def test_nested_tensor_indexing(self, device, dtype):
1177        # edge case: empty nested tensor
1178        nt0 = torch.nested.nested_tensor([])
1179        self.assertRaises(IndexError, lambda: nt0[0])
1180        # normal case
1181        x0 = torch.randn((2, 5), device=device, dtype=dtype)
1182        x1 = torch.randn((3, 4), device=device, dtype=dtype)
1183        nt = torch.nested.nested_tensor([x0, x1])
1184        # single index: only support integer in the batch dimension
1185        self.assertEqual(nt[0], x0)
1186        self.assertEqual(nt[-1], x1)
1187        self.assertRaises(IndexError, lambda: nt[2])
1188        self.assertRaises(IndexError, lambda: nt[-3])
1189        self.assertRaises(NotImplementedError, lambda: nt[:])
1190        self.assertEqual(nt[...], nt)
1191        # tuple of indices: only support integer in the batch dimension
1192        #                 + all possible indexing in the original tensor dimensions
1193        self.assertEqual(nt[0, 0, 0], x0[0, 0])
1194        self.assertEqual(nt[0, 1, :], x0[1, :])
1195        self.assertEqual(nt[1, ...], x1)
1196        self.assertRaises(IndexError, lambda: nt[1, 4, 2])
1197        self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1])
1198        # test select on non-batch dimensions
1199        self.assertEqual(nt.select(1, 0)[0], x0.select(0, 0))
1200        self.assertEqual(nt.select(1, 0)[1], x1.select(0, 0))
1201        self.assertRaises(IndexError, lambda: nt.select(1, 3))
1202        self.assertEqual(nt.select(2, 0)[0], x0.select(1, 0))
1203        self.assertEqual(nt.select(2, 0)[1], x1.select(1, 0))
1204        self.assertRaises(IndexError, lambda: nt.select(2, 5))
1205        # make sure indexing returns a view
1206        nt[0].fill_(100.0)
1207        answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5))
1208        self.assertEqual(nt[0], answer)
1209        nt[1, 1, :].fill_(200.0)
1210        answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4)
1211        self.assertEqual(nt[1, 1, :], answer)
1212
1213        # Test that indexing works when requires_grad_(True)
1214        # previously this was failing because the backward kernel for select.int uses .sizes()
1215        nt = torch.nested.nested_tensor([x0, x1]).requires_grad_(True)
1216        self.assertEqual(nt[0], x0)
1217        self.assertEqual(nt[-1], x1)
1218        grad_x0 = torch.randn((2, 5), device=device, dtype=dtype)
1219        nt[0].backward(grad_x0)
1220        expected_grad = torch.nested.nested_tensor(
1221            [grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)]
1222        )
1223        self.assertEqual(nt.grad, expected_grad)
1224
1225    @parametrize(
1226        "func",
1227        [
1228            subtest(torch.nn.functional.relu, name="relu"),
1229            subtest(torch.nn.functional.relu_, name="relu_"),
1230            subtest(torch.nn.functional.gelu, name="gelu"),
1231            subtest(torch._C._nn.gelu_, name="gelu_"),
1232            subtest(torch.tanh, name="tanh"),
1233            subtest(torch.tanh_, name="tanh_"),
1234            subtest(torch.neg, name="neg"),
1235            subtest(torch.nn.functional.silu, name="silu"),
1236            subtest(partial(torch.nn.functional.silu, inplace=True), name="silu_"),
1237            subtest(torch.abs, name="abs"),
1238            subtest(torch.abs_, name="abs_"),
1239            subtest(torch.sgn, name="sgn"),
1240            subtest(torch.logical_not, name="logical_not"),
1241            subtest(torch.sin, name="sin"),
1242            subtest(torch.cos, name="cos"),
1243        ],
1244    )
1245    def test_activations(self, device, func):
1246        nt, nt_noncontiguous = random_nt_noncontiguous_pair(
1247            (2, 3, 6, 7), device=device, dtype=torch.float32
1248        )
1249        nested_result = func(nt)
1250        self.assertTrue(nested_result.is_nested)
1251        for t, t_res in zip(nt.unbind(), nested_result.unbind()):
1252            self.assertEqual(func(t), t_res)
1253        self.assertRaisesRegex(
1254            RuntimeError,
1255            "NestedTensor must be contiguous to get buffer.",
1256            lambda: func(nt_noncontiguous),
1257        )
1258
1259    @parametrize("func", [subtest(torch.ge, name="ge"), subtest(torch.eq, name="eq")])
1260    def test_binary_ops_with_scalar(self, device, func):
1261        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
1262            (2, 3, 6, 7), device=device, dtype=torch.float32
1263        )
1264        scalar = 0.0
1265
1266        # should work regardless of contiguity
1267        for nt in (nt_contiguous, nt_noncontiguous):
1268            nested_result = func(nt, scalar)
1269            self.assertTrue(nested_result.is_nested)
1270            for t, t_res in zip(nt.unbind(), nested_result.unbind()):
1271                self.assertEqual(func(t, scalar), t_res)
1272
1273    @dtypes(*floating_types_and_half())
1274    def test_nested_tensor_chunk(self, device, dtype):
1275        # Transformer use case
1276        a = torch.randn(3, 3 * 4, device=device, dtype=dtype)
1277        b = torch.randn(2, 3 * 4, device=device, dtype=dtype)
1278        c = torch.randn(1, 3 * 4, device=device, dtype=dtype)
1279        a_chunks = a.chunk(3, dim=-1)
1280        b_chunks = b.chunk(3, dim=-1)
1281        c_chunks = c.chunk(3, dim=-1)
1282
1283        a_nt = [a_chunks[0], b_chunks[0], c_chunks[0]]
1284        b_nt = [a_chunks[1], b_chunks[1], c_chunks[1]]
1285        c_nt = [a_chunks[2], b_chunks[2], c_chunks[2]]
1286
1287        nt = torch.nested.nested_tensor([a, b, c])
1288        chunked = nt.chunk(3, dim=-1)
1289
1290        self.assertEqual(chunked[0], torch.nested.nested_tensor(a_nt))
1291        self.assertEqual(chunked[1], torch.nested.nested_tensor(b_nt))
1292        self.assertEqual(chunked[2], torch.nested.nested_tensor(c_nt))
1293
1294        for chunk in chunked:
1295            self.assertFalse(chunk.is_contiguous())
1296
1297        # Failure chunking on ragged dimensions
1298        self.assertRaisesRegex(
1299            RuntimeError,
1300            "Chunk for nested tensors is currently only supported for the last dimension.",
1301            lambda: torch.chunk(nt, 5, dim=1),
1302        )
1303        self.assertRaisesRegex(
1304            RuntimeError,
1305            "Chunk for nested tensors is currently only supported for the last dimension.",
1306            lambda: torch.chunk(nt, 5, dim=0),
1307        )
1308
1309        # Failure on non-contiguous nt
1310        _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
1311        self.assertRaisesRegex(
1312            RuntimeError,
1313            "chunk expects `self` to be contiguous.",
1314            lambda: torch.chunk(nt_noncontiguous, 5, dim=-1),
1315        )
1316
1317        # Failure when calling non divisible n_chunks
1318        self.assertRaisesRegex(
1319            RuntimeError,
1320            "Chunk for nested tensors is only supported for "
1321            "nested tensors with trailing dimension divisible by chunks.",
1322            lambda: torch.chunk(nt, 5, dim=-1),
1323        )
1324
1325        # Failure when calling backward on a chunk
1326        a = torch.randn(3, 3 * 4, device=device, dtype=dtype, requires_grad=True)
1327        b = torch.randn(2, 3 * 4, device=device, dtype=dtype, requires_grad=True)
1328        nt_grad = torch.nested.as_nested_tensor([a, b])
1329        chunked = torch.chunk(nt_grad, 2, dim=-1)
1330        self.assertRaisesRegex(
1331            RuntimeError,
1332            "Nested Strided Tensor doesn't support chunk backward.",
1333            lambda: chunked[0].backward(chunked[0].clone()),
1334        )
1335
1336    @dtypes(*floating_types_and_half())
1337    def test_nested_tensor_split_with_sizes(self, device, dtype):
1338        a = torch.randn(3, 20, device=device, dtype=dtype)
1339        b = torch.randn(2, 20, device=device, dtype=dtype)
1340        c = torch.randn(1, 20, device=device, dtype=dtype)
1341
1342        split_sizes = [4, 6, 10]
1343        a_splits = a.split_with_sizes(split_sizes, dim=-1)
1344        b_splits = b.split_with_sizes(split_sizes, dim=-1)
1345        c_splits = c.split_with_sizes(split_sizes, dim=-1)
1346
1347        nt = torch.nested.nested_tensor([a, b, c])
1348        nt_splits = nt.split_with_sizes(split_sizes, dim=-1)
1349
1350        for i, nt_split in enumerate(nt_splits):
1351            self.assertEqual(
1352                nt_split,
1353                torch.nested.nested_tensor([a_splits[i], b_splits[i], c_splits[i]]),
1354            )
1355            dense_strides = torch.stack(
1356                [
1357                    torch.tensor(a_splits[i].stride()),
1358                    torch.tensor(b_splits[i].stride()),
1359                    torch.tensor(c_splits[i].stride()),
1360                ]
1361            )
1362            self.assertEqual(nt_split._nested_tensor_strides(), dense_strides)
1363            self.assertFalse(nt_split.is_contiguous())
1364
1365        # Failure calling on ragged dimensions
1366        self.assertRaisesRegex(
1367            RuntimeError,
1368            "split_with_sizes for nested tensors is currently only supported for the last dimension.",
1369            lambda: torch.split_with_sizes(nt, split_sizes, dim=1),
1370        )
1371
1372        # Failure calling on non-last dimension
1373        self.assertRaisesRegex(
1374            RuntimeError,
1375            "split_with_sizes for nested tensors is currently only supported for the last dimension.",
1376            lambda: torch.split_with_sizes(nt, split_sizes, dim=0),
1377        )
1378
1379        # Failure on non-contiguous nt
1380        _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
1381        self.assertRaisesRegex(
1382            RuntimeError,
1383            "split_with_sizes expects `self` to be contiguous.",
1384            lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1),
1385        )
1386
1387        # Failure when calling with split_sizes that don't cover the full dim size
1388        bad_split_sizes = [4, 6, 9]  # don't add up to 20
1389        self.assertRaisesRegex(
1390            RuntimeError,
1391            "split_with_sizes expects split_sizes to sum exactly to 20",
1392            lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1),
1393        )
1394
1395    @dtypes(torch.float, torch.float16, torch.double)
1396    @torch.inference_mode()
1397    def test_nested_tensor_indexing_noncontiguous(self, device, dtype):
1398        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
1399            (2, 3, 6, 7), device, dtype
1400        )
1401        self.assertEqual(nt_contiguous.size(0), nt_noncontiguous.size(0))
1402        n = nt_contiguous.size(0)
1403        for i in range(n):
1404            self.assertEqual(nt_contiguous[i], nt_noncontiguous[i])
1405
1406    @dtypes(torch.float, torch.float16)
1407    @skipMeta
1408    @torch.inference_mode()
1409    @parametrize("transpose", [True, False])
1410    def test_nested_tensor_add(self, device, dtype, transpose):
1411        if transpose:
1412            a = torch.randn(2, 2, 2, device=device, dtype=dtype)
1413            b = torch.rand(2, 2, 2, device=device, dtype=dtype)
1414            c = a.transpose(-1, -2).contiguous()
1415            d = b.transpose(-1, -2).contiguous()
1416            nt1 = torch.nested.nested_tensor([a, b, a, b])
1417            nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2)
1418        else:
1419            (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
1420        ref = torch.nested.nested_tensor(
1421            [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
1422        )
1423        out = nt1 + nt2
1424        self.assertEqual(ref, out)
1425
1426    @dtypes(torch.float, torch.float16)
1427    @skipMeta
1428    @torch.inference_mode()
1429    @parametrize("transpose", [True, False])
1430    def test_nested_tensor_sub(self, device, dtype, transpose):
1431        if transpose:
1432            a = torch.randn(2, 2, 2, device=device, dtype=dtype)
1433            b = torch.rand(2, 2, 2, device=device, dtype=dtype)
1434            c = a.transpose(-1, -2).contiguous()
1435            d = b.transpose(-1, -2).contiguous()
1436            nt1 = torch.nested.nested_tensor([a, b, a, b])
1437            nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2)
1438        else:
1439            (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
1440        ref = torch.nested.nested_tensor(
1441            [t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
1442        )
1443        out = nt1 - nt2
1444        self.assertEqual(ref, out)
1445
1446    @onlyCUDA
1447    @dtypes(torch.float, torch.float16)
1448    @torch.inference_mode()
1449    @parametrize("embedding_dim", [8, 128, 256, 384])
1450    def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim):
1451        def _test_add_mul(nt, t):
1452            ref_add = torch.nested.nested_tensor(
1453                [t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]
1454            )
1455            ref_mul = torch.nested.nested_tensor(
1456                [t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]
1457            )
1458            self.assertEqual(nt.add(t), ref_add)
1459            self.assertEqual(nt.mul(t), ref_mul)
1460
1461        batch_size = 32
1462        seq_lens = torch.randint(low=0, high=10, size=(batch_size,))
1463
1464        # [B, *, D], [B, 1, D] case
1465        ts = [torch.randn((seq_len, embedding_dim)) for seq_len in seq_lens]
1466        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1467        t = torch.randn((batch_size, 1, embedding_dim), device=device, dtype=dtype)
1468        _test_add_mul(nt, t)
1469
1470        # [B, *], [B, 1] case
1471        ts = [torch.randn(seq_len) for seq_len in seq_lens]
1472        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1473        t = torch.randn((batch_size, 1), device=device, dtype=dtype)
1474        _test_add_mul(nt, t)
1475
1476    @dtypes(torch.float, torch.float16)
1477    @skipMeta
1478    @torch.inference_mode()
1479    def test_nested_tensor_mul(self, device, dtype):
1480        # nested tensor * nested tensor
1481        (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
1482        ref = torch.nested.nested_tensor(
1483            [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
1484        )
1485        out = nt1 * nt2
1486        self.assertEqual(ref, out)
1487        # nested tensor * scalar
1488        number = 10.0
1489        scalar = torch.tensor(number).to(dtype).to(device)
1490        ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()])
1491        out_number0 = nt1 * number
1492        out_number1 = number * nt1
1493        out_scalar0 = nt1 * scalar
1494        out_scalar1 = scalar * nt1
1495        self.assertEqual(out_number0, ref)
1496        self.assertEqual(out_number1, ref)
1497        self.assertEqual(out_scalar0, ref)
1498        self.assertEqual(out_scalar1, ref)
1499        # error case: numel == 1 but dim > 0
1500        vector = torch.tensor([number]).to(dtype).to(device)
1501        self.assertRaisesRegex(
1502            RuntimeError,
1503            "Expected both self and other to be nested, but got a nested self and non-nested other",
1504            lambda: nt1.mul(vector),
1505        )
1506        self.assertRaisesRegex(
1507            RuntimeError,
1508            "Expected both self and other to be nested, but got a non-nested self and nested other",
1509            lambda: vector.mul(nt1),
1510        )
1511
1512    @dtypes(torch.float, torch.float16)
1513    @skipMeta
1514    @torch.inference_mode()
1515    def test_nested_tensor_div(self, device, dtype):
1516        nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4))
1517        scale = 4.0
1518        ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()])
1519        out = nt / 4.0
1520        self.assertEqual(ref, out)
1521        ref_transposed = ref.transpose(1, 2)
1522        out = nt.transpose(1, 2) / 4.0
1523        self.assertEqual(ref_transposed, out)
1524
1525        ref = torch.nested.nested_tensor(
1526            [t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())]
1527        )
1528        out = nt / nt2
1529        self.assertEqual(ref, out)
1530
1531        out = nt.transpose(1, 2) / nt2.transpose(1, 2)
1532        self.assertEqual(ref.transpose(1, 2), out)
1533
1534        nt_transpose_copy = torch.nested.nested_tensor(
1535            [t.transpose(0, 1) for t in nt.unbind()]
1536        )
1537
1538        self.assertRaisesRegex(
1539            RuntimeError,
1540            "div requires strides to match when given NestedTensors",
1541            lambda: nt_transpose_copy.transpose(1, 2) / nt2,
1542        )
1543
1544        nt = torch.nested.nested_tensor(
1545            [torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype
1546        )
1547        nt_chunks = nt.chunk(2, -1)
1548        self.assertRaisesRegex(
1549            RuntimeError,
1550            "div requires offsets to match when given NestedTensors",
1551            lambda: nt_chunks[0] / nt_chunks[1],
1552        )
1553
1554    @dtypes(torch.float, torch.float16)
1555    @skipMeta
1556    @torch.inference_mode()
1557    def test_nested_tensor_add_in_place(self, device, dtype):
1558        (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
1559        ref = torch.nested.nested_tensor(
1560            [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
1561        )
1562        nt1 += nt2
1563        self.assertEqual(ref, nt1)
1564
1565    @dtypes(torch.float, torch.float16)
1566    @skipMeta
1567    @torch.inference_mode()
1568    def test_nested_tensor_mul_in_place(self, device, dtype):
1569        # nested tensor * nested tensor
1570        (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
1571        ref = torch.nested.nested_tensor(
1572            [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
1573        )
1574        nt1 *= nt2
1575        self.assertEqual(ref, nt1)
1576        # nested tensor * scalar
1577        number = 10.0
1578        scalar = torch.tensor(number).to(dtype).to(device)
1579        ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()])
1580        out_number = nt1.clone()
1581        out_number *= number
1582        out_scalar = nt1.clone()
1583        out_scalar *= scalar
1584        self.assertEqual(out_number, ref)
1585        self.assertEqual(out_scalar, ref)
1586        self.assertRaisesRegex(
1587            RuntimeError,
1588            r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]",
1589            lambda: scalar.mul_(nt1),
1590        )
1591        # error case: numel == 1 but dim > 0
1592        vector = torch.tensor([number]).to(dtype).to(device)
1593        self.assertRaisesRegex(
1594            RuntimeError,
1595            "Expected both self and other to be nested, but got a nested self and non-nested other",
1596            lambda: nt1.mul_(vector),
1597        )
1598        self.assertRaisesRegex(
1599            RuntimeError,
1600            "Expected both self and other to be nested, but got a non-nested self and nested other",
1601            lambda: vector.mul_(nt1),
1602        )
1603
1604    @onlyCPU
1605    @skipMeta
1606    @dtypes(torch.float)
1607    def test_nested_tensor_sum_dim(self, device, dtype):
1608        params = ((2, (1, 1)), ((4), (4, 4)), (10, (3, 5, 7)))
1609
1610        def test_sum(device, dtype, ntensors, max_sizes, dim, keepdim=True):
1611            nt = random_nt(device, dtype, ntensors, max_sizes, require_non_empty=False)
1612            nt2 = nt.clone()
1613            ub2 = nt2.unbind()
1614            nt.requires_grad_(True)
1615            [t.requires_grad_(True) for t in ub2]
1616            nt_sum = nt.sum(dim=dim, keepdim=keepdim)
1617            ub2_sum = [t.sum(-1, keepdim=keepdim) for t in ub2]
1618            self.assertEqual(nt_sum, torch.nested.nested_tensor(ub2_sum))
1619
1620            # test backward
1621            # generate gradient tensor that has the same size as the output
1622            size = nt_sum._nested_tensor_size()
1623            gt2 = []
1624            for i in range(ntensors):
1625                gt2.append(torch.randn(size[i].tolist(), device=device, dtype=dtype))
1626            gt = torch.nested.nested_tensor(gt2).clone()
1627            nt_sum.backward(gt)
1628            for t2, g2 in zip(ub2_sum, gt2):
1629                t2.backward(g2)
1630            self.assertEqual(nt.grad, torch.nested.nested_tensor([t.grad for t in ub2]))
1631            return
1632
1633        for ntensors, max_sizes in params:
1634            test_sum(device, dtype, ntensors, max_sizes, len(max_sizes))
1635
1636        # Test error inputs
1637        with self.assertRaisesRegex(
1638            RuntimeError, "NestedTensor can only be reduced across the last"
1639        ):
1640            torch.nested.nested_tensor(
1641                [torch.tensor([3, 4, 5]), torch.tensor([1, 2])]
1642            ).sum(0, keepdim=True)
1643
1644        with self.assertRaisesRegex(
1645            RuntimeError, "NestedTensor only allows reduction of a single"
1646        ):
1647            torch.nested.nested_tensor(
1648                [torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])]
1649            ).sum([0, 1], keepdim=True)
1650
1651        with self.assertRaisesRegex(
1652            RuntimeError, "NestedTensor always requires keepdim=True for now."
1653        ):
1654            torch.nested.nested_tensor(
1655                [torch.tensor([3, 4, 5]), torch.tensor([1, 2])]
1656            ).sum(-1)
1657
1658    @dtypes(torch.float, torch.float16)
1659    def test_contiguous(self, device, dtype):
1660        # Since we don't have access to the buffer in python this is harder to show what
1661        # we are testing for. When we call chunk on a consistent dim of a NT
1662        # for chunk_size > 1 the resulting tensors are views of the original NT
1663        # whose numels is now less than the size of the buffer. Clone was
1664        # previously creating a new NT with a buffer that was the same size as the
1665        # original.
1666        nt_contiguous = torch.nested.nested_tensor(
1667            [
1668                torch.randn(2, 20, device=device, dtype=dtype),
1669                torch.randn(4, 20, device=device, dtype=dtype),
1670            ]
1671        )
1672        # Split up the last dimension which has a consistent size of 20 into 5 chunks
1673        chunks = nt_contiguous.chunk(5, dim=-1)
1674
1675        # # Check chunks are contiguous after calling contiguous
1676        for chunk in chunks:
1677            self.assertFalse(chunk.is_contiguous())
1678            self.assertTrue(chunk.contiguous().is_contiguous())
1679
1680    @dtypes(torch.float, torch.float16)
1681    @skipMeta
1682    def test_clone(self, device, dtype):
1683        nt1 = random_nt(device, dtype, 4, (4, 4), (1, 1))
1684        nt2 = nt1.clone()
1685        # Verify the values match
1686        self.assertEqual(nt1, nt2)
1687        # Verify modifying nt2 doesn't affect nt1
1688        nt2.mul_(nt1)
1689        ub1 = nt1.unbind()
1690        ub2 = nt2.unbind()
1691        for i in range(len(ub1)):
1692            self.assertNotEqual(ub1[i], ub2[i])
1693
1694        nt1.clone(memory_format=torch.preserve_format)
1695        msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast"
1696        with self.assertRaisesRegex(RuntimeError, msg):
1697            nt1.clone(memory_format=torch.channels_last)
1698
1699    # cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half'
1700    @decorateIf(xfailIfTorchDynamo, lambda params: params["layout"] == torch.jagged)
1701    @dtypes(torch.float, torch.double)
1702    @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
1703    def test_dropout(self, device, dtype, layout):
1704        # edge case: empty nested tensor
1705        # TODO: support empty NT in jagged layout
1706        if layout == torch.strided:
1707            nt0 = torch.nested.nested_tensor([], layout=layout)
1708            y = torch.nn.functional.dropout(nt0, 0.5)
1709            self.assertEqual(nt0, y)
1710        # normal nested tensor
1711        ntensors = 4
1712        if layout == torch.jagged:
1713            nt = random_nt(device, dtype, ntensors, (4, 4), (0, 3), layout=layout)
1714        else:
1715            nt = random_nt(device, dtype, ntensors, (4, 4), layout=layout)
1716        # edge case: invalid dropout
1717        self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
1718        self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
1719        self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1))
1720        self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1))
1721        # edge case: no dropout
1722        dropouter = torch.nn.Dropout(0.0)
1723        y0 = dropouter(nt)
1724        y1 = torch.nn.functional.dropout(nt, 0.0)
1725        self.assertEqual(nt, y0)
1726        self.assertEqual(nt, y1)
1727        # edge case: all dropout
1728        dropouter = torch.nn.Dropout(1.0)
1729        y0 = dropouter(nt)
1730        y1 = torch.nn.functional.dropout(nt, 1.0)
1731        nt0 = torch.zeros_like(nt)
1732        self.assertEqual(nt0, y0)
1733        self.assertEqual(nt0, y1)
1734        # normal case: normal dropout
1735        p = 0.2
1736        y = torch.nn.functional.dropout(nt, p)
1737        expect = nt.clone()
1738        if layout == torch.jagged:
1739            expect = torch.where(y == 0.0, y, nt)
1740            expect /= 1.0 - p
1741            self.assertEqual(y, expect)
1742        else:
1743            expect = nt.clone()
1744            for i in range(ntensors):
1745                actual_tensor = y[i].view(-1)
1746                expect_tensor = expect[i].view(-1)
1747                for j in range(actual_tensor.shape[0]):
1748                    if actual_tensor[j].item() == 0.0:
1749                        expect_tensor[j] = 0.0
1750                    else:
1751                        expect_tensor[j] /= 1.0 - p
1752            self.assertEqual(y, expect)
1753        with freeze_rng_state():
1754            dropouter = torch.nn.Dropout(p)
1755            y0 = dropouter(nt)
1756        with freeze_rng_state():
1757            y1 = torch.nn.functional.dropout(nt, p)
1758        self.assertEqual(y0, y1)
1759
1760    @dtypes(torch.float, torch.double)
1761    def test_dropout_noncontiguous(self, device, dtype):
1762        ntensors = 4
1763        nt0 = random_nt(device, dtype, ntensors, (4, 4))
1764        nt1 = nt0.transpose(-1, -2)
1765        p = 0.3
1766        with freeze_rng_state():
1767            dropouter = torch.nn.Dropout(p)
1768            y0 = dropouter(nt0)
1769        with freeze_rng_state():
1770            y1 = torch.nn.functional.dropout(nt1, p).transpose(-1, -2)
1771        self.assertEqual(y0, y1)
1772
1773    # cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half'
1774    @dtypes(torch.float, torch.double)
1775    def test_softmax(self, device, dtype):
1776        # normal nested tensor
1777        ntensors = 4
1778        nt = random_nt(device, dtype, ntensors, (4, 4))
1779        # error case: softmax across nested dimension
1780        self.assertRaisesRegex(
1781            RuntimeError,
1782            "Cannot apply softmax across nested dimension 0",
1783            lambda: torch.nn.functional.softmax(nt, 0),
1784        )
1785        self.assertRaisesRegex(
1786            RuntimeError,
1787            "Cannot apply softmax across nested dimension 0",
1788            lambda: torch.nn.functional.softmax(nt, -3),
1789        )
1790        # error case: dimension out of range
1791        self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3))
1792        self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4))
1793        # normal case: should equal to padding -inf
1794        softmaxer = torch.nn.Softmax(1)
1795        y0 = softmaxer(nt)
1796        y1 = torch.nn.functional.softmax(nt, 1)
1797        self.assertEqual(y0, y1)
1798        pt = torch.nested.to_padded_tensor(nt, float("-inf"))
1799        # if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan
1800        # however, physically speaking that should be 0.0
1801        expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0)
1802        self.assertEqual(torch.nested.to_padded_tensor(y0, 0.0), expect)
1803        # edge case: empty nested tensor
1804        nt0 = torch.nested.nested_tensor([])
1805        y = torch.nn.functional.softmax(nt0, 1)
1806        self.assertEqual(nt0, y)
1807        # edge case: nesting scalars
1808        nt1 = torch.nested.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)])
1809        self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0))
1810        self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1))
1811
1812    @dtypes(torch.float, torch.double)
1813    @torch.inference_mode()
1814    def test_softmax_noncontiguous(self, device, dtype):
1815        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
1816            (2, 3, 6, 7), device, dtype
1817        )
1818        self.assertEqual(
1819            torch.nn.functional.softmax(nt_contiguous, -1),
1820            torch.nn.functional.softmax(nt_noncontiguous, -1),
1821        )
1822
1823    def _test_bmm(self, device, dtype):
1824        # error case: not 3D tensors
1825        nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype)
1826        nt1 = torch.nested.nested_tensor(
1827            [torch.randn(2), torch.randn(3)], device=device, dtype=dtype
1828        )
1829        nt2 = torch.nested.nested_tensor(
1830            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
1831        )
1832        self.assertRaisesRegex(
1833            RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt0)
1834        )
1835        self.assertRaisesRegex(
1836            RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt1)
1837        )
1838        self.assertRaisesRegex(
1839            RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt2)
1840        )
1841        self.assertRaisesRegex(
1842            RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt0)
1843        )
1844        self.assertRaisesRegex(
1845            RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt1)
1846        )
1847        self.assertRaisesRegex(
1848            RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt2)
1849        )
1850        self.assertRaisesRegex(
1851            RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt0)
1852        )
1853        self.assertRaisesRegex(
1854            RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt1)
1855        )
1856        # error case: incompatible batch size
1857        nt0 = torch.nested.nested_tensor(
1858            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
1859        )
1860        nt1 = torch.nested.nested_tensor(
1861            [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))],
1862            device=device,
1863            dtype=dtype,
1864        )
1865        self.assertRaisesRegex(
1866            RuntimeError,
1867            "Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.",
1868            lambda: nt0.bmm(nt1),
1869        )
1870        self.assertRaisesRegex(
1871            RuntimeError,
1872            "Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.",
1873            lambda: nt1.bmm(nt0),
1874        )
1875        # error case: underlying matrices cannot be multiplied
1876        nt0 = torch.nested.nested_tensor(
1877            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
1878        )
1879        self.assertRaisesRegex(
1880            RuntimeError,
1881            r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)",
1882            lambda: nt0.bmm(nt0),
1883        )
1884        # normal nested tensor
1885        nt0 = torch.nested.nested_tensor(
1886            [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype
1887        )
1888        nt1 = torch.nested.nested_tensor(
1889            [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype
1890        )
1891        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
1892        expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(
1893            torch.nested.to_padded_tensor(nt1, 0.0)
1894        )
1895        if dtype == torch.float16:
1896            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
1897        else:
1898            self.assertEqual(actual, expect)
1899
1900        # nested tensor bmm normal tensor
1901        nt0 = torch.nested.nested_tensor(
1902            [torch.randn((2, 7)), torch.randn((3, 7))], device=device, dtype=dtype
1903        )
1904        nt1 = torch.rand(2, 7, 5, dtype=dtype, device=device)
1905        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
1906        expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1)
1907        if dtype == torch.float16:
1908            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
1909        else:
1910            self.assertEqual(actual, expect)
1911
1912        # nested tensor bmm normal tensor with non-contiguous view
1913        nt1 = torch.rand(2, 5, 7, dtype=dtype, device=device)
1914        nt1 = nt1.transpose(1, 2)
1915        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
1916        expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1)
1917        if dtype == torch.float16:
1918            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
1919        else:
1920            self.assertEqual(actual, expect)
1921
1922        # normal tensor bmm nested tensor
1923        nt0 = torch.rand(2, 5, 7, dtype=dtype, device=device)
1924        nt1 = torch.nested.nested_tensor(
1925            [torch.randn((7, 6)), torch.randn((7, 5))], device=device, dtype=dtype
1926        )
1927        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
1928        expect = nt0.bmm(torch.nested.to_padded_tensor(nt1, 0.0))
1929        if dtype == torch.float16:
1930            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
1931        else:
1932            self.assertEqual(actual, expect)
1933
1934        # test tensorcore path
1935        nt0 = torch.nested.nested_tensor(
1936            [torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype
1937        )
1938        nt1 = torch.nested.nested_tensor(
1939            [torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype
1940        )
1941        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
1942        expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(
1943            torch.nested.to_padded_tensor(nt1, 0.0)
1944        )
1945        if dtype == torch.float16:
1946            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
1947        else:
1948            self.assertEqual(actual, expect)
1949
1950    @onlyCUDA
1951    @dtypes(torch.float, torch.double, torch.float16)
1952    def test_bmm_cuda(self, device, dtype):
1953        self._test_bmm(device, dtype)
1954
1955    @onlyCPU
1956    # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
1957    @dtypes(torch.float, torch.double)
1958    def test_bmm_cpu(self, device, dtype):
1959        self._test_bmm(device, dtype)
1960
1961    # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
1962    @dtypes(torch.float, torch.double)
1963    def test_bmm_noncontiguous(self, device, dtype):
1964        nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair(
1965            (2, 3), device, dtype
1966        )
1967        nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair(
1968            (6, 7), device, dtype
1969        )
1970        self.assertEqual(
1971            nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
1972            nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous),
1973        )
1974
1975    @dtypes(torch.float, torch.double)
1976    def test_matmul_with_bmm_path(self, device, dtype):
1977        def unbind_rebind_matmul(nt1, nt2):
1978            t1s = nt1.unbind()
1979            t2s = nt2.unbind()
1980            out_ts = [t1.matmul(t2) for t1, t2 in zip(t1s, t2s)]
1981            return torch.nested.nested_tensor(out_ts)
1982
1983        # [N, n_head, *, head_dim], [N, n_head, head_dim, *]
1984        Ns = [1, 2, 5]
1985        n_heads = np.random.randint(2, 5)
1986        head_dim = 3
1987        t1s = []
1988        t2s = []
1989        for N in Ns:
1990            for _ in range(N):
1991                seq_len1 = np.random.randint(2, 5)
1992                seq_len2 = np.random.randint(2, 5)
1993                t1s.append(torch.randn(n_heads, seq_len1, head_dim))
1994                t2s.append(torch.randn(n_heads, head_dim, seq_len2))
1995            nt1 = torch.nested.nested_tensor(t1s, device=device, dtype=dtype)
1996            nt2 = torch.nested.nested_tensor(t2s, device=device, dtype=dtype)
1997            self.assertEqual(torch.matmul(nt1, nt2), unbind_rebind_matmul(nt1, nt2))
1998
1999        # test with noncontiguous
2000        t3s = []
2001        t4s = []
2002        for _ in range(N):
2003            seq_len = np.random.randint(2, 5)
2004            t3s.append(torch.randn(seq_len, n_heads, head_dim))
2005            t4s.append(torch.randn(seq_len, n_heads, head_dim))
2006        nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose(
2007            1, 2
2008        )
2009        nt4 = (
2010            torch.nested.nested_tensor(t4s, device=device, dtype=dtype)
2011            .transpose(1, 2)
2012            .transpose(2, 3)
2013        )
2014        self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4))
2015
2016    # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
2017    @dtypes(torch.float, torch.double)
2018    def test_matmul(self, device, dtype):
2019        # error case: one is nested but the other is not
2020        nt = torch.nested.nested_tensor(
2021            [torch.randn(2), torch.randn(3)], device=device, dtype=dtype
2022        )
2023        t = torch.randn(4, device=device, dtype=dtype)
2024        self.assertRaisesRegex(
2025            RuntimeError,
2026            "Expected both to be nested, but got a nested self and non-nested other",
2027            lambda: torch.matmul(nt, t),
2028        )
2029        self.assertRaisesRegex(
2030            RuntimeError,
2031            "Expected both to be nested, but got a non-nested self and nested other",
2032            lambda: torch.matmul(t, nt),
2033        )
2034        # error case: not 3+D tensors
2035        nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype)
2036        nt1 = torch.nested.nested_tensor(
2037            [torch.randn(2), torch.randn(3)], device=device, dtype=dtype
2038        )
2039        nt2 = torch.nested.nested_tensor(
2040            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
2041        )
2042        self.assertRaisesRegex(
2043            RuntimeError,
2044            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
2045            lambda: torch.matmul(nt0, nt0),
2046        )
2047        self.assertRaisesRegex(
2048            RuntimeError,
2049            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
2050            lambda: torch.matmul(nt0, nt1),
2051        )
2052        self.assertRaisesRegex(
2053            RuntimeError,
2054            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
2055            lambda: torch.matmul(nt0, nt2),
2056        )
2057        self.assertRaisesRegex(
2058            RuntimeError,
2059            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
2060            lambda: torch.matmul(nt1, nt0),
2061        )
2062        self.assertRaisesRegex(
2063            RuntimeError,
2064            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
2065            lambda: torch.matmul(nt1, nt1),
2066        )
2067        self.assertRaisesRegex(
2068            RuntimeError,
2069            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
2070            lambda: torch.matmul(nt1, nt2),
2071        )
2072        self.assertRaisesRegex(
2073            RuntimeError,
2074            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+",
2075            lambda: torch.matmul(nt2, nt0),
2076        )
2077        self.assertRaisesRegex(
2078            RuntimeError,
2079            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+",
2080            lambda: torch.matmul(nt2, nt1),
2081        )
2082        # error case: incompatible batch size
2083        nt0 = torch.nested.nested_tensor(
2084            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
2085        )
2086        nt1 = torch.nested.nested_tensor(
2087            [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))],
2088            device=device,
2089            dtype=dtype,
2090        )
2091        self.assertRaisesRegex(
2092            RuntimeError,
2093            r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
2094            lambda: torch.matmul(nt0, nt1),
2095        )
2096        self.assertRaisesRegex(
2097            RuntimeError,
2098            r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
2099            lambda: torch.matmul(nt1, nt0),
2100        )
2101        # error case: incompatible (wrong) batch sizes that shouldn't even broadcast?
2102        nt0 = torch.nested.nested_tensor(
2103            [torch.randn((2, 2, 4)), torch.randn((2, 3, 4))], device=device, dtype=dtype
2104        )
2105        nt1 = torch.nested.nested_tensor(
2106            [torch.randn((3, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype
2107        )
2108        self.assertRaisesRegex(
2109            RuntimeError,
2110            "matmul(): For nested tensors, batch dimensions must have the same sizes,",
2111            lambda: torch.matmul(nt0, nt1),
2112        )
2113        # error case: incompatible batch sizes that should technically broadcast
2114        nt0 = torch.nested.nested_tensor(
2115            [torch.randn((2, 2, 4)), torch.randn((1, 3, 4))], device=device, dtype=dtype
2116        )
2117        nt1 = torch.nested.nested_tensor(
2118            [torch.randn((1, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype
2119        )
2120        self.assertRaisesRegex(
2121            RuntimeError,
2122            "matmul(): For nested tensors, batch dimensions must have the same sizes,",
2123            lambda: torch.matmul(nt0, nt1),
2124        )
2125        # error case: underlying matrices cannot be multiplied
2126        nt0 = torch.nested.nested_tensor(
2127            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
2128        )
2129        self.assertRaisesRegex(
2130            RuntimeError,
2131            "matmul(): Nested tensors cannot be matrix multiplied",
2132            lambda: torch.matmul(nt0, nt0),
2133        )
2134        # normal nested tensor: 3D
2135        nt0 = torch.nested.nested_tensor(
2136            [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype
2137        )
2138        nt1 = torch.nested.nested_tensor(
2139            [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype
2140        )
2141        actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
2142        expect = torch.matmul(
2143            torch.nested.to_padded_tensor(nt0, 0.0),
2144            torch.nested.to_padded_tensor(nt1, 0.0),
2145        )
2146        self.assertEqual(actual, expect)
2147        # normal nested tensor: 4D (with testing for batch_size=1)
2148        nt0 = torch.nested.nested_tensor(
2149            [torch.randn((1, 2, 4)), torch.randn((8, 3, 7))], device=device, dtype=dtype
2150        )
2151        nt1 = torch.nested.nested_tensor(
2152            [torch.randn((1, 4, 6)), torch.randn((8, 7, 5))], device=device, dtype=dtype
2153        )
2154        actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
2155        expect = torch.matmul(
2156            torch.nested.to_padded_tensor(nt0, 0.0),
2157            torch.nested.to_padded_tensor(nt1, 0.0),
2158        )
2159        self.assertEqual(actual, expect)
2160        # normal nested tensor: 5D
2161        nt0 = torch.nested.nested_tensor(
2162            [torch.randn((8, 9, 2, 4)), torch.randn((8, 9, 3, 7))],
2163            device=device,
2164            dtype=dtype,
2165        )
2166        nt1 = torch.nested.nested_tensor(
2167            [torch.randn((8, 9, 4, 6)), torch.randn((8, 9, 7, 5))],
2168            device=device,
2169            dtype=dtype,
2170        )
2171        actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
2172        expect = torch.matmul(
2173            torch.nested.to_padded_tensor(nt0, 0.0),
2174            torch.nested.to_padded_tensor(nt1, 0.0),
2175        )
2176        self.assertEqual(actual, expect)
2177
2178    # only supported on CUDA for now
2179    @dtypes(torch.float, torch.double)
2180    def test_matmul_nt_with_broadcasted_t(self, device, dtype):
2181        # NT (B, *, C, D) with T (D, E) broadcasting case
2182        nt = random_nt_from_dims([3, None, 4, 5], device=device, dtype=dtype)
2183        t = torch.randn(5, 6, device=device, dtype=dtype)
2184        output = torch.matmul(nt, t)
2185
2186        # should be equivalent to matmul-ing each component with the dense tensor
2187        self.assertEqual(nt.size(0), output.size(0))
2188        for component, out_component in zip(nt, output):
2189            self.assertEqual(out_component, torch.matmul(component, t))
2190
2191    # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
2192    @dtypes(torch.float, torch.double)
2193    def test_matmul_noncontiguous(self, device, dtype):
2194        nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair(
2195            (2, 3), device, dtype
2196        )
2197        nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair(
2198            (6, 7), device, dtype
2199        )
2200        self.assertEqual(
2201            torch.matmul(nt0_contiguous.transpose(-1, -2), nt1_contiguous),
2202            torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous),
2203        )
2204
2205    @dtypes(torch.float, torch.double)
2206    def test_linear(self, device, dtype):
2207        a = torch.randn(1, 2, device=device, dtype=dtype)
2208        b = torch.randn(2, 2, device=device, dtype=dtype)
2209        c = torch.randn(3, 2, device=device, dtype=dtype)
2210        nt = torch.nested.nested_tensor([a, b, c])
2211
2212        weight = torch.randn(2, 2, device=device, dtype=dtype)
2213        bias = torch.randn(2, device=device, dtype=dtype)
2214        # success case
2215        torch.functional.F.linear(nt, weight, bias)
2216
2217        # invalid nested tensor dimension
2218        msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2"
2219        nt1 = torch.nested.nested_tensor(
2220            [
2221                torch.randn(1, device=device, dtype=dtype),
2222                torch.randn(2, device=device, dtype=dtype),
2223            ]
2224        )
2225        with self.assertRaisesRegex(RuntimeError, msg):
2226            torch.functional.F.linear(nt1, weight, bias)
2227
2228        # invalid weight shape
2229        msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3"
2230        weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype)
2231        with self.assertRaisesRegex(RuntimeError, msg):
2232            torch.functional.F.linear(nt, weight1, bias)
2233
2234        # inconsistent last dim of nested tensor
2235        msg = r"Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals:"
2236        nt2 = torch.nested.nested_tensor(
2237            [
2238                torch.randn(1, 2, device=device, dtype=dtype),
2239                torch.randn(2, 3, device=device, dtype=dtype),
2240            ]
2241        )
2242        with self.assertRaisesRegex(RuntimeError, msg):
2243            torch.functional.F.linear(nt2, weight, bias)
2244
2245        # Mismatch of nested tensor last dim and weight dimension
2246        weight2 = torch.randn(2, 4, device=device, dtype=dtype)
2247        msg = (
2248            r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'"
2249            r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4"
2250        )
2251        with self.assertRaisesRegex(RuntimeError, msg):
2252            torch.functional.F.linear(nt, weight2, bias)
2253
2254        # Nested tensor input and nested weight
2255        nt_weight = nt.clone()
2256        msg = r"Linear does not support nested weight when input is a nested tensor."
2257        with self.assertRaisesRegex(RuntimeError, msg):
2258            torch.functional.F.linear(nt, nt_weight, bias)
2259
2260    # TODO: test noncontiguous linear
2261    # For now this tests the error message of linear
2262    # since linear does not support noncontiguous buffer yet
2263    @dtypes(torch.float, torch.double)
2264    def test_linear_noncontiguous(self, device, dtype):
2265        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
2266            (2, 3, 6, 7), device, dtype
2267        )
2268        weight = torch.randn((8, 5), device=device, dtype=dtype)
2269        self.assertRaisesRegex(
2270            RuntimeError,
2271            r"for now linear only supports contiguous nested tensor",
2272            lambda: torch.nn.functional.linear(nt_noncontiguous, weight),
2273        )
2274
2275    @dtypes(torch.float, torch.float16, torch.double)
2276    def test_to_padded_tensor_zero_numel_errors(self, device, dtype):
2277        ts = [torch.ones(1, 0), torch.ones(0, 0)]
2278        nt = torch.nested.nested_tensor(
2279            ts, device=device, dtype=dtype, layout=torch.strided
2280        )
2281        self.assertRaisesRegex(
2282            RuntimeError,
2283            r"at least one constituent tensor should have non-zero numel",
2284            lambda: torch.nested.to_padded_tensor(nt, 0.0),
2285        )
2286
2287    @dtypes(torch.float, torch.float16, torch.double)
2288    def test_transpose(self, device, dtype):
2289        nt = random_nt(device, dtype, 4, (4, 4))
2290        # error case: transpose nested dimension
2291        self.assertRaisesRegex(
2292            RuntimeError,
2293            "Nested tensor dimension 0 cannot be transposed",
2294            lambda: nt.transpose(0, 1),
2295        )
2296        self.assertRaisesRegex(
2297            RuntimeError,
2298            "Nested tensor dimension 0 cannot be transposed",
2299            lambda: nt.transpose(1, -3),
2300        )
2301        # error case: dimension out of range
2302        self.assertRaises(IndexError, lambda: nt.transpose(1, 3))
2303        self.assertRaises(IndexError, lambda: nt.transpose(-4, -1))
2304        # normal case
2305        ntT = nt.transpose(-1, -2)
2306        ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
2307        pt = torch.nested.to_padded_tensor(nt, 0.0)
2308        ptT = pt.transpose(-1, -2)
2309        self.assertEqual(ptT, ptT_from_ntT)
2310
2311    @dtypes(torch.float, torch.float16, torch.double)
2312    def test_squeeze_unsqueeze(self, device, dtype):
2313        a = torch.arange(6).reshape(2, 3)
2314        b = torch.arange(15).reshape(5, 3)
2315        nt = torch.nested.nested_tensor([a, b], device=device, dtype=dtype)
2316        # error case: squeeze no dimension
2317        self.assertRaisesRegex(
2318            RuntimeError,
2319            "For nested tensors, squeeze without the dim argument",
2320            lambda: nt.squeeze(),
2321        )
2322        # error case: squeeze nested dimension
2323        self.assertRaisesRegex(
2324            RuntimeError,
2325            "For nested tensors, squeezing dimension 0",
2326            lambda: nt.squeeze(0),
2327        )
2328        # error case: dimension out of range
2329        self.assertRaises(IndexError, lambda: nt.squeeze(3))
2330        # error case: squeeze nested tensor of singleton tensors
2331        c = torch.ones(1)
2332        nt_singleton = torch.nested.nested_tensor([c, c], device=device, dtype=dtype)
2333        self.assertRaisesRegex(
2334            RuntimeError,
2335            "For nested tensors, squeezing a nested tensor of singleton",
2336            lambda: nt_singleton.squeeze(1),
2337        )
2338
2339        # squeezing a dim which does not have size 1 should be a no-op
2340        nt2 = nt.squeeze(-1)
2341        self.assertEqual(nt, nt2)
2342
2343        # test cases that should work
2344        nt_sizes = nt._nested_tensor_size()
2345        nt_strides = nt._nested_tensor_strides()
2346        for i in range(-2, 4):
2347            if i == 0:
2348                # cannot unsqueeze batch dim
2349                continue
2350            nt_unsqueezed = nt.unsqueeze(i)
2351            # negative dim will correspond to unsqueeze() applied at dim = dim + nt.dim() + 1
2352            wrapped_i = i + nt.dim() + 1 if i < 0 else i
2353            # col_index into nt size tensor is requires subtraction of 1 to ignore batch dim
2354            size_idx = wrapped_i - 1
2355            self.assertEqual(
2356                nt_unsqueezed._nested_tensor_size()[:, size_idx],
2357                torch.ones(2, dtype=torch.long),
2358            )
2359            unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx]
2360            if i == nt.ndim or i == -1:
2361                self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long))
2362            else:
2363                stride_col_after = nt_strides[:, size_idx]
2364                size_col_after = nt_sizes[:, size_idx]
2365                self.assertEqual(unsqueezed_stride, stride_col_after * size_col_after)
2366            nt_squeezed = nt_unsqueezed.squeeze(i)
2367            self.assertEqual(nt_squeezed, nt)
2368            self.assertEqual(nt_squeezed._nested_tensor_size(), nt_sizes)
2369            self.assertEqual(nt_squeezed._nested_tensor_strides(), nt_strides)
2370
2371    @dtypes(torch.float, torch.float16, torch.double)
2372    def test_transpose_inference_mode_interaction(self, device, dtype):
2373        nt = random_nt(device, dtype, 4, (4, 4))
2374        # Construct in default mode and transpose while in inference mode
2375        with torch.inference_mode():
2376            ntT = nt.transpose(-1, -2)
2377            ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
2378            pt = torch.nested.to_padded_tensor(nt, 0.0)
2379            ptT = pt.transpose(-1, -2)
2380            self.assertEqual(ptT, ptT_from_ntT)
2381
2382        # Construct and transpose while in inference mode
2383        with torch.inference_mode():
2384            nt = random_nt(device, dtype, 4, (4, 4))
2385            ntT = nt.transpose(-1, -2)
2386            ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
2387            pt = torch.nested.to_padded_tensor(nt, 0.0)
2388            ptT = pt.transpose(-1, -2)
2389            self.assertEqual(ptT, ptT_from_ntT)
2390
2391    @dtypes(torch.float, torch.float16, torch.double)
2392    def test_view(self, device, dtype):
2393        nt = random_nt(device, dtype, 4, (4, 4))
2394        # error case: empty shape
2395        self.assertRaisesRegex(
2396            RuntimeError,
2397            r"shape '\[\]' is invalid for a nested tensor",
2398            lambda: nt.view(()),
2399        )
2400        # error case: empty nested tensor
2401        nt_empty = torch.nested.nested_tensor([])
2402        self.assertRaisesRegex(
2403            RuntimeError,
2404            "empty nested tensor cannot be reshaped",
2405            lambda: nt_empty.view(-1),
2406        )
2407        # error case: -1 for batch size
2408        self.assertRaisesRegex(
2409            RuntimeError,
2410            r"view: For now nested view cannot change or infer the implicit batch dimension",
2411            lambda: nt.view(-1, 2, 3),
2412        )
2413        self.assertRaisesRegex(
2414            RuntimeError,
2415            r"shape '\[.*\]' is invalid for input of size [0-9]+",
2416            lambda: nt.view(4, 2, 3),
2417        )
2418        # normal case
2419        x0 = torch.randn((2, 20), device=device, dtype=dtype)
2420        x1 = torch.randn((3, 20), device=device, dtype=dtype)
2421        nt = torch.nested.nested_tensor([x0, x1])
2422        pt = torch.nested.to_padded_tensor(nt, 0.0)
2423        # error case, trying to reshape batch dim to a legit shape
2424        self.assertRaisesRegex(
2425            RuntimeError,
2426            r"For now nested view cannot change or infer the implicit batch dimension",
2427            lambda: nt.transpose(-1, -2).view(40, -1),
2428        )
2429        # inherit only the ragged dimension
2430        # (2, 20) -> (2, 5, 4)
2431        # (3, 20) -> (3, 5, 4)
2432        nt1 = nt.view(2, -1, 5, 4)
2433        # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
2434        pt1 = pt.view(2, -1, 5, 4)
2435        self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
2436
2437        # more than one -1 (even for "old" dims), should fail
2438        # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
2439        # but we ban "inherit old behavior" for >1 dimension
2440        self.assertRaisesRegex(
2441            RuntimeError,
2442            r"only one dimension can be inferred",
2443            lambda: nt1.view(2, -1, -1, 2, 2),
2444        )
2445
2446    @dtypes(torch.float, torch.float16, torch.double)
2447    def test_view_inference_mode_interaction(self, device, dtype):
2448        # Construct in default mode and view while in inference mode
2449        nt = torch.nested.nested_tensor(
2450            [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype
2451        )
2452        with torch.inference_mode():
2453            ntT = nt.view(2, -1, 4, 5)
2454            ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
2455            pt = torch.nested.to_padded_tensor(nt, 0.0)
2456            ptT = pt.view(2, -1, 4, 5)
2457            self.assertEqual(ptT, ptT_from_ntT)
2458        # Construct and view while in inference mode
2459        with torch.inference_mode():
2460            nt = torch.nested.nested_tensor(
2461                [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype
2462            )
2463            ntT = nt.view(2, -1, 4, 5)
2464            ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
2465            pt = torch.nested.to_padded_tensor(nt, 0.0)
2466            ptT = pt.view(2, -1, 4, 5)
2467            self.assertEqual(ptT, ptT_from_ntT)
2468
2469    @dtypes(torch.float, torch.float16, torch.double)
2470    def test_reshape(self, device, dtype):
2471        nt = random_nt(device, dtype, 4, (4, 4))
2472        # error case: empty shape
2473        self.assertRaisesRegex(
2474            RuntimeError,
2475            r"shape '\[\]' is invalid for a nested tensor",
2476            lambda: nt.reshape(()),
2477        )
2478        # error case: empty nested tensor
2479        nt_empty = torch.nested.nested_tensor([])
2480        self.assertRaisesRegex(
2481            RuntimeError,
2482            "empty nested tensor cannot be reshaped",
2483            lambda: nt_empty.reshape(-1),
2484        )
2485        # error case: -1 for batch size
2486        self.assertRaisesRegex(
2487            RuntimeError,
2488            r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
2489            lambda: nt.reshape(-1, 2, 3),
2490        )
2491        self.assertRaisesRegex(
2492            RuntimeError,
2493            r"shape '\[.*\]' is invalid for input of size [0-9]+",
2494            lambda: nt.reshape(4, 2, 3),
2495        )
2496        # normal case
2497        x0 = torch.randn((2, 20), device=device, dtype=dtype)
2498        x1 = torch.randn((3, 20), device=device, dtype=dtype)
2499        nt = torch.nested.nested_tensor([x0, x1])  # (2, (2, 3), 20)
2500        pt = torch.nested.to_padded_tensor(nt, 0.0)
2501        # error case, trying to reshape batch dim to a legit shape
2502        self.assertRaisesRegex(
2503            RuntimeError,
2504            r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
2505            lambda: nt.transpose(-1, -2).reshape(40, -1),
2506        )
2507        # inherit only the ragged dimension
2508        # (2, 20) -> (2, 5, 4)
2509        # (3, 20) -> (3, 5, 4)
2510        nt1 = nt.reshape(2, -1, 5, 4)
2511        # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
2512        pt1 = pt.reshape(2, -1, 5, 4)
2513        self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
2514
2515        # more than one -1 (even for "old" dims), should fail
2516        # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
2517        # but we ban "inherit old behavior" for >1 dimension
2518        self.assertRaisesRegex(
2519            RuntimeError,
2520            r"only one dimension can be inferred",
2521            lambda: nt1.reshape(2, -1, -1, 2, 2),
2522        )
2523
2524    def test_nested_masked_select(self, device):
2525        t = torch.randn([3, 3], device=device)
2526        mask = torch.tensor([False], device=device)
2527
2528        njt = torch.nested.masked_select(t, mask)
2529        self.assertEqual(njt.values(), torch.tensor([], device=device))
2530        self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 0], device=device))
2531
2532        mask = torch.tensor([[False], [False], [True]], device=device)
2533        njt = torch.nested.masked_select(t, mask)
2534        self.assertEqual(njt.values(), t[-1], atol=0.1, rtol=0.1)
2535        self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 3], device=device))
2536
2537        mask = torch.tensor(
2538            [[False, False, True], [True, False, True], [False, False, True]],
2539            device=device,
2540        )
2541        njt = torch.nested.masked_select(t, mask)
2542        self.assertEqual(njt.values(), t.masked_select(mask))
2543        self.assertEqual(njt.offsets(), torch.tensor([0, 1, 3, 4], device=device))
2544
2545        t = torch.randn([2, 3, 3, 1], device=device)
2546        mask = torch.tensor(
2547            [
2548                [
2549                    [[True], [False], [True]],
2550                    [[True], [False], [True]],
2551                    [[True], [False], [True]],
2552                ],
2553                [
2554                    [[False], [True], [True]],
2555                    [[False], [True], [True]],
2556                    [[True], [True], [True]],
2557                ],
2558            ],
2559            device=device,
2560        )
2561        njt = torch.nested.masked_select(t, mask)
2562        self.assertEqual(njt.values(), t.masked_select(mask))
2563        self.assertEqual(
2564            njt.offsets(),
2565            torch.tensor(
2566                [0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 11, 12, 13],
2567                device=device,
2568            ),
2569        )
2570
2571    @dtypes(torch.float, torch.float16, torch.double)
2572    def test_narrow(self, device, dtype):
2573        nt = random_nt_from_dims([5, None, None, None], device=device, dtype=dtype)
2574
2575        # narrow on dim=0 from start to end
2576        bounds = [(0, 5), (0, 3), (1, 2), (1, 5), (2, 4)]
2577        for start, end in bounds:
2578            length = end - start
2579            narrowed = nt.narrow(dim=0, start=start, length=length)
2580            # ensure output is a view
2581            self.assertTrue(narrowed._base is nt)
2582            for nc, c in zip(narrowed.unbind(), nt.unbind()[start:end]):
2583                self.assertEqual(nc, c)
2584
2585        # dim != 0 is not supported
2586        for dim in range(1, nt.dim()):
2587            with self.assertRaisesRegex(
2588                RuntimeError, "only dim=0 supported for nested tensors"
2589            ):
2590                nt.narrow(dim=dim, start=0, length=1)
2591
2592        # error case: non-contiguous NT
2593        _, nt_noncont = random_nt_noncontiguous_pair((2, 3, 4))
2594        with self.assertRaisesRegex(
2595            RuntimeError, "only contiguous nested tensors supported"
2596        ):
2597            nt_noncont.narrow(dim=0, start=0, length=1)
2598
2599    @parametrize("input_dim", [3, 4])
2600    def test_scaled_dot_product_attention(self, device, input_dim):
2601        def rand_tensor(*shape):
2602            return torch.randn(shape, device=device)
2603
2604        E = 8
2605        if input_dim == 3:
2606            # Shape: (N, L, E); ragged L
2607            query = torch.nested.nested_tensor(
2608                [rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)]
2609            )
2610
2611            # Shape: (N, S, E); ragged S
2612            key = torch.nested.nested_tensor(
2613                [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]
2614            )
2615            value = torch.nested.nested_tensor(
2616                [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]
2617            )
2618        elif input_dim == 4:
2619            # In the 4D case the L and S is ragged
2620            # Shape: (N, N', L, E); ragged N' and L
2621            query = torch.nested.nested_tensor(
2622                [rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)]
2623            )
2624            # Shape: (N, N', S, E); ragged N' and S
2625            key = torch.nested.nested_tensor(
2626                [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]
2627            )
2628            value = torch.nested.nested_tensor(
2629                [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]
2630            )
2631        else:
2632            self.fail(f"Invalid input_dim {input_dim} encountered in SDP test")
2633
2634        def rand_mask(size):
2635            return torch.randint(0, 2, size=size, dtype=torch.bool, device=device)
2636
2637        # Shape: (N, L, S); ragged L and S matching above
2638        attn_mask = torch.nested.nested_tensor(
2639            [rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))]
2640        )
2641
2642        dropout_p = 0.0  # no dropout for reproducibility
2643
2644        # Success case: no attn_mask set and is_causal=False.
2645        actual = torch.nn.functional.scaled_dot_product_attention(
2646            query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p
2647        )
2648
2649        expected_outputs = []
2650        for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()):
2651            output = torch.nn.functional.scaled_dot_product_attention(
2652                q.unsqueeze(0),
2653                k.unsqueeze(0),
2654                v.unsqueeze(0),
2655                attn_mask=None,
2656                dropout_p=dropout_p,
2657            )
2658            expected_outputs.append(output.squeeze(0))
2659        expected_output_nested = torch.nested.nested_tensor(expected_outputs)
2660        self.assertEqual(actual, expected_output_nested)
2661
2662        # Error case: explicit attn_mask set.
2663        with self.assertRaisesRegex(
2664            RuntimeError, "not supported when an explicit attn_mask is set"
2665        ):
2666            torch.nn.functional.scaled_dot_product_attention(
2667                query, key, value, attn_mask=attn_mask, dropout_p=dropout_p
2668            )
2669
2670        # Error case: is_causal=True.
2671        with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"):
2672            torch.nn.functional.scaled_dot_product_attention(
2673                query, key, value, dropout_p=dropout_p, is_causal=True
2674            )
2675
2676    @dtypes(torch.float, torch.float16, torch.double)
2677    def test_empty_like(self, device, dtype):
2678        ntensors = 4
2679        nt = random_nt(device, dtype, ntensors, (4, 4))
2680
2681        # Create empty on same device as original nested tensor
2682        nt_empty = torch.empty_like(nt)
2683        assert nt.is_same_size(nt_empty)
2684        self.assertEqual(nt.dtype, nt_empty.dtype)
2685        self.assertEqual(nt.device, nt_empty.device)
2686        self.assertEqual(nt.layout, nt_empty.layout)
2687
2688        if torch.cuda.is_available():
2689            if device == "cpu":
2690                nt_cuda = torch.empty_like(nt, device="cuda")
2691                self.assertEqual(torch.device("cuda").type, nt_cuda.device.type)
2692            else:
2693                nt_cpu = torch.empty_like(nt, device="cpu")
2694                self.assertEqual(torch.device("cpu").type, nt_cpu.device.type)
2695
2696        # Check changing dtype of empty_like nested tensor output
2697        dtype_set = {torch.float, torch.float16, torch.double}
2698        for other_dtype in dtype_set - {dtype}:
2699            nt_empty_other_dtype = torch.empty_like(nt, dtype=other_dtype)
2700            self.assertEqual(nt.dtype, dtype)
2701            self.assertEqual(nt_empty_other_dtype.dtype, other_dtype)
2702            self.assertEqual(nt.device, nt_empty.device)
2703            self.assertEqual(nt.layout, nt_empty.layout)
2704
2705        # Create tensor for autograd
2706        nt_empty_req_grad = torch.empty_like(nt, requires_grad=True)
2707        self.assertEqual(nt_empty_req_grad.requires_grad, True)
2708
2709        # Test noncontiguous tensor does not fail to copy
2710        nt_cont, nt_noncont = random_nt_noncontiguous_pair((2, 3, 6, 7))
2711        nt_empty = torch.empty_like(nt_cont)
2712        assert nt_cont.is_same_size(nt_empty)
2713        nt_empty_non_contig = torch.empty_like(nt_noncont)
2714        assert nt_noncont.is_same_size(nt_empty_non_contig)
2715
2716        # Test the contiguous memory format option
2717        nt_empty_contig = torch.empty_like(
2718            nt_cont, memory_format=torch.contiguous_format
2719        )
2720        assert nt_cont.is_same_size(nt_empty_contig)
2721        assert nt_empty_contig.is_contiguous()
2722
2723        nt_empty_non_contig = torch.empty_like(
2724            nt_noncont, memory_format=torch.contiguous_format
2725        )
2726        assert nt_noncont.is_same_size(nt_empty_non_contig)
2727        assert nt_empty_non_contig.is_contiguous()
2728
2729        # Test other memory formats fail
2730        self.assertRaises(
2731            RuntimeError,
2732            lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last),
2733        )
2734        self.assertRaises(
2735            RuntimeError,
2736            lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last),
2737        )
2738        self.assertRaises(
2739            RuntimeError,
2740            lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d),
2741        )
2742        self.assertRaises(
2743            RuntimeError,
2744            lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d),
2745        )
2746
2747
2748@markDynamoStrictTest
2749class TestNestedTensorAutograd(NestedTensorTestCase):
2750    # Note [Gradcheck args check_batched_grad=False] the common_utils testing version of gradcheck
2751    # includes the default parameters used for testing ops with gradcheck. However nested tensor
2752    # does not support the stack op therefore we turn it off for these tests
2753    def _create_leaf_nested_tensor_from_list(self, tensor_device, requires_grad=False):
2754        return torch.nested.nested_tensor(
2755            [torch.randn(1, 2), torch.randn(7, 8)],
2756            requires_grad=requires_grad,
2757            device=tensor_device,
2758        )
2759
2760    def _create_nested_tensor_from_list(self, tensor_device, requires_grad=False):
2761        return torch.nested.as_nested_tensor(
2762            [
2763                torch.randn(1, 2, requires_grad=requires_grad),
2764                torch.randn(7, 8, requires_grad=requires_grad),
2765            ],
2766            device=tensor_device,
2767        )
2768
2769    def _create_nested_tensor_from_mask(self, tensor_device, requires_grad=False):
2770        data = torch.randn(2, 3, 4, requires_grad=requires_grad, device=tensor_device)
2771        mask = torch.ones_like(data[:, :, 0]).bool()
2772        return torch._nested_tensor_from_mask(data, mask)
2773
2774    def test_as_nested_tensor_propagates_gradients(self, device):
2775        a = torch.arange(3, dtype=torch.float, device=device)
2776        b = torch.arange(5, dtype=torch.float, device=device)
2777        nt = torch.nested.as_nested_tensor([a, b])
2778        # tensors with requires_grad=False are leaves
2779        self.assertTrue(nt.is_leaf)
2780        self.assertTrue(not nt.requires_grad)
2781
2782        a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device)
2783        b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device)
2784        nt2 = torch.nested.as_nested_tensor([a, b])
2785        fake_grad = torch.nested.nested_tensor(
2786            [torch.ones_like(a), torch.zeros_like(b)], device=device
2787        )
2788        nt2.backward(fake_grad)
2789        self.assertEqual(a.grad, fake_grad[0])
2790        self.assertEqual(b.grad, fake_grad[1])
2791
2792    def test_nested_tensor_generates_leaf(self, device):
2793        a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device)
2794        b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device)
2795
2796        nt = torch.nested.nested_tensor([a, b], requires_grad=False)
2797        self.assertTrue(nt.is_leaf)
2798        self.assertTrue(not nt.requires_grad)
2799
2800        nt2 = torch.nested.nested_tensor([a, b], requires_grad=True)
2801        self.assertTrue(nt2.is_leaf)
2802        self.assertTrue(nt2.requires_grad)
2803
2804        fake_grad = torch.nested.nested_tensor(
2805            [torch.ones_like(a), torch.zeros_like(b)], device=device
2806        )
2807        nt2.backward(fake_grad)
2808        self.assertEqual(nt2.grad, fake_grad)
2809        self.assertEqual(a.grad, None)
2810        self.assertEqual(b.grad, None)
2811
2812    def test_set_requires_grad_from_list(self, device):
2813        nt = self._create_nested_tensor_from_list(device)
2814        nt.requires_grad_()
2815        assert nt.requires_grad
2816
2817    def test_set_requires_grad_from_mask(self, device):
2818        nt = self._create_nested_tensor_from_mask(device)
2819        nt.requires_grad_()
2820        assert nt.requires_grad
2821
2822    def test_backward_for_add_op(self, device):
2823        nt_1 = self._create_nested_tensor_from_mask(device)
2824        nt_2 = self._create_nested_tensor_from_mask(device)
2825
2826        nt_1.requires_grad_()
2827        c = nt_1 + nt_2
2828
2829        assert nt_1.requires_grad
2830        assert c.requires_grad
2831        grad_output = self._create_nested_tensor_from_mask(device)
2832        c.backward(grad_output)
2833
2834        #  Grad check doesn't work with nested yet.
2835        # d/dnt_1 (nt + nt_1) = 1*grad_output
2836        self.assertEqual(nt_1.grad, grad_output)
2837
2838    def test_backward_for_sub_op(self, device):
2839        nt_1 = self._create_nested_tensor_from_mask(device)
2840        nt_2 = self._create_nested_tensor_from_mask(device)
2841
2842        nt_1.requires_grad_()
2843        nt_2.requires_grad_()
2844        c = nt_1 - nt_2
2845
2846        assert nt_1.requires_grad
2847        assert nt_2.requires_grad
2848        assert c.requires_grad
2849        grad_output = self._create_nested_tensor_from_mask(device)
2850        c.backward(grad_output)
2851
2852        self.assertEqual(nt_1.grad, grad_output)
2853        self.assertEqual(nt_2.grad, -1 * grad_output)
2854
2855    def test_backward_sub_strided(self, device):
2856        a = torch.nested.nested_tensor(
2857            [torch.randn(9, 2, 4), torch.randn(12, 2, 4)],
2858            requires_grad=True,
2859            device=device,
2860        )
2861        b = torch.nested.nested_tensor(
2862            [torch.randn(9, 4, 2), torch.randn(12, 4, 2)],
2863            requires_grad=True,
2864            device=device,
2865        )
2866        c = a - b.transpose(-1, -2)
2867        grad_output = c.clone()
2868        c.backward(grad_output)
2869        self.assertEqual(a.grad, grad_output)
2870        self.assertEqual(b.grad, -1 * grad_output.transpose(-1, -2))
2871
2872    def test_backward_add_strided(self, device):
2873        a = torch.nested.nested_tensor(
2874            [torch.randn(9, 2, 4), torch.randn(12, 2, 4)],
2875            requires_grad=True,
2876            device=device,
2877        )
2878        b = torch.nested.nested_tensor(
2879            [torch.randn(9, 4, 2), torch.randn(12, 4, 2)],
2880            requires_grad=True,
2881            device=device,
2882        )
2883        c = a + b.transpose(-1, -2)
2884        grad_output = c.clone()
2885        c.backward(grad_output)
2886        self.assertEqual(a.grad, grad_output)
2887        self.assertEqual(b.grad, grad_output.transpose(-1, -2))
2888
2889    # Test Factory Functions
2890    def test_nested_tensor_to_padded_tensor(self, device):
2891        for padding_val in [0, 1]:
2892            nt = self._create_leaf_nested_tensor_from_list(
2893                tensor_device=device, requires_grad=True
2894            )
2895
2896            out = torch.nested.to_padded_tensor(nt, padding_val)
2897            grad_output = torch.ones(out.shape, device=device)
2898            out.backward(grad_output)
2899
2900            self.assertEqual(
2901                nt.grad,
2902                torch.nested.nested_tensor(
2903                    [torch.ones(1, 2), torch.ones(7, 8)], device=device
2904                ),
2905            )
2906
2907    def test_nested_tensor_from_mask_and_to_padded(self, device):
2908        N, L, D = 2, 4, 4
2909        mask = torch.ones(N, L, device=device)
2910        for i in range(1, N):
2911            end = torch.randint(1, L - 1, (1,), device=device)
2912            mask[i, end:] = 0
2913
2914        mask[0, :] = 1
2915        mask = mask.bool()
2916
2917        data = torch.randn(
2918            N, L, D, requires_grad=True, dtype=torch.float64, device=device
2919        )
2920
2921        def grad_test_func(inpt):
2922            nt = torch._nested_tensor_from_mask(inpt, mask)
2923            # This implicitly tests to_padded_tensor grads
2924            return torch.nested.to_padded_tensor(nt, 0)
2925
2926        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
2927
2928    def test_nested_tensor_from_padded(self, device):
2929        nested_size = torch.tensor([[1, 2], [2, 2]])
2930        padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64, device=device)
2931        padded_tensor[0, 1, :] = 0
2932        padded_tensor.requires_grad_()
2933
2934        def grad_test_func(tensor, nested_size):
2935            nt = torch._nested_from_padded(
2936                tensor, nested_size, fuse_transform_0213=False
2937            )
2938            # This implicitly tests to_padded_tensor grads
2939            return torch.nested.to_padded_tensor(nt, 0)
2940
2941        data = (padded_tensor, nested_size)
2942        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
2943
2944    def test_nested_tensor_from_padded_fused(self, device):
2945        nested_size = torch.tensor([[1, 8], [2, 8]])
2946        padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64, device=device)
2947        padded_tensor[0, 1, :] = 0
2948        padded_tensor.requires_grad_()
2949
2950        def grad_test_func(tensor, nested_size):
2951            nt = torch._nested_from_padded(
2952                tensor, nested_size, fuse_transform_0213=True
2953            )
2954            # This implicitly tests to_padded_tensor grads
2955            return torch.nested.to_padded_tensor(nt, 0)
2956
2957        data = (padded_tensor, nested_size)
2958        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
2959
2960    def test_nested_tensor_from_list(self, device):
2961        a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
2962        b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
2963        c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64, device=device)
2964
2965        def grad_test_func(a, b, c):
2966            c = torch.nested.as_nested_tensor([a, b, c])
2967            # This implictily tests to_padded_tensor grads
2968            return torch.nested.to_padded_tensor(c, 0)
2969
2970        data = (a, b, c)
2971        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
2972
2973    @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
2974    def test_dropout_backward(self, layout):
2975        if layout == torch.jagged:
2976            nt = torch.nested.nested_tensor(
2977                [torch.randn((2, 5)), torch.randn((3, 5))],
2978                requires_grad=True,
2979                layout=layout,
2980            )
2981        else:
2982            nt = torch.nested.nested_tensor(
2983                [torch.randn((2, 5)), torch.randn((3, 4))],
2984                requires_grad=True,
2985                layout=layout,
2986            )
2987        p = 0.2
2988        y = torch.nn.functional.dropout(nt, p)
2989        y.backward(nt.clone().detach())
2990        self.assertEqual(nt.grad, y)
2991
2992    def test_nested_tensor_bmm_gradcheck(self, device):
2993        a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device)
2994        b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device)
2995        c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device)
2996        d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device)
2997
2998        def grad_test_func(a, b, c, d):
2999            nt0 = torch.nested.as_nested_tensor([a, b])
3000            nt1 = torch.nested.as_nested_tensor([c, d])
3001            result = nt0.bmm(nt1)
3002            return torch.nested.to_padded_tensor(result, 0.0)
3003
3004        data = (a, b, c, d)
3005        assert torch.autograd.gradcheck(grad_test_func, inputs=data)
3006
3007    def test_nested_tensor_bmm_backward(self, device):
3008        nt0 = torch.nested.nested_tensor(
3009            [torch.randn((2, 6)), torch.randn((3, 6))],
3010            requires_grad=True,
3011            device=device,
3012        )
3013        nt1 = torch.nested.nested_tensor(
3014            [torch.randn((6, 4)), torch.randn((6, 5))],
3015            requires_grad=True,
3016            device=device,
3017        )
3018        with torch.no_grad():
3019            pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True)
3020            pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True)
3021
3022        ynt = nt0.bmm(nt1)
3023        ypt = pt0.bmm(pt1)
3024        ynt.backward(ynt.clone())
3025        ypt.backward(ypt.clone())
3026
3027        self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad)
3028        self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad)
3029
3030    def test_nested_tensor_matmul_gradcheck(self, device):
3031        a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device)
3032        b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device)
3033        c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device)
3034        d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device)
3035
3036        def grad_test_func(a, b, c, d):
3037            nt0 = torch.nested.as_nested_tensor([a, b])
3038            nt1 = torch.nested.as_nested_tensor([c, d])
3039            result = torch.matmul(nt0, nt1)
3040            return torch.nested.to_padded_tensor(result, 0.0)
3041
3042        data = (a, b, c, d)
3043        assert torch.autograd.gradcheck(grad_test_func, inputs=data)
3044
3045    def test_nested_tensor_matmul_backward(self, device):
3046        nt0 = torch.nested.nested_tensor(
3047            [torch.randn((7, 2, 6)), torch.randn((7, 3, 6))],
3048            requires_grad=True,
3049            device=device,
3050        )
3051        nt1 = torch.nested.nested_tensor(
3052            [torch.randn((7, 6, 4)), torch.randn((7, 6, 5))],
3053            requires_grad=True,
3054            device=device,
3055        )
3056        with torch.no_grad():
3057            pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True)
3058            pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True)
3059
3060        ynt = torch.matmul(nt0, nt1)
3061        ypt = torch.matmul(pt0, pt1)
3062        ynt.backward(ynt.clone())
3063        ypt.backward(ypt.clone())
3064
3065        self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad)
3066        self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad)
3067
3068    def test_nested_tensor_transpose_gradcheck(self, device):
3069        a = torch.randn(2, 5, requires_grad=True, device=device)
3070        b = torch.randn(3, 4, requires_grad=True, device=device)
3071
3072        def grad_test_func(a, b):
3073            nt = torch.nested.as_nested_tensor([a, b])
3074            result = nt.transpose(-2, -1).transpose(-2, -1)
3075            return torch.nested.to_padded_tensor(result, 0.0)
3076
3077        data = (a, b)
3078        assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3)
3079
3080    def test_nested_tensor_transpose_backward(self, device):
3081        nt = torch.nested.nested_tensor(
3082            [torch.randn((2, 5)), torch.randn((3, 4))],
3083            requires_grad=True,
3084            device=device,
3085        )
3086        with torch.no_grad():
3087            pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
3088
3089        ynt = nt.transpose(-2, -1)
3090        ypt = pt.transpose(-2, -1)
3091        ynt.backward(ynt.clone())
3092        ypt.backward(ypt.clone())
3093
3094        self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
3095
3096    def test_nested_tensor_reshape_gradcheck(self, device):
3097        a = torch.randn(2, 6, requires_grad=True, device=device)
3098        b = torch.randn(3, 6, requires_grad=True, device=device)
3099
3100        def grad_test_func(a, b):
3101            nt = torch.nested.as_nested_tensor([a, b])
3102            result = nt.reshape(2, -1, 2, 3)
3103            return torch.nested.to_padded_tensor(result, 0.0)
3104
3105        data = (a, b)
3106        assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3)
3107
3108    def test_nested_tensor_reshape_backward(self):
3109        nt = torch.nested.nested_tensor(
3110            [torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True
3111        )
3112        with torch.no_grad():
3113            pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
3114
3115        ynt = nt.reshape(2, -1, 2, 3)
3116        ypt = pt.reshape(2, -1, 2, 3)
3117        ynt.backward(ynt.clone())
3118        ypt.backward(ypt.clone())
3119
3120        self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
3121
3122    def test_nested_tensor_squeeze_backward(self, device):
3123        nt = torch.nested.nested_tensor(
3124            [torch.randn((2, 6, 1)), torch.randn((3, 6, 1))],
3125            requires_grad=True,
3126            device=device,
3127        )
3128        with torch.no_grad():
3129            pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
3130
3131        ynt = nt.squeeze(-1)
3132        ypt = pt.squeeze(-1)
3133        ynt.backward(ynt.clone())
3134        ypt.backward(ypt.clone())
3135
3136        self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
3137
3138    def test_nested_tensor_squeeze_gradcheck(self, device):
3139        a = torch.randn(
3140            (2, 6, 1), dtype=torch.float64, requires_grad=True, device=device
3141        )
3142        b = torch.randn(
3143            (3, 6, 1), dtype=torch.float64, requires_grad=True, device=device
3144        )
3145
3146        def grad_test_func(a, b):
3147            nt = torch.nested.as_nested_tensor([a, b])
3148            result = nt.squeeze(-1)
3149            return torch.nested.to_padded_tensor(result, 0.0)
3150
3151        assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
3152
3153    def test_nested_tensor_unsqueeze_backward(self, device):
3154        nt = torch.nested.nested_tensor(
3155            [torch.randn((2, 6)), torch.randn((3, 6))],
3156            requires_grad=True,
3157            device=device,
3158        )
3159        with torch.no_grad():
3160            pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
3161
3162        ynt = nt.unsqueeze(2)
3163        ypt = pt.unsqueeze(2)
3164        ynt.backward(ynt.clone())
3165        ypt.backward(ypt.clone())
3166
3167        self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
3168
3169    def test_nested_tensor_unsqueeze_gradcheck(self, device):
3170        a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True, device=device)
3171        b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True, device=device)
3172
3173        def grad_test_func(a, b):
3174            nt = torch.nested.as_nested_tensor([a, b])
3175            result = nt.unsqueeze(-1)
3176            return torch.nested.to_padded_tensor(result, 0.0)
3177
3178        assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
3179
3180    def test_nested_tensor_linear(self, device):
3181        a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
3182        b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
3183        c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
3184
3185        weight = torch.randn(
3186            2, 2, requires_grad=True, dtype=torch.float64, device=device
3187        )
3188        bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device)
3189
3190        def grad_test_func(a, b, c, weight, bias=None):
3191            nt = torch.nested.as_nested_tensor([a, b, c])
3192            # This implicitly tests to_padded_tensor grads
3193            d = torch.functional.F.linear(nt, weight, bias)
3194            return torch.nested.to_padded_tensor(d, 0)
3195
3196        data = (a, b, c, weight, bias)
3197        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3198
3199        # Test linear with no bias added
3200        data = (a, b, c, weight)
3201        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3202
3203    def test_nested_tensor_linear_plus_transpose(self, device):
3204        a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
3205        b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
3206        c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
3207
3208        weight = torch.randn(
3209            2, 2, requires_grad=True, dtype=torch.float64, device=device
3210        )
3211        bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device)
3212
3213        def grad_test_func(a, b, c, weight, bias=None):
3214            nt = torch.nested.as_nested_tensor([a, b, c])
3215            # This implicitly tests to_padded_tensor grads
3216            d = torch.functional.F.linear(nt, weight, bias)
3217            d = d.transpose(-1, -2).contiguous()
3218            return torch.nested.to_padded_tensor(d, 0)
3219
3220        data = (a, b, c, weight, bias)
3221        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3222
3223        # Test linear with no bias added
3224        data = (a, b, c, weight)
3225        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3226
3227    def test_nested_tensor_softmax(self, device):
3228        a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
3229        b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
3230        c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
3231
3232        def grad_test_func(a, b, c, dim):
3233            nt = torch.nested.as_nested_tensor([a, b, c])
3234            # This implicitly tests to_padded_tensor grads
3235            d = torch.functional.F.softmax(nt, dim=dim)
3236            return torch.nested.to_padded_tensor(d, 0)
3237
3238        # softmax over last dim
3239        data = (a, b, c, -1)
3240        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3241
3242    def test_nested_tensor_linear_backward(self, device):
3243        a = torch.randn(1, 2, requires_grad=False, device=device)
3244        b = torch.randn(2, 2, requires_grad=False, device=device)
3245        c = torch.randn(3, 2, requires_grad=False, device=device)
3246
3247        weight = torch.randn(2, 2, requires_grad=True, device=device)
3248        bias = torch.randn(2, requires_grad=True, device=device)
3249        nt = torch.nested.as_nested_tensor([a, b, c], device=device)
3250
3251        out = torch.functional.F.linear(nt, weight, bias)
3252
3253        out.backward(out.clone())
3254
3255        assert weight.grad is not None
3256        assert bias.grad is not None
3257
3258        assert a.grad is None
3259        assert b.grad is None
3260        assert c.grad is None
3261
3262    def test_values_grad_with_broadcast(self, device):
3263        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3264        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3265        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3266
3267        def grad_test_func(a, b, c):
3268            nt = torch.nested.as_nested_tensor([a, b, c])
3269            buffer = nt.values()
3270            return buffer.sum()
3271
3272        data = (a, b, c)
3273        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3274
3275    def test_to_buffer_series_ops_grad_with_broadcast(self, device):
3276        a = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
3277        b = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
3278        c = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
3279
3280        def grad_test_func(a, b, c):
3281            nt = torch.nested.as_nested_tensor([a, b, c])
3282            buffer = nt.values()
3283            buffer = buffer * 2
3284            return buffer.exp()
3285
3286        data = (a, b, c)
3287        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3288
3289    def test_unbind_flow_through(self, device):
3290        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3291        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3292        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3293
3294        def grad_test_func(a, b, c):
3295            nt = torch.nested.as_nested_tensor([a, b, c])
3296            ntT = nt.transpose(-1, -2)
3297            unbound = ntT.unbind()
3298            d = unbound[0]
3299            d = torch.pow(d, 2)
3300            return d
3301
3302        data = (a, b, c)
3303        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3304
3305    def test_split_with_sizes_flow_through(self, device):
3306        a = torch.randn(2, 5, requires_grad=True, dtype=torch.float64, device=device)
3307        b = torch.randn(3, 5, requires_grad=True, dtype=torch.float64, device=device)
3308        c = torch.randn(4, 5, requires_grad=True, dtype=torch.float64, device=device)
3309
3310        def grad_test_func(a, b, c):
3311            nt = torch.nested.as_nested_tensor([a, b, c])
3312            splits = nt.split_with_sizes([2, 3], dim=-1)
3313            unbound = splits[1].unbind()
3314            d = unbound[0]
3315            d = torch.pow(d, 2)
3316            return d
3317
3318        data = (a, b, c)
3319        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3320
3321    def test_indexing_backward(self, device):
3322        x0 = torch.randn((2, 5))
3323        x1 = torch.randn((3, 4))
3324        nt = torch.nested.nested_tensor([x0, x1], device=device, requires_grad=True)
3325        self.assertEqual(nt[0], x0)
3326        self.assertEqual(nt[-1], x1)
3327        grad_x0 = torch.randn((2, 5), device=device)
3328        nt[0].backward(grad_x0)
3329        expected_grad = torch.nested.nested_tensor(
3330            [grad_x0, torch.zeros((3, 4), device=device)]
3331        )
3332        self.assertEqual(nt.grad, expected_grad)
3333
3334    def test_masked_fill_backward(self, device):
3335        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3336        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3337        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3338
3339        def grad_test_func(a, b, c):
3340            nt = torch.nested.as_nested_tensor([a, b, c])
3341            mask = nt.detach().clone().to(bool)
3342            out = nt.masked_fill(mask, 0)
3343            out = torch.nested.to_padded_tensor(out, 0)
3344            return out
3345
3346        data = (a, b, c)
3347        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3348
3349    def test_gelu_backward(self, device):
3350        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3351        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3352        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3353
3354        def grad_test_func(a, b, c):
3355            nt = torch.nested.as_nested_tensor([a, b, c])
3356            nt_gelu = torch.nn.functional.gelu(nt)
3357            return torch.nested.to_padded_tensor(nt_gelu, 0)
3358
3359        data = (a, b, c)
3360        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3361
3362    def test_relu_backward(self, device):
3363        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3364        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3365        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3366
3367        def grad_test_func(a, b, c):
3368            nt = torch.nested.as_nested_tensor([a, b, c])
3369            nt_relu = torch.nn.functional.relu(nt)
3370            return torch.nested.to_padded_tensor(nt_relu, 0)
3371
3372        data = (a, b, c)
3373        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3374
3375    def test_selu_backward(self, device):
3376        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3377        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3378        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3379
3380        def grad_test_func(a, b, c):
3381            nt = torch.nested.as_nested_tensor([a, b, c])
3382            nt_relu = torch.nn.functional.silu(nt)
3383            return torch.nested.to_padded_tensor(nt_relu, 0)
3384
3385        data = (a, b, c)
3386        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3387
3388    def test_abs_backward(self, device):
3389        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3390        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3391        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3392
3393        def grad_test_func(a, b, c):
3394            nt = torch.nested.as_nested_tensor([a, b, c])
3395            nt_abs = torch.abs(nt)
3396            return torch.nested.to_padded_tensor(nt_abs, 0)
3397
3398        data = (a, b, c)
3399        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3400
3401    # Previously would error when input NT doesn't require grad
3402    # NotImplementedError: Cannot access storage of UndefinedTensorImpl
3403    def test_layer_norm_backward_edge_case(self, device):
3404        size = 4
3405        a = torch.randn(
3406            1, 2, size, requires_grad=False, dtype=torch.float64, device=device
3407        )
3408        nt = torch.nested.nested_tensor([a])
3409        nt_layer_norm = torch.nn.LayerNorm(
3410            nt.size(-1), device=device, dtype=torch.float64
3411        )
3412        out = nt_layer_norm(nt)
3413        out.backward(out.clone())
3414
3415    def test_accumulate_grad_different_strides(self, device):
3416        a = torch.rand(1, 4, 2, requires_grad=True, dtype=torch.float64, device=device)
3417        b = torch.rand(1, 8, 2, requires_grad=True, dtype=torch.float64, device=device)
3418
3419        def grad_test_func(a, b):
3420            nt_1 = torch.nested.as_nested_tensor([a, b])
3421            nt_2 = nt_1.clone()
3422            out = torch.nn.functional.scaled_dot_product_attention(nt_1, nt_2, nt_2)
3423            return torch.nested.to_padded_tensor(out, 0)
3424
3425        data = (a, b)
3426        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3427
3428    # https://github.com/pytorch/pytorch/issues/95562
3429    @skipIfSlowGradcheckEnv
3430    @parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2])
3431    def test_layer_norm_backward(self, device, size):
3432        a = torch.randn(
3433            1, 2, size, requires_grad=True, dtype=torch.float64, device=device
3434        )
3435        b = torch.randn(
3436            2, 2, size, requires_grad=True, dtype=torch.float64, device=device
3437        )
3438        c = torch.randn(
3439            3, 2, size, requires_grad=True, dtype=torch.float64, device=device
3440        )
3441
3442        def grad_test_func(a, b, c):
3443            nt = torch.nested.as_nested_tensor([a, b, c])
3444            layer_norm = torch.nn.LayerNorm(
3445                nt.size(-1), device=device, dtype=torch.float64
3446            )
3447            nt_layer_norm = layer_norm(nt)
3448            return torch.nested.to_padded_tensor(nt_layer_norm, 0)
3449
3450        data = (a, b, c)
3451        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3452
3453    # https://github.com/pytorch/pytorch/issues/95562
3454    @skipIfSlowGradcheckEnv
3455    # Could either mark slow or reduce size
3456    @parametrize("size", [128, 32, 4, 2])
3457    def test_layer_norm_backward_5d(self, device, size):
3458        a = torch.randn(
3459            4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
3460        )
3461        b = torch.randn(
3462            7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
3463        )
3464        c = torch.randn(
3465            10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
3466        )
3467
3468        def grad_test_func(a, b, c):
3469            nt = torch.nested.as_nested_tensor([a, b, c])
3470            layer_norm = torch.nn.LayerNorm(
3471                (size, size, nt.size(-1)), device=device, dtype=torch.float64
3472            )
3473            nt_layer_norm = layer_norm(nt)
3474            return torch.nested.to_padded_tensor(nt_layer_norm, 0)
3475
3476        data = (a, b, c)
3477        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3478
3479
3480# Found in torch/testing/_comparison.py
3481default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5}
3482default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6}
3483
3484
3485def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
3486    deviation = true_value - computed_value
3487    deviation = torch.abs(deviation / true_value)
3488    # Fill in the nans with the default rtol
3489    torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype])
3490    return deviation.max().item()
3491
3492
3493def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
3494    deviation = true_value - computed_value
3495    atol = torch.abs(deviation).max().item()
3496    return atol
3497
3498
3499def get_tolerances(
3500    true_value: torch.Tensor,
3501    computed_value: torch.Tensor,
3502    fudge_factor: Optional[float] = None,
3503) -> Tuple[float, float]:
3504    """Returns the absolute and relative tolerances for comparing two tensors."""
3505    fudge_factor = fudge_factor if fudge_factor is not None else 1.0
3506    atol = get_atol(true_value, computed_value)
3507    rtol = get_rtol(true_value, computed_value)
3508
3509    atol = fudge_factor * max(atol, default_atol[computed_value.dtype])
3510    rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype])
3511    # torch.isclose() has weird behavior around see:
3512    # https://github.com/pytorch/pytorch/issues/102400
3513    if rtol > 1e30:
3514        rtol = default_rtol[computed_value.dtype]
3515    return atol, rtol
3516
3517
3518# We can probably parametrizing existing tests instead of having a separate
3519# test class as we begin to support more ops. Also maybe rewrite with OpInfos.
3520@markDynamoStrictTest
3521class TestNestedTensorSubclass(NestedTensorTestCase):
3522    # TODO: consolidate with the below
3523    def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True):
3524        Ds = nested_size[1:]
3525        out = []
3526        for s in nested_size[0]:
3527            out.append(
3528                torch.randn(
3529                    s,
3530                    *Ds,
3531                    requires_grad=requires_grad,
3532                    device=device,
3533                    dtype=torch.float64,
3534                )
3535            )
3536        return out
3537
3538    def _get_example_tensor_lists(
3539        self,
3540        include_list_of_lists=True,
3541        include_requires_grad=True,
3542        include_inner_dim_size_1=False,
3543        include_2d_tensor=False,
3544    ):
3545        def _make_tensor(
3546            *shape, include_requires_grad=include_requires_grad, requires_grad=True
3547        ):
3548            return torch.randn(
3549                *shape,
3550                requires_grad=(requires_grad if include_requires_grad else False),
3551            )
3552
3553        # Purposefully introduce mixed requires_grad settings for the components
3554        # when include_requires_grad=True.
3555        example_lists = [
3556            # (B, *, D) with B=4
3557            [
3558                _make_tensor(2, 5),
3559                _make_tensor(3, 5, requires_grad=False),
3560                _make_tensor(4, 5, requires_grad=False),
3561                _make_tensor(6, 5),
3562            ],
3563            # (B, *, D_0, D_1) with B=5
3564            [
3565                _make_tensor(2, 5, 6),
3566                _make_tensor(3, 5, 6),
3567                _make_tensor(4, 5, 6, requires_grad=False),
3568                _make_tensor(5, 5, 6),
3569                _make_tensor(6, 5, 6),
3570            ],
3571            # (B, *, D_0, D_1, D_2) with B=6
3572            [
3573                _make_tensor(2, 5, 6, 7),
3574                _make_tensor(3, 5, 6, 7),
3575                _make_tensor(4, 5, 6, 7, requires_grad=False),
3576                _make_tensor(5, 5, 6, 7),
3577                _make_tensor(6, 5, 6, 7),
3578                _make_tensor(7, 5, 6, 7),
3579            ],
3580        ]
3581
3582        if include_list_of_lists:
3583            example_lists.append(
3584                # (B, *, D) with B=3 in list form
3585                [
3586                    _make_tensor(2, 5, requires_grad=False).tolist(),
3587                    _make_tensor(3, 5).tolist(),
3588                    _make_tensor(4, 5).tolist(),
3589                ]
3590            )
3591
3592        if include_inner_dim_size_1:
3593            example_lists.append(
3594                [
3595                    _make_tensor(2, 1),
3596                    _make_tensor(3, 1, requires_grad=False),
3597                    _make_tensor(4, 1, requires_grad=False),
3598                    _make_tensor(6, 1),
3599                ]  # (B, *, 1)
3600            )
3601            example_lists.append(
3602                [
3603                    _make_tensor(2, 5, 1),
3604                    _make_tensor(3, 5, 1, requires_grad=False),
3605                    _make_tensor(4, 5, 1, requires_grad=False),
3606                    _make_tensor(6, 5, 1),
3607                ]  # (B, *, 5, 1)
3608            )
3609
3610        if include_2d_tensor:
3611            example_lists.append(
3612                [
3613                    _make_tensor(2),
3614                    _make_tensor(3, requires_grad=False),
3615                    _make_tensor(4, requires_grad=False),
3616                    _make_tensor(6),
3617                ]  # (B, *)
3618            )
3619
3620        return example_lists
3621
3622    def test_tensor_attributes(self, device):
3623        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
3624        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
3625        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
3626        nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3627        _offsets = nt.offsets()
3628
3629        for op in (
3630            torch.ops.aten.is_non_overlapping_and_dense.default,
3631            torch.ops.aten.sym_size.default,
3632            torch.ops.aten.dim.default,
3633            torch.ops.aten.numel.default,
3634            torch.ops.aten.sym_numel.default,
3635            torch.ops.aten.sym_stride.default,
3636            torch.ops.aten.sym_storage_offset.default,
3637        ):
3638            op(nt)
3639
3640        with self.assertRaisesRegex(
3641            RuntimeError, "directly calling torch.ops.aten.size"
3642        ):
3643            torch.ops.aten.size.default(nt)
3644
3645        nested_int = torch.nested._internal.nested_tensor.get_tensor_symint(
3646            _offsets, coeff=1
3647        )
3648        self.assertEqual(nt.size(), (3, nested_int, 3))
3649        self.assertEqual(nt.shape, (3, nested_int, 3))
3650        self.assertEqual(nt.dim(), 3)
3651        self.assertEqual(nt.numel(), 27)
3652
3653    @parametrize("nt_dim", [3, 4, 5])
3654    def test_linear(self, device, nt_dim):
3655        if nt_dim == 3:
3656            fixed_shape = (3,)
3657        elif nt_dim == 4:
3658            fixed_shape = (4, 3)
3659        elif nt_dim == 5:
3660            fixed_shape = (5, 4, 3)
3661
3662        a = torch.randn(
3663            2, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
3664        )
3665        b = torch.randn(
3666            3, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
3667        )
3668        c = torch.randn(
3669            4, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
3670        )
3671        weight = torch.randn(
3672            4, 3, requires_grad=True, dtype=torch.float64, device=device
3673        )
3674
3675        def grad_test_func(a, b, c, weight):
3676            nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3677            out = torch.nn.functional.linear(nt, weight)
3678            return out.values()
3679
3680        gradcheck(grad_test_func, inputs=(a, b, c, weight), check_batched_grad=False)
3681
3682    def test_unary_pointwise(self, device):
3683        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
3684        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
3685        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
3686
3687        def grad_test_func(a, b, c):
3688            nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3689            out = torch.nn.functional.silu(nt.sin().cos())
3690            return out.values()
3691
3692        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
3693
3694    def test_unary_pointwise_transposed_inputs(self, device):
3695        a, b, c = (
3696            torch.randn(
3697                i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
3698            )
3699            for i in range(3)
3700        )
3701
3702        nt = torch.nested.nested_tensor(
3703            [a.detach(), b.detach(), c.detach()], layout=torch.jagged
3704        )
3705        nt_t = nt.transpose(1, 2)
3706        self.assertFalse(nt_t.is_contiguous())
3707        out = torch.nn.functional.silu(nt_t.sin().cos())
3708        self.assertEqual(
3709            out.is_contiguous(),
3710            torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous(),
3711        )
3712
3713        self.assertEqual(nt_t.shape, out.shape)
3714
3715        a, b, c = (
3716            torch.randn(
3717                i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
3718            )
3719            for i in range(3)
3720        )
3721
3722        def grad_test_func(a, b, c):
3723            nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3724            nt_t = nt.transpose(1, 2)
3725            out = torch.nn.functional.silu(nt_t.sin().cos())
3726            return out.values()
3727
3728        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
3729
3730    def test_binary_pointwise(self, device):
3731        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
3732        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
3733        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
3734
3735        # Incorrect usage: shape check will fail if the offsets tensor are not
3736        #                  the same exact tensor object
3737        nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3738        nt2 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3739
3740        self.assertRaisesRegex(
3741            RuntimeError,
3742            "cannot call binary pointwise function .* with inputs of shapes",
3743            lambda: nt1 * nt2,
3744        )
3745
3746        # Correct usage: chain the calls using the same offsets tensor object
3747        def grad_test_func(a, b, c):
3748            nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3749            # TODO: Switch to public API that takes in (values, offsets) once it exists
3750            nt2, offsets = jagged_from_list([a, b, c], nt1.offsets())
3751            out = nt1 * nt2
3752            return out.values()
3753
3754        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
3755
3756    def test_binary_pointwise_transposed(self, device):
3757        a, b, c = (
3758            torch.randn(i + 2, 5, dtype=torch.float64, device=device) for i in range(3)
3759        )
3760
3761        nt1, offsets = jagged_from_list([a, b, c], None)
3762        nt2, offsets = jagged_from_list([a, b, c], offsets)
3763
3764        nt1_t = nt1.transpose(1, 2)
3765        nt2_t = nt2.transpose(1, 2)
3766
3767        # out = nt1_t * nt2_t
3768        # self.assertFalse(nt1_t.is_contiguous())
3769        # self.assertEqual(out.is_contiguous(), (b.transpose(-1, -2) * b.transpose(-1, -2)).is_contiguous())
3770        # self.assertEqual(out.shape, nt1_t.shape)
3771
3772        self.assertRaisesRegex(
3773            RuntimeError,
3774            "cannot call binary pointwise function mul.Tensor with inputs of shapes",
3775            lambda: nt1 * nt2_t,
3776        )
3777
3778        a, b, c = (
3779            torch.randn(
3780                i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
3781            )
3782            for i in range(3)
3783        )
3784
3785        # Correct usage: chain the calls using the same offsets tensor object
3786        def grad_test_func(a, b, c):
3787            nt1, offsets = jagged_from_list([a, b, c], None)
3788            nt2, offsets = jagged_from_list([a, b, c], offsets)
3789            nt1_t = nt1.transpose(1, 2)
3790            nt2_t = nt2.transpose(1, 2)
3791            out = nt1_t * nt2_t
3792            return out.values()
3793
3794        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
3795
3796    def test_split(self, device):
3797        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
3798        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
3799        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
3800
3801        nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3802        out = torch.split(nt, 2, -1)
3803        self.assertEqual(len(out), 2)
3804        self.assertEqualIgnoringNestedInts(
3805            out[0],
3806            torch.nested.as_nested_tensor(
3807                [a[:, 0:2], b[:, 0:2], c[:, 0:2]], layout=torch.jagged
3808            ),
3809        )
3810        self.assertEqualIgnoringNestedInts(
3811            out[1],
3812            torch.nested.as_nested_tensor(
3813                [a[:, 2:], b[:, 2:], c[:, 2:]], layout=torch.jagged
3814            ),
3815        )
3816
3817        with self.assertRaisesRegex(
3818            RuntimeError,
3819            r"split\(\): not supported for NestedTensor on dim=1",
3820        ):
3821            torch.split(nt, 2, 1)
3822
3823    def test_split_with_sizes(self, device):
3824        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
3825        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
3826        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
3827
3828        nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3829        out = torch.split(nt, [1, 2], -1)
3830        self.assertEqual(len(out), 2)
3831        self.assertEqualIgnoringNestedInts(
3832            out[0],
3833            torch.nested.as_nested_tensor(
3834                [a[:, 0:1], b[:, 0:1], c[:, 0:1]], layout=torch.jagged
3835            ),
3836        )
3837        self.assertEqualIgnoringNestedInts(
3838            out[1],
3839            torch.nested.as_nested_tensor(
3840                [a[:, 1:], b[:, 1:], c[:, 1:]], layout=torch.jagged
3841            ),
3842        )
3843        with self.assertRaisesRegex(
3844            RuntimeError,
3845            r"split_with_sizes\(\): not supported for NestedTensor on dim=1",
3846        ):
3847            torch.split(nt, [1, 2], 1)
3848
3849    def test_softmax(self, device):
3850        nt = random_nt_from_dims(
3851            [3, None, 5],
3852            device=device,
3853            dtype=torch.float32,
3854            layout=torch.jagged,
3855            requires_grad=True,
3856        )
3857
3858        # operate on dim=2
3859        output = nt.softmax(dim=2)
3860
3861        @torch._dynamo.disable
3862        def _compare_to_ref(nt, output, dim):
3863            for in_component, out_component in zip(nt.unbind(), output.unbind()):
3864                self.assertEqual(in_component.softmax(dim=dim), out_component)
3865
3866        # dim=2 -> dim=1 after unbind
3867        _compare_to_ref(nt, output, dim=1)
3868
3869        # operate on dim=-1
3870        output2 = nt.softmax(dim=-1)
3871        torch._dynamo.disable(self.assertEqual)(output, output2)
3872        _compare_to_ref(nt, output2, dim=-1)
3873
3874        def grad_test_func(a, b):
3875            nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged)
3876            out = nt.softmax(dim=-1)
3877            return out.values()
3878
3879        a = torch.rand(4, 5, requires_grad=True, dtype=torch.float64, device=device)
3880        b = torch.rand(8, 5, requires_grad=True, dtype=torch.float64, device=device)
3881        gradcheck(grad_test_func, inputs=(a, b), check_batched_grad=False)
3882
3883    def test_views_inherit_ragged_dim(self, device):
3884        # view
3885        nt = random_nt_from_dims(
3886            [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged
3887        )
3888        # inherit ragged dim via -1
3889        view = nt.view(4, -1, 80)
3890        self.assertEqual(nt.shape[1], view.shape[1])
3891        # inherit batch and ragged dims via -1
3892        view2 = nt.view(-1, -1, 80)
3893        self.assertEqual(nt.shape[:2], view2.shape[:2])
3894
3895        # expand
3896        nt = random_nt_from_dims(
3897            [3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged
3898        )
3899        # inherit batch and ragged dims via -1
3900        view = nt.expand(-1, -1, 5)
3901        self.assertEqual(nt.shape[:2], view.shape[:2])
3902
3903    def test_view_ragged_idx_not_one(self, device):
3904        nt = random_nt_from_dims(
3905            [2, None, 20], device=device, dtype=torch.float32, layout=torch.jagged
3906        )
3907
3908        view_transposed = nt.transpose(1, 2).view(2, 20, nt.size(1))
3909        self.assertEqual((2, 20, nt.size(1)), (view_transposed.size()))
3910        self.assertEqual(view_transposed._base, nt._base)
3911
3912    def test_unsafe_view(self, device):
3913        nt = random_nt_from_dims(
3914            [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged
3915        )
3916        # basic view
3917        view1 = torch.ops.aten._unsafe_view(nt, (4, -1, 80))
3918        self.assertEqual((4, nt.size(1), 80), tuple(view1.size()))
3919        # _unsafe_view differs from view in that the view information is not tracked
3920        self.assertTrue(view1._base is None)
3921
3922        # test an unsafe_view when ragged_idx != 1, currently only supports identity view
3923        nt_t = nt.transpose(1, 2)
3924        view2 = torch.ops.aten._unsafe_view(nt_t, (4, 8, nt.size(1), 10))
3925        self.assertEqual((4, 8, nt.size(1), 10), tuple(view2.size()))
3926        self.assertTrue(view2._base is None)
3927
3928    @xfailIfTorchDynamo
3929    @parametrize("requires_grad", [False, True])
3930    def test_reshape_decomp(self, device, requires_grad):
3931        # contiguous NT should result in view.
3932        nt = (
3933            random_nt_from_dims(
3934                [3, None, 10],
3935                device=device,
3936                dtype=torch.float32,
3937                layout=torch.jagged,
3938            )
3939            .detach()
3940            .requires_grad_(requires_grad)
3941        )
3942        view = nt.reshape(-1, -1, 5, 2)
3943        self.assertEqual(view.shape[:2], nt.shape[:2])
3944        self.assertTrue(view._is_view() and view._base is nt)
3945        # make sure gradients flow back
3946        if requires_grad:
3947            view.backward(torch.ones_like(view))
3948            self.assertEqual(nt.grad, torch.ones_like(nt))
3949
3950        # non-contiguous NT should result in contiguous copy
3951        nt = random_nt_from_dims(
3952            [3, None, 5, 2],
3953            device=device,
3954            dtype=torch.float32,
3955            layout=torch.jagged,
3956            requires_grad=requires_grad,
3957        )
3958        nt_noncontig = nt.transpose(-1, -2)
3959        self.assertFalse(nt_noncontig.is_contiguous())
3960        copy = nt_noncontig.reshape(-1, -1, 10)
3961        self.assertTrue(copy.is_contiguous())
3962        self.assertEqual(copy.shape[:2], nt.shape[:2])
3963        # make sure gradients flow back
3964        if requires_grad:
3965            copy.backward(torch.ones_like(copy))
3966            self.assertEqual(nt.grad, torch.ones_like(nt))
3967
3968    def test_flatten_decomp(self, device):
3969        nt = random_nt_from_dims(
3970            [3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged
3971        )
3972        flattened = nt.flatten(-2, -1)
3973        self.assertEqual(flattened.shape, nt.view(3, -1, 10).shape)
3974
3975        nt = random_nt_from_dims(
3976            [3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged
3977        )
3978        flattened = nt.flatten(-3, -2)
3979        self.assertEqual(flattened.shape, nt.view(3, -1, 10, 6).shape)
3980
3981    def test_chunk(self, device):
3982        # none NJT case
3983        t = torch.randn(10, 4, 5, requires_grad=True)
3984        t_list = t.chunk(3, dim=0)
3985        loss = t_list[0].sum() + t_list[2].sum()
3986        loss.backward()
3987
3988        # normal case
3989        D = 30
3990        B = 8
3991        nt = random_nt_from_dims(
3992            [B, None, D],
3993            device=device,
3994            dtype=torch.float32,
3995            layout=torch.jagged,
3996            requires_grad=True,
3997        )
3998        NUM_CHUNKS = 3
3999        chunks = nt.chunk(NUM_CHUNKS, dim=-1)
4000        self.assertEqual(len(chunks), NUM_CHUNKS)
4001        for i in range(NUM_CHUNKS):
4002            self.assertEqual(chunks[i].shape[-1], D // NUM_CHUNKS)
4003
4004        # test chunk_backward
4005        values = torch.randn(
4006            5, 11, dtype=torch.float64, device=device, requires_grad=True
4007        )
4008        offsets = torch.tensor([0, 2, 3, 5], device=device)
4009
4010        def grad_test_func(values, offsets):
4011            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
4012            chunks = nt.chunk(3, dim=-1)
4013            return chunks[0].values().sum()
4014
4015        assert gradcheck(
4016            grad_test_func,
4017            inputs=(values, offsets),
4018            check_batched_grad=False,
4019        )
4020
4021        # chunk on batch dim
4022        chunks = nt.chunk(NUM_CHUNKS, dim=0)
4023        self.assertEqual(len(chunks), NUM_CHUNKS)
4024        chunk_size = math.ceil(B / NUM_CHUNKS)
4025        for i in range(NUM_CHUNKS):
4026            if i < NUM_CHUNKS - 1:
4027                self.assertEqual(chunks[i].shape[0], chunk_size)
4028            else:
4029                self.assertEqual(chunks[i].shape[0], B - chunk_size * (NUM_CHUNKS - 1))
4030            offsets_expected = (
4031                nt._offsets[i * chunk_size + 1 : (i + 1) * chunk_size + 1]
4032                - nt._offsets[i * chunk_size]
4033            )
4034            self.assertEqual(chunks[i]._offsets[1:], offsets_expected)
4035        self.assertEqual(nt._values, torch.cat([x._values for x in chunks], dim=0))
4036
4037        with self.assertRaisesRegex(
4038            RuntimeError,
4039            "dim != 0 INTERNAL ASSERT FAILED .* Nested Tensor doesn't support chunk backward on dim=0 yet.",
4040        ):
4041            # doesn't support backward for chunk (dim=0) yet
4042            loss = (
4043                chunks[0].values().sum()
4044                + chunks[1].values().sum()
4045                + chunks[2].values().sum()
4046            )
4047            loss.backward()
4048
4049        # chunk on ragged dim not supported
4050        with self.assertRaisesRegex(
4051            RuntimeError, "chunk.* not supported for NestedTensor on dim=1"
4052        ):
4053            nt.chunk(2, dim=1)
4054
4055    def test_squeeze(self, device):
4056        B = 4
4057        D = 6
4058        # squeeze middle dim
4059        nt = random_nt_from_dims(
4060            [B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged
4061        )
4062        j0 = nt.shape[1]
4063
4064        for dim_arg in [-2, 2]:
4065            out = nt.squeeze(dim_arg)
4066            self.assertEqual(out.shape, (B, j0, D))
4067            self.assertEqual(out.unsqueeze(-2), nt)
4068
4069        # squeeze last dim
4070        nt = random_nt_from_dims(
4071            [B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged
4072        )
4073        j1 = nt.shape[1]
4074
4075        for dim_arg in [-1, 2]:
4076            out = nt.squeeze(dim_arg)
4077            self.assertEqual(out.shape, (B, j1))
4078            self.assertEqual(out.unsqueeze(-1), nt)
4079
4080        # squeeze on batch dim not supported
4081        with self.assertRaisesRegex(
4082            RuntimeError, "squeeze.* not supported for NestedTensor on dim=0"
4083        ):
4084            nt.squeeze(0)
4085
4086        # squeeze on ragged dim not supported
4087        with self.assertRaisesRegex(
4088            RuntimeError, "squeeze.* not supported for NestedTensor on dim=1"
4089        ):
4090            nt.squeeze(1)
4091
4092    def test_binary_pointwise_broadcasting(self, device):
4093        # (B, j0, 3, 4)
4094        ts = self._get_list_for_jagged_tensor(
4095            ((2, 3, 4), 3, 4), device, requires_grad=True
4096        )
4097        # (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
4098        # (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
4099        # (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?)
4100        # Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)
4101        t_sizes = (
4102            (4,),
4103            (1, 4),
4104            (3, 1),
4105            (1, 3, 1),
4106            (1, 1, 1, 4),
4107            # (1, 1, 1, 1, 4), (unsupported today)
4108        )
4109
4110        def grad_test_func(t, *ts):
4111            nt = torch.nested.as_nested_tensor(list(ts), layout=torch.jagged)
4112            out = nt + t
4113            return out.values()
4114
4115        for t_size in t_sizes:
4116            t = torch.rand(
4117                t_size, requires_grad=True, device=device, dtype=torch.float64
4118            )
4119            gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)
4120
4121    def test_threshold_backward(self, device):
4122        ts1 = self._get_list_for_jagged_tensor(
4123            ((2, 3, 4), 16), device=device, requires_grad=False
4124        )
4125        ts2 = self._get_list_for_jagged_tensor(
4126            ((2, 3, 4), 16), device=device, requires_grad=False
4127        )
4128
4129        nt1, offsets = jagged_from_list(ts1, None)
4130        nt2, offsets = jagged_from_list(ts2, offsets)
4131        buf1 = nt1.values().detach().clone()
4132        buf2 = nt2.values().detach().clone()
4133
4134        res_nt = torch.ops.aten.threshold_backward(nt1, nt2, 0.0)
4135        res_dense = torch.ops.aten.threshold_backward(buf1, buf2, 0.0)
4136
4137        self.assertEqual(res_dense, res_nt.values())
4138
4139    @dtypes(torch.float32)
4140    @parametrize(
4141        "func",
4142        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
4143        name_fn=get_op_name,
4144    )
4145    @parametrize("keepdim", [False, True])
4146    @parametrize("requires_grad", [False, True])
4147    @parametrize("components_require_grad", [False, True])
4148    def test_jagged_op_different_output_shape_dim(
4149        self, device, dtype, keepdim, requires_grad, components_require_grad, func
4150    ):
4151        """
4152        Operator passes when reducing on valid reduction dimensions.
4153        This test is for operators which return an output tensor with a shape different from the input tensor.
4154        """
4155        if get_op_name(func) == "mean" and not keepdim:
4156            return
4157
4158        op_name = get_op_name(func)
4159
4160        ts = self._get_list_for_jagged_tensor(
4161            ((2, 3, 4), 3, 4), device=device, requires_grad=True
4162        )  # (B, j0, 3, 4)
4163
4164        # verify correctness of shapes (assuming that ragged_idx == 1)
4165        if op_name == "sum":
4166            reduce_dims = (
4167                ((0, 1), (3, 4), (1, 1, 3, 4), (0,)),  # batch, ragged
4168                ((2, 3), (3, None), (3, None, 1, 1), (1, 2)),  # non-batch, non-batch
4169                ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)),  # batch, ragged, non-batch
4170                ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)),  # batch, ragged, non-batch
4171                (
4172                    (0, 1, 2, 3),
4173                    (),
4174                    (1, 1, 1, 1),
4175                    (0, 1, 2),
4176                ),  # batch, ragged, non-batch, non-batch
4177                ((2,), (3, None, 4), (3, None, 1, 4), (1,)),  # non-batch
4178            )  # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None
4179        elif op_name == "mean":
4180            reduce_dims = (
4181                ((2,), (3, None, 4), (3, None, 1, 4), (1,)),
4182                ((3,), (3, None, 3), (3, None, 3, 1), (2,)),
4183            )
4184
4185        for rd, ref_shape_no_keepdim, ref_shape_keepdim, _ in reduce_dims:
4186            nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
4187            out = func(nt, dim=rd, keepdim=keepdim)
4188            ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
4189            if not torch.compiler.is_compiling:  # if not using torch dynamo
4190                self.assertEqual(len(out.shape), len(ref_shape))
4191                for o, r in zip(out.shape, ref_shape):
4192                    if r is not None:
4193                        self.assertEqual(o, r)
4194                    else:
4195                        self.assertTrue(isinstance(o, torch.SymInt))
4196
4197        # verify correctness of values
4198        tensor_lists = self._get_example_tensor_lists(
4199            include_list_of_lists=False,
4200            include_requires_grad=components_require_grad,
4201            include_inner_dim_size_1=True,
4202        )
4203        for tensor_list, reduce_dim_tuple in itertools.product(
4204            tensor_lists, reduce_dims
4205        ):
4206            nt = torch.nested.nested_tensor(
4207                tensor_list,
4208                device=device,
4209                dtype=dtype,
4210                layout=torch.jagged,
4211                requires_grad=requires_grad,
4212            )
4213
4214            reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple
4215
4216            if nt.dim() > reduce_dim[-1]:
4217                out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
4218                if nt._ragged_idx in reduce_dim:  # raggedness reduced away
4219                    out_expected = func(
4220                        nt.values(), dim=reduce_dim_expected, keepdim=keepdim
4221                    )
4222                    self.assertTrue(torch.allclose(out_actual, out_expected))
4223                else:  # raggedness preserved
4224                    out_expected = func(nt.values(), dim=reduce_dim_expected)
4225                    self.assertTrue(
4226                        torch.allclose(
4227                            out_actual.values().view(-1), out_expected.view(-1)
4228                        )
4229                    )
4230
4231    @dtypes(torch.float32)
4232    @parametrize("requires_grad", [False, True])
4233    @parametrize("components_require_grad", [False, True])
4234    def test_softmax_dim(
4235        self,
4236        device,
4237        dtype,
4238        requires_grad,
4239        components_require_grad,
4240    ):
4241        """
4242        Softmax passes when reducing on valid reduction dimensions.
4243        """
4244        ts = self._get_list_for_jagged_tensor(
4245            ((2, 3, 4), 3, 4), device=device, requires_grad=True
4246        )  # (B, j0, 3, 4)
4247
4248        output_shape = (3, None, 3, 4)
4249
4250        # verify correctness of shapes (assuming that ragged_idx == 1)
4251        reduce_dims = (
4252            (2, 1),
4253            (3, 2),
4254        )  # (reduction dimension, effective reduction dimension for baseline)
4255
4256        for reduce_dim, _ in reduce_dims:
4257            nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
4258            out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim)
4259            torch._dynamo.disable(self.assertEqual)(
4260                len(out_actual.shape), len(output_shape)
4261            )  # disable if running on dynamo
4262            for dim_actual, dim_expected in zip(out_actual.shape, output_shape):
4263                if dim_expected is not None:
4264                    self.assertEqual(dim_actual, dim_expected)
4265                else:
4266                    self.assertTrue(isinstance(dim_actual, torch.SymInt))
4267
4268        # verify correctness of values
4269        tensor_lists = self._get_example_tensor_lists(
4270            include_list_of_lists=False,
4271            include_requires_grad=components_require_grad,
4272            include_inner_dim_size_1=True,
4273        )
4274        for tensor_list, reduce_dim_tuple in itertools.product(
4275            tensor_lists, reduce_dims
4276        ):
4277            nt = torch.nested.nested_tensor(
4278                tensor_list,
4279                device=device,
4280                dtype=dtype,
4281                layout=torch.jagged,
4282                requires_grad=requires_grad,
4283            )
4284
4285            reduce_dim, reduce_dim_expected = reduce_dim_tuple
4286
4287            if nt.dim() > reduce_dim:
4288                out_actual = torch.nn.functional.softmax(
4289                    nt, dim=reduce_dim
4290                )  # nested tensor
4291                out_expected = torch.nn.functional.softmax(
4292                    nt.values(), dim=reduce_dim_expected
4293                )  # dense tensor of dimensions 1 less than out_actual
4294                self.assertTrue(
4295                    torch.allclose(out_actual.values().view(-1), out_expected.view(-1))
4296                )
4297
4298    @dtypes(torch.float32)
4299    @parametrize(
4300        "func",
4301        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
4302        name_fn=get_op_name,
4303    )
4304    @parametrize("keepdim", [False, True])
4305    @parametrize("requires_grad", [False, True])
4306    @parametrize("components_require_grad", [False, True])
4307    def test_op_dim_reduce_ragged_idx_1_different_output_shape(
4308        self, device, dtype, keepdim, requires_grad, components_require_grad, func
4309    ):
4310        """
4311        Operator on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1.
4312        This test is for operators which return an output tensor with a shape different from the input tensor.
4313        """
4314        if get_op_name(func) == "mean" and not keepdim:
4315            return
4316
4317        op_name = get_op_name(func)
4318
4319        tensor_lists = self._get_example_tensor_lists(
4320            include_list_of_lists=False,
4321            include_requires_grad=components_require_grad,
4322            include_inner_dim_size_1=True,  # (B, *, 1)
4323        )
4324        reduce_dim = (1,)  # ragged
4325
4326        for tensor_list in tensor_lists:
4327            nt = torch.nested.nested_tensor(
4328                tensor_list,
4329                device=device,
4330                dtype=dtype,
4331                layout=torch.jagged,
4332                requires_grad=requires_grad,
4333            )
4334
4335            out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
4336            out_expected = torch.cat(
4337                [func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0) for t in nt.unbind()]
4338            )
4339
4340            self.assertFalse(
4341                out_actual.is_nested,
4342                f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor",
4343            )  # output is a dense tensor
4344            self.assertTrue(torch.allclose(out_actual, out_expected))
4345
4346    @dtypes(torch.float32)
4347    @parametrize("requires_grad", [False, True])
4348    @parametrize("components_require_grad", [False, True])
4349    def test_softmax_dim_reduce_ragged_idx_1(
4350        self, device, dtype, requires_grad, components_require_grad
4351    ):
4352        """
4353        Softmax on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1.
4354        """
4355        tensor_lists = self._get_example_tensor_lists(
4356            include_list_of_lists=False,
4357            include_requires_grad=components_require_grad,
4358            include_inner_dim_size_1=True,  # (B, *, 1)
4359            include_2d_tensor=True,  # (B, *)
4360        )
4361        reduce_dim = 1  # ragged
4362
4363        for tensor_list in tensor_lists:
4364            nt = torch.nested.nested_tensor(
4365                tensor_list,
4366                device=device,
4367                dtype=dtype,
4368                layout=torch.jagged,
4369                requires_grad=requires_grad,
4370            )
4371
4372            out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim)
4373            out_expected = torch.cat(
4374                [
4375                    torch.nn.functional.softmax(t, dim=reduce_dim - 1)
4376                    for t in nt.unbind()
4377                ]
4378            )
4379
4380            self.assertTrue(
4381                out_actual.is_nested,
4382                "softmax(): the result of reducing a nested tensor along the ragged dimension is a nested tensor",
4383            )  # output is a nested tensor
4384            self.assertTrue(torch.allclose(out_actual.values(), out_expected))
4385
4386    @dtypes(torch.float32)
4387    @parametrize("requires_grad", [False, True])
4388    @parametrize("components_require_grad", [False, True])
4389    def test_softmax_reduce_batch_dim(
4390        self, device, dtype, requires_grad, components_require_grad
4391    ):
4392        """
4393        Softmax on NestedTensor fails when trying to reduce across batch dimension.
4394        """
4395        tensor_lists = self._get_example_tensor_lists(
4396            include_list_of_lists=False,
4397            include_requires_grad=components_require_grad,
4398            include_inner_dim_size_1=True,  # (B, *, 1)
4399        )
4400        reduce_dim = 0  # batch
4401
4402        for tensor_list in tensor_lists:
4403            nt = torch.nested.nested_tensor(
4404                tensor_list,
4405                device=device,
4406                dtype=dtype,
4407                layout=torch.jagged,
4408                requires_grad=requires_grad,
4409            )
4410
4411            with self.assertRaisesRegex(
4412                RuntimeError,
4413                "not supported when reducing across the batch dimension for NestedTensor",
4414            ):
4415                out = torch.nn.functional.softmax(nt, dim=reduce_dim)
4416
4417    @dtypes(torch.float32)
4418    @parametrize("requires_grad", [False, True])
4419    @parametrize("components_require_grad", [False, True])
4420    def test_layer_norm_reduce_ragged_idx_1(
4421        self, device, dtype, requires_grad, components_require_grad
4422    ):
4423        """
4424        Layer normalization on NestedTensor passes when trying to normalize across ragged dimension, where ragged_idx == 1.
4425        """
4426
4427        # requires_grad = False does not currently work with dynamo tests and throws this error:
4428        #   AssertionError: SymInts must use SymNodeVariable.
4429        #   If the underlying value is static, we will create a ConstantVariable and specialize.
4430        if torch._dynamo.is_compiling() and not requires_grad:
4431            return
4432
4433        tensor_lists = self._get_example_tensor_lists(
4434            include_list_of_lists=False,
4435            include_requires_grad=components_require_grad,
4436            include_inner_dim_size_1=True,  # (B, *, 1)
4437        )
4438
4439        for tensor_list in tensor_lists:
4440            nt = torch.nested.nested_tensor(
4441                tensor_list,
4442                device=device,
4443                dtype=dtype,
4444                layout=torch.jagged,
4445                requires_grad=requires_grad,
4446            )
4447
4448            if (
4449                nt.dim() >= 3
4450            ):  # layer norm only works for tensors with 3 or more dimensions
4451                normalized_shape = nt.shape[nt._ragged_idx :]
4452
4453                out_actual = torch.nn.functional.layer_norm(
4454                    nt, normalized_shape=normalized_shape
4455                )
4456                out_expected = torch.cat(
4457                    [
4458                        torch.nn.functional.layer_norm(t, normalized_shape=t.shape)
4459                        for t in nt.unbind()
4460                    ]
4461                )  # e.g. in 3D tensor (B, *, M), performs layer normalization on B 2D tensors (*, M)
4462
4463                self.assertTrue(
4464                    out_actual.is_nested,
4465                    "layer_norm(): the result of reducing a nested tensor along the ragged dimension is a nested tensor",
4466                )  # output is a nested tensor
4467                self.assertEqual(out_actual._values.shape, out_expected.shape)
4468                self.assertTrue(torch.allclose(out_actual.values(), out_expected))
4469
4470    @dtypes(torch.float32)
4471    @parametrize("requires_grad", [False, True])
4472    @parametrize("components_require_grad", [False, True])
4473    def test_layer_norm_2d_input(
4474        self,
4475        device,
4476        dtype,
4477        requires_grad,
4478        components_require_grad,
4479    ):
4480        """
4481        Layer normalization on NestedTensor fails when trying to operate on a 2-dimensional tensor
4482        """
4483        tensor_lists = self._get_example_tensor_lists(
4484            include_list_of_lists=False,
4485            include_requires_grad=components_require_grad,
4486            include_inner_dim_size_1=True,  # (B, *, 1)
4487            include_2d_tensor=True,  # (B, *)
4488        )
4489
4490        for tensor_list in tensor_lists:
4491            nt = torch.nested.nested_tensor(
4492                tensor_list,
4493                device=device,
4494                dtype=dtype,
4495                layout=torch.jagged,
4496                requires_grad=requires_grad,
4497            )
4498
4499            if nt.dim() <= 2:
4500                with self.assertRaisesRegex(
4501                    RuntimeError,
4502                    "not supported for NestedTensor objects with 2 or fewer dimensions",
4503                ):
4504                    out = torch.nn.functional.layer_norm(
4505                        nt, normalized_shape=(nt.shape[nt._ragged_idx],)
4506                    )
4507
4508    @dtypes(torch.float32)
4509    @parametrize("requires_grad", [False, True])
4510    @parametrize("components_require_grad", [False, True])
4511    def test_layer_norm_operate_on_batch_dim(
4512        self,
4513        device,
4514        dtype,
4515        requires_grad,
4516        components_require_grad,
4517    ):
4518        """
4519        Layer normalization on NestedTensor fails when trying to operate on the batch dimension
4520        """
4521        tensor_lists = self._get_example_tensor_lists(
4522            include_list_of_lists=False,
4523            include_requires_grad=components_require_grad,
4524            include_inner_dim_size_1=True,  # (B, *, 1)
4525            include_2d_tensor=True,  # (B, *)
4526        )
4527
4528        for tensor_list in tensor_lists:
4529            nt = torch.nested.nested_tensor(
4530                tensor_list,
4531                device=device,
4532                dtype=dtype,
4533                layout=torch.jagged,
4534                requires_grad=requires_grad,
4535            )
4536
4537            if nt.dim() > 2:  # cannot perform layer normalization on 2D tensors
4538                with self.assertRaisesRegex(
4539                    RuntimeError,
4540                    "not supported when normalizing over the batch dimension for NestedTensor",
4541                ):
4542                    out = torch.nn.functional.layer_norm(nt, normalized_shape=nt.shape)
4543
4544    @dtypes(torch.float32)
4545    @parametrize(
4546        "func",
4547        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
4548        name_fn=get_op_name,
4549    )
4550    @parametrize(
4551        "transpose_offset", [1, 2]
4552    )  # [transpose consecutive dimensions, transpose nonconsecutive dimensions]
4553    @parametrize("keepdim", [False, True])
4554    @parametrize("requires_grad", [False, True])
4555    @parametrize("components_require_grad", [False, True])
4556    def test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape(
4557        self,
4558        device,
4559        dtype,
4560        keepdim,
4561        requires_grad,
4562        components_require_grad,
4563        func,
4564        transpose_offset,
4565    ):
4566        """
4567        Operator on NestedTensor passes when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1
4568        This test is for operators which return an output tensor with a shape different from the input tensor.
4569        """
4570        if get_op_name(func) == "mean" and not keepdim:
4571            return
4572
4573        op_name = get_op_name(func)
4574
4575        tensor_lists = self._get_example_tensor_lists(
4576            include_list_of_lists=False,
4577            include_requires_grad=components_require_grad,
4578            include_inner_dim_size_1=True,  # (B, *, 1)
4579            include_2d_tensor=True,  # (B, *)
4580        )
4581
4582        for tensor_list in tensor_lists:
4583            nt = torch.nested.nested_tensor(
4584                tensor_list,
4585                device=device,
4586                dtype=dtype,
4587                layout=torch.jagged,
4588                requires_grad=requires_grad,
4589            )
4590
4591            if nt.dim() > nt._ragged_idx + transpose_offset:
4592                nt_transposed = nt.transpose(
4593                    nt._ragged_idx, nt._ragged_idx + transpose_offset
4594                )
4595                reduce_dim = (nt_transposed._ragged_idx,)  # ragged
4596
4597                out_actual = func(nt_transposed, dim=reduce_dim, keepdim=keepdim)
4598                out_expected = torch.cat(
4599                    [
4600                        func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0)
4601                        for t in nt_transposed.unbind()
4602                    ]
4603                )
4604
4605                self.assertFalse(
4606                    out_actual.is_nested,
4607                    f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor",
4608                )  # output is a dense tensor
4609                self.assertTrue(torch.allclose(out_actual, out_expected, rtol=1e-4))
4610
4611    @dtypes(torch.float32)
4612    @parametrize(
4613        "transpose_offset", [1, 2]
4614    )  # [transpose consecutive dimensions, transpose nonconsecutive dimensions]
4615    @parametrize("requires_grad", [False, True])
4616    @parametrize("components_require_grad", [False, True])
4617    def test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape(
4618        self,
4619        device,
4620        dtype,
4621        requires_grad,
4622        components_require_grad,
4623        transpose_offset,
4624    ):
4625        """
4626        Softmax on NestedTensor fails when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1
4627        This test is for operators which return an output tensor with the same shape as the input tensor.
4628        """
4629        tensor_lists = self._get_example_tensor_lists(
4630            include_list_of_lists=False,
4631            include_requires_grad=components_require_grad,
4632            include_inner_dim_size_1=True,  # (B, *, 1)
4633        )
4634
4635        for tensor_list in tensor_lists:
4636            nt = torch.nested.nested_tensor(
4637                tensor_list,
4638                device=device,
4639                dtype=dtype,
4640                layout=torch.jagged,
4641                requires_grad=requires_grad,
4642            )
4643
4644            if nt.dim() > nt._ragged_idx + transpose_offset:
4645                nt_transposed = nt.transpose(
4646                    nt._ragged_idx, nt._ragged_idx + transpose_offset
4647                )
4648                reduce_dim = nt_transposed._ragged_idx  # ragged
4649
4650                with self.assertRaisesRegex(
4651                    RuntimeError,
4652                    "not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor",
4653                ):
4654                    out = torch.nn.functional.softmax(nt_transposed, dim=reduce_dim)
4655
4656    @dtypes(torch.float32)
4657    @parametrize(
4658        "func",
4659        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
4660        name_fn=get_op_name,
4661    )
4662    @parametrize("keepdim", [False, True])
4663    @parametrize("requires_grad", [False, True])
4664    @parametrize("components_require_grad", [False, True])
4665    def test_op_dim_transpose_non_ragged_dim_different_output_shape(
4666        self, device, dtype, keepdim, requires_grad, components_require_grad, func
4667    ):
4668        """
4669        Operator passes when reducing transposed nested tensors on valid reduction dimensions.
4670        This test is for operators which return an output tensor with a shape different from the input tensor.
4671        """
4672        if get_op_name(func) == "mean" and not keepdim:
4673            return
4674
4675        # verify correctness of shapes (assuming that ragged_idx == 1)
4676        if get_op_name(func) == "sum":
4677            reduce_dims = (
4678                ((0, 1), (3, 4), (1, 1, 3, 4), (0,)),  # batch, ragged
4679                ((2, 3), (3, None), (3, None, 1, 1), (1, 2)),  # non-batch, non-batch
4680                ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)),  # batch, ragged, non-batch
4681                ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)),  # batch, ragged, non-batch
4682                (
4683                    (0, 1, 2, 3),
4684                    (),
4685                    (1, 1, 1, 1),
4686                    (0, 1, 2),
4687                ),  # batch, ragged, non-batch, non-batch
4688                ((2,), (3, None, 4), (3, None, 1, 4), (1,)),  # non-batch
4689            )  # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None
4690        elif get_op_name(func) == "mean":
4691            reduce_dims = (
4692                ((2,), (3, None, 4), (3, None, 1, 4), (1,)),
4693                ((3,), (3, None, 3), (3, None, 3, 1), (2,)),
4694            )
4695
4696        # verify correctness of values
4697        tensor_lists = self._get_example_tensor_lists(
4698            include_list_of_lists=False,
4699            include_requires_grad=components_require_grad,
4700        )
4701        for tensor_list, reduce_dim_tuple in itertools.product(
4702            tensor_lists, reduce_dims
4703        ):
4704            nt = torch.nested.nested_tensor(
4705                tensor_list,
4706                device=device,
4707                dtype=dtype,
4708                layout=torch.jagged,
4709                requires_grad=requires_grad,
4710            ).transpose(-1, -2)
4711
4712            reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple
4713
4714            if nt.dim() > max(
4715                reduce_dim[-1], nt._ragged_idx + 2
4716            ):  # ensure that transposed dimensions are non-batch, non-ragged dimensions
4717                out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
4718                if nt._ragged_idx in reduce_dim:  # raggedness reduced away
4719                    out_expected = func(
4720                        nt.values(), dim=reduce_dim_expected, keepdim=keepdim
4721                    )
4722                    self.assertTrue(torch.allclose(out_actual, out_expected))
4723                else:  # raggedness preserved
4724                    out_expected = func(nt.values(), dim=reduce_dim_expected)
4725                    self.assertTrue(
4726                        torch.allclose(
4727                            out_actual.values().view(-1), out_expected.view(-1)
4728                        )
4729                    )
4730
4731    @dtypes(torch.float32)
4732    @parametrize("requires_grad", [False, True])
4733    @parametrize("components_require_grad", [False, True])
4734    def test_softmax_dim_transpose_non_ragged_dim(
4735        self,
4736        device,
4737        dtype,
4738        requires_grad,
4739        components_require_grad,
4740    ):
4741        """
4742        Softmax passes when reducing transposed nested tensors on valid reduction dimensions.
4743        This test is for operators which return an output tensor with the same shape as the input tensor.
4744        """
4745        # verify correctness of shapes (assuming that ragged_idx == 1)
4746        reduce_dims = (
4747            (2, 1),
4748            (3, 2),
4749        )  # (reduction dimension, effective reduction dimension for baseline)
4750
4751        # verify correctness of values
4752        tensor_lists = self._get_example_tensor_lists(
4753            include_list_of_lists=False,
4754            include_requires_grad=components_require_grad,
4755            include_inner_dim_size_1=True,  # (B, *, 1)
4756        )
4757        for tensor_list, reduce_dim_tuple in itertools.product(
4758            tensor_lists, reduce_dims
4759        ):
4760            nt = torch.nested.nested_tensor(
4761                tensor_list,
4762                device=device,
4763                dtype=dtype,
4764                layout=torch.jagged,
4765                requires_grad=requires_grad,
4766            ).transpose(-1, -2)
4767
4768            reduce_dim, reduce_dim_expected = reduce_dim_tuple
4769
4770            if nt.dim() > max(reduce_dim, nt._ragged_idx + 2):
4771                out_actual = torch.nn.functional.softmax(
4772                    nt, dim=reduce_dim
4773                )  # nested tensor
4774                out_expected = torch.nn.functional.softmax(
4775                    nt.values(), dim=reduce_dim_expected
4776                )  # dense tensor of dimensions 1 less than out_actual
4777
4778                self.assertTrue(
4779                    torch.allclose(out_actual.values().view(-1), out_expected.view(-1))
4780                )
4781
4782    @dtypes(torch.float32)
4783    @parametrize("keepdim", [False, True])
4784    @parametrize("requires_grad", [False, True])
4785    @parametrize("components_require_grad", [False, True])
4786    def test_sum_dim_reduce_ragged_and_non_batch(
4787        self,
4788        device,
4789        dtype,
4790        keepdim,
4791        requires_grad,
4792        components_require_grad,
4793    ):
4794        """
4795        Sum on NestedTensor fails when trying to reduce across ragged and non-batch dimensions
4796        """
4797        tensor_lists = self._get_example_tensor_lists(
4798            include_list_of_lists=False, include_requires_grad=components_require_grad
4799        )
4800        reduce_dims = (
4801            (1, 2),  # ragged, non-batch
4802            (1, 3),  # ragged, non-batch
4803        )
4804
4805        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
4806            nt = torch.nested.nested_tensor(
4807                tensor_list,
4808                device=device,
4809                dtype=dtype,
4810                layout=torch.jagged,
4811                requires_grad=requires_grad,
4812            )
4813
4814            if nt.dim() > reduce_dim[-1]:
4815                with self.assertRaisesRegex(
4816                    RuntimeError,
4817                    "not supported along a ragged and non-batch dimension for NestedTensor",
4818                ):
4819                    out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
4820
4821    @dtypes(torch.float32)
4822    @parametrize("keepdim", [False, True])
4823    @parametrize("requires_grad", [False, True])
4824    @parametrize("components_require_grad", [False, True])
4825    def test_sum_dim_reduce_batch_and_non_batch(
4826        self,
4827        device,
4828        dtype,
4829        keepdim,
4830        requires_grad,
4831        components_require_grad,
4832    ):
4833        """
4834        Sum on NestedTensor fails when trying to reduce across batch and non-batch dimensions
4835        """
4836        tensor_lists = self._get_example_tensor_lists(
4837            include_list_of_lists=False, include_requires_grad=components_require_grad
4838        )
4839        reduce_dims = (
4840            (0, 2),  # batch, non-batch
4841            (0, 3),  # batch, non-batch
4842        )
4843
4844        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
4845            nt = torch.nested.nested_tensor(
4846                tensor_list,
4847                device=device,
4848                dtype=dtype,
4849                layout=torch.jagged,
4850                requires_grad=requires_grad,
4851            )
4852
4853            if nt.dim() > reduce_dim[-1]:
4854                with self.assertRaisesRegex(
4855                    RuntimeError,
4856                    "not supported along the batch dimension but not the ragged dimension for NestedTensor",
4857                ):
4858                    out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
4859
4860    @dtypes(torch.float32)
4861    @parametrize(
4862        "func",
4863        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
4864        name_fn=get_op_name,
4865    )
4866    @parametrize("keepdim", [False, True])
4867    @parametrize("requires_grad", [False, True])
4868    @parametrize("components_require_grad", [False, True])
4869    def test_op_dim_reduce_batch_only_different_output_shape(
4870        self, device, dtype, keepdim, requires_grad, components_require_grad, func
4871    ):
4872        """
4873        Operator on NestedTensor fails when trying to reduce across batch dimension
4874        """
4875        if get_op_name(func) == "mean" and not keepdim:
4876            return
4877
4878        tensor_lists = self._get_example_tensor_lists(
4879            include_list_of_lists=False, include_requires_grad=components_require_grad
4880        )
4881        reduce_dim = (0,)  # batch
4882
4883        for tensor_list in tensor_lists:
4884            nt = torch.nested.nested_tensor(
4885                tensor_list,
4886                device=device,
4887                dtype=dtype,
4888                layout=torch.jagged,
4889                requires_grad=requires_grad,
4890            )
4891
4892            with self.assertRaisesRegex(
4893                RuntimeError,
4894                "not supported along the batch dimension but not the ragged dimension for NestedTensor",
4895            ):
4896                out = func(nt, dim=reduce_dim, keepdim=keepdim)
4897
4898    @dtypes(torch.float32)
4899    @parametrize(
4900        "func",
4901        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
4902        name_fn=get_op_name,
4903    )
4904    @parametrize("keepdim", [False, True])
4905    @parametrize("requires_grad", [False, True])
4906    @parametrize("components_require_grad", [False, True])
4907    def test_op_dim_with_lengths_different_output_shape(
4908        self,
4909        device,
4910        dtype,
4911        keepdim,
4912        requires_grad,
4913        components_require_grad,
4914        func,
4915    ):
4916        """
4917        Operator on NestedTensor fails when trying to reduce a nested tensor with lengths,
4918        i.e. a nested tensor with holes, if reducing on the ragged dimension.
4919        This test is for operators which return an output tensor with different shape than the input tensor.
4920        """
4921        if get_op_name(func) == "mean" and not keepdim:
4922            return
4923
4924        reduce_dims = ((1,), (2,), (2, 3))
4925
4926        lengths = torch.randint(5, 10, (20,), device=device)
4927        offsets = torch.zeros((21,), device=device, dtype=torch.int)
4928        torch.cumsum(lengths, dim=0, out=offsets[1:])
4929
4930        values = torch.randn(
4931            (offsets[-1].item(), 20),
4932            device=device,
4933            dtype=dtype,
4934            requires_grad=requires_grad,
4935        )
4936
4937        nt_with_holes = torch.nested.nested_tensor_from_jagged(
4938            values,
4939            offsets,
4940            lengths=offsets.diff() - 2,  # arbitrary subtraction to create holes
4941        )
4942
4943        for reduce_dim in reduce_dims:
4944            if nt_with_holes.dim() > reduce_dim[-1]:
4945                if nt_with_holes._ragged_idx in reduce_dim:
4946                    with self.assertRaisesRegex(
4947                        RuntimeError,
4948                        "not supported where lengths is not None "
4949                        + "if reducing across the ragged dimension for NestedTensor",
4950                    ):
4951                        out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim)
4952                else:
4953                    out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim)
4954
4955    @dtypes(torch.float32)
4956    @parametrize("requires_grad", [False, True])
4957    @parametrize("components_require_grad", [False, True])
4958    def test_softmax_dim_with_lengths(
4959        self,
4960        device,
4961        dtype,
4962        requires_grad,
4963        components_require_grad,
4964    ):
4965        """
4966        Softmax on NestedTensor fails when trying to reduce a nested tensor with lengths,
4967        i.e. a nested tensor with holes, if reducing on the ragged dimension.
4968        """
4969        reduce_dims = (1, 2, 3)
4970
4971        lengths = torch.randint(5, 10, (20,), device=device)
4972        offsets = torch.zeros((21,), device=device, dtype=torch.int)
4973        torch.cumsum(lengths, dim=0, out=offsets[1:])
4974
4975        values = torch.randn(
4976            (offsets[-1].item(), 20),
4977            device=device,
4978            dtype=dtype,
4979            requires_grad=requires_grad,
4980        )
4981
4982        nt_with_holes = torch.nested.nested_tensor_from_jagged(
4983            values,
4984            offsets,
4985            lengths=offsets.diff() - 2,  # arbitrary subtraction to create holes
4986        )
4987
4988        for reduce_dim in reduce_dims:
4989            if nt_with_holes.dim() > reduce_dim:
4990                if nt_with_holes._ragged_idx == reduce_dim:
4991                    with self.assertRaisesRegex(
4992                        RuntimeError,
4993                        "not supported where lengths is not None "
4994                        + "if reducing across the ragged dimension for NestedTensor",
4995                    ):
4996                        out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim)
4997                else:
4998                    out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim)
4999
5000    @skipIfTorchDynamo(
5001        "ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx] does not currently work "
5002        + "with dynamo tests and throws this error: `AssertionError: SymInts must use SymNodeVariable. "
5003        + "If the underlying value is static, we will create a ConstantVariable and specialize.`"
5004    )
5005    @dtypes(torch.float32)
5006    @parametrize("requires_grad", [False, True])
5007    @parametrize("components_require_grad", [False, True])
5008    def test_layer_norm_with_lengths(
5009        self,
5010        device,
5011        dtype,
5012        requires_grad,
5013        components_require_grad,
5014    ):
5015        """
5016        Layer normalization on NestedTensor fails when trying to operate on a nested tensor with lengths,
5017        i.e. a nested tensor with holes, if operating on the ragged dimension.
5018        """
5019
5020        # create components for nested tensor
5021        lengths = torch.randint(5, 10, (20,), device=device)
5022        offsets = torch.zeros((21,), device=device, dtype=torch.int)
5023        torch.cumsum(lengths, dim=0, out=offsets[1:])
5024        values = torch.randn(
5025            (offsets[-1].item(), 10, 30),
5026            device=device,
5027            dtype=dtype,
5028            requires_grad=requires_grad,
5029        )
5030
5031        nt_with_holes = torch.nested.nested_tensor_from_jagged(
5032            values,
5033            offsets,
5034            lengths=offsets.diff() - 2,  # arbitrary subtraction to create holes
5035        )
5036
5037        ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx]
5038
5039        normalized_shapes = (
5040            (10, 30),  # normalization on non-ragged dimension passes
5041            (ragged_size, 10, 30),  # normalization on ragged dimension fails
5042        )
5043
5044        for normalized_shape in normalized_shapes:
5045            if ragged_size in normalized_shape:
5046                with self.assertRaisesRegex(
5047                    RuntimeError,
5048                    "not supported where lengths is not None if operating on the ragged dimension for NestedTensor",
5049                ):
5050                    out = torch.nn.functional.layer_norm(
5051                        nt_with_holes, normalized_shape=normalized_shape
5052                    )
5053            else:
5054                out = torch.nn.functional.layer_norm(
5055                    nt_with_holes, normalized_shape=normalized_shape
5056                )
5057
5058    @dtypes(torch.float32)
5059    @parametrize("keepdim", [True])
5060    @parametrize("requires_grad", [False, True])
5061    @parametrize("components_require_grad", [False, True])
5062    def test_mean_dim_reduce_multiple_dims(
5063        self,
5064        device,
5065        dtype,
5066        keepdim,
5067        requires_grad,
5068        components_require_grad,
5069    ):
5070        """
5071        Mean on NestedTensor fails when trying to reduce across multiple dimensions
5072        """
5073        tensor_lists = self._get_example_tensor_lists(
5074            include_list_of_lists=False, include_requires_grad=components_require_grad
5075        )
5076        reduce_dims = ((0, 1), (2, 3), (2, 3, 4))
5077
5078        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
5079            nt = torch.nested.nested_tensor(
5080                tensor_list,
5081                device=device,
5082                dtype=dtype,
5083                layout=torch.jagged,
5084                requires_grad=requires_grad,
5085            )
5086
5087            if nt.dim() > reduce_dim[-1]:
5088                with self.assertRaisesRegex(
5089                    RuntimeError,
5090                    "not supported across multiple dimensions for NestedTensor",
5091                ):
5092                    out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim)
5093
5094    @dtypes(torch.float32)
5095    @parametrize("keepdim", [False, True])
5096    @parametrize("requires_grad", [False, True])
5097    @parametrize("components_require_grad", [False, True])
5098    def test_mean_dim_keepdim_False(
5099        self,
5100        device,
5101        dtype,
5102        keepdim,
5103        requires_grad,
5104        components_require_grad,
5105    ):
5106        """
5107        Mean on NestedTensor fails when keepdim=False
5108        """
5109        tensor_lists = self._get_example_tensor_lists(
5110            include_list_of_lists=False, include_requires_grad=components_require_grad
5111        )
5112        reduce_dims = ((1,), (2,), (3,))
5113
5114        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
5115            nt = torch.nested.nested_tensor(
5116                tensor_list,
5117                device=device,
5118                dtype=dtype,
5119                layout=torch.jagged,
5120                requires_grad=requires_grad,
5121            )
5122
5123            if nt.dim() > reduce_dim[-1]:
5124                if not keepdim:
5125                    with self.assertRaisesRegex(
5126                        RuntimeError,
5127                        "not supported when keepdim=False for NestedTensor",
5128                    ):
5129                        out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim)
5130                else:
5131                    out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim)
5132
5133    @dtypes(torch.float, torch.double, torch.half)
5134    @parametrize("requires_grad", [False, True])
5135    @parametrize("weights_only", [False, True])
5136    def test_serialization(self, device, dtype, requires_grad, weights_only):
5137        def compare_metadata(nt1, nt2):
5138            self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size())
5139            self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides())
5140            self.assertEqual(
5141                nt1._nested_tensor_storage_offsets(),
5142                nt2._nested_tensor_storage_offsets(),
5143            )
5144
5145        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
5146        for a in [nt_contiguous, nt_noncontiguous]:
5147            buffer = io.BytesIO()
5148            serialized = torch.save(a, buffer)
5149            buffer.seek(0)
5150            b = torch.load(buffer, weights_only=weights_only)
5151            # should be both conceptually equal and metadata equivalent
5152            self.assertEqual(a, b)
5153            compare_metadata(a, b)
5154            # should be conceptually equal but not necessarily metadata equivalent
5155            self.assertEqual(b, nt_contiguous)
5156            self.assertEqual(b, nt_noncontiguous)
5157
5158    @unittest.skipIf(
5159        PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property"
5160    )
5161    @onlyCUDA
5162    def test_pin_memory(self, device):
5163        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
5164        for nt in [nt_contiguous, nt_noncontiguous]:
5165            self.assertFalse(nt.is_pinned())
5166            pinned = nt.pin_memory(device)
5167            self.assertTrue(pinned.is_pinned())
5168            self.assertEqual(nt, pinned)
5169            self.assertNotEqual(nt.data_ptr(), pinned.data_ptr())
5170            # test that pin_memory on already pinned tensor has no effect
5171            self.assertIs(pinned, pinned.pin_memory())
5172            self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
5173
5174    @torch.compiler.disable
5175    def _validate_nt(
5176        self,
5177        nt,
5178        device,
5179        dtype,
5180        layout,
5181        requires_grad,
5182        dim,
5183        batch_size,
5184        contiguous,
5185        cached_min_seqlen=None,
5186        cached_max_seqlen=None,
5187        base=None,
5188        ref_nt=None,
5189    ):
5190        # Validate a bunch of properties after NT construction.
5191        device = torch.device(device)
5192        self.assertEqual(nt.dim(), dim)
5193        self.assertEqual(nt.device, device)
5194        self.assertEqual(nt.dtype, dtype)
5195        self.assertEqual(nt.layout, layout)
5196        self.assertEqual(nt.requires_grad, requires_grad)
5197        self.assertEqual(nt.is_contiguous(), contiguous)
5198
5199        if layout == torch.jagged:
5200            self.assertEqual(nt._values.device, device)
5201            self.assertEqual(nt._offsets.device, device)
5202            self.assertEqual(nt.shape[0], batch_size)
5203            self.assertTrue(isinstance(nt.shape[1], torch.SymInt))
5204
5205            if base is not None:
5206                self.assertTrue(nt._is_view() and nt._base is base)
5207                replay_cache = nt._view_func(torch.randn_like(nt._base))._metadata_cache
5208                self.assertEqual(
5209                    "min_seqlen" in replay_cache, cached_min_seqlen is not None
5210                )
5211                self.assertEqual(
5212                    "max_seqlen" in replay_cache, cached_max_seqlen is not None
5213                )
5214
5215            self.assertEqual(
5216                "min_seqlen" in nt._metadata_cache, cached_min_seqlen is not None
5217            )
5218            self.assertEqual(
5219                "max_seqlen" in nt._metadata_cache, cached_max_seqlen is not None
5220            )
5221
5222            if cached_min_seqlen is not None:
5223                self.assertEqual(nt._min_seqlen, cached_min_seqlen)
5224
5225            if cached_max_seqlen is not None:
5226                self.assertEqual(nt._max_seqlen, cached_max_seqlen)
5227
5228        if ref_nt is not None:
5229            self.assertEqual(nt.size(0), ref_nt.size(0))
5230            for n1, n2 in zip(nt.unbind(), ref_nt.unbind()):
5231                self.assertEqual(n1, n2)
5232
5233    @dtypes(torch.float, torch.double, torch.half)
5234    @parametrize("requires_grad", [False, True])
5235    @parametrize("components_require_grad", [False, True])
5236    def test_jagged_layout_construction_nested_tensor(
5237        self, device, dtype, requires_grad, components_require_grad
5238    ):
5239        for tensor_list in self._get_example_tensor_lists(
5240            include_list_of_lists=True, include_requires_grad=components_require_grad
5241        ):
5242            nt = torch.nested.nested_tensor(
5243                tensor_list,
5244                device=device,
5245                dtype=dtype,
5246                layout=torch.jagged,
5247                requires_grad=requires_grad,
5248            )
5249
5250            expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1
5251            expected_batch_size = len(tensor_list)
5252            expected_contiguous = True
5253            expected_min_seqlen = min(
5254                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
5255                for t in tensor_list
5256            )
5257            expected_max_seqlen = max(
5258                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
5259                for t in tensor_list
5260            )
5261            self._validate_nt(
5262                nt,
5263                device,
5264                dtype,
5265                torch.jagged,
5266                requires_grad,
5267                expected_dim,
5268                expected_batch_size,
5269                expected_contiguous,
5270                expected_min_seqlen,
5271                expected_max_seqlen,
5272            )
5273
5274            # Make sure grads -don't- flow back into original tensors for nested_tensor()
5275            if requires_grad:
5276                (nt * 2).backward(torch.ones_like(nt))
5277            for t in tensor_list:
5278                t = t if isinstance(t, torch.Tensor) else torch.as_tensor(t)
5279                self.assertTrue(t.grad is None)
5280
5281    @dtypes(torch.float, torch.double, torch.half)
5282    @parametrize("components_require_grad", [False, True])
5283    def test_jagged_layout_construction_as_nested_tensor(
5284        self, device, dtype, components_require_grad
5285    ):
5286        # NB: as_nested_tensor(tensor_list) doesn't support lists of lists for tensor_list
5287        for tensor_list in self._get_example_tensor_lists(
5288            include_list_of_lists=False, include_requires_grad=components_require_grad
5289        ):
5290            nt = torch.nested.as_nested_tensor(
5291                tensor_list, device=device, dtype=dtype, layout=torch.jagged
5292            )
5293
5294            # nt.requires_grad=True should be set if at least one component requires grad
5295            expected_dim = tensor_list[0].dim() + 1
5296            expected_batch_size = len(tensor_list)
5297            expected_contiguous = True
5298            expected_min_seqlen = min(
5299                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
5300                for t in tensor_list
5301            )
5302            expected_max_seqlen = max(
5303                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
5304                for t in tensor_list
5305            )
5306            self._validate_nt(
5307                nt,
5308                device,
5309                dtype,
5310                torch.jagged,
5311                components_require_grad,
5312                expected_dim,
5313                expected_batch_size,
5314                expected_contiguous,
5315                expected_min_seqlen,
5316                expected_max_seqlen,
5317            )
5318
5319            # Make sure grads flow back into original tensors for as_nested_tensor()
5320            if components_require_grad:
5321                (nt * 2).backward(torch.ones_like(nt))
5322                for t in tensor_list:
5323                    if t.requires_grad:
5324                        self.assertEqual(t.grad, torch.ones_like(t) * 2)
5325                    else:
5326                        self.assertTrue(t.grad is None)
5327
5328    @xfailIfTorchDynamo
5329    @unittest.skipIf(
5330        PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property"
5331    )
5332    @onlyCUDA
5333    def test_jagged_layout_construction_with_pinned_memory(self, device):
5334        for tensor_list in self._get_example_tensor_lists():
5335            nt = torch.nested.nested_tensor(
5336                tensor_list, layout=torch.jagged, device="cpu", pin_memory=True
5337            )
5338
5339            expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1
5340            expected_batch_size = len(tensor_list)
5341            expected_min_seqlen = min(
5342                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
5343                for t in tensor_list
5344            )
5345            expected_max_seqlen = max(
5346                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
5347                for t in tensor_list
5348            )
5349            self._validate_nt(
5350                nt,
5351                device="cpu",
5352                dtype=torch.float32,
5353                layout=torch.jagged,
5354                requires_grad=False,
5355                dim=expected_dim,
5356                batch_size=expected_batch_size,
5357                contiguous=True,
5358                cached_min_seqlen=expected_min_seqlen,
5359                cached_max_seqlen=expected_max_seqlen,
5360            )
5361            self.assertTrue(nt.is_pinned())
5362
5363    @dtypes(torch.float, torch.double, torch.half)
5364    @parametrize("requires_grad", [False, True])
5365    @parametrize("values_is_view", [False, True])
5366    def test_jagged_view_from_values_offsets(
5367        self, device, dtype, requires_grad, values_is_view
5368    ):
5369        if values_is_view:
5370            # make values a view of base
5371            base = torch.randn(
5372                2, 3, 4, 5, 6, device=device, dtype=dtype, requires_grad=requires_grad
5373            )
5374            values = base.flatten(0, -2)
5375        else:
5376            values = torch.randn(
5377                10, 5, device=device, dtype=dtype, requires_grad=requires_grad
5378            )
5379        offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64)
5380
5381        nt = nested_view_from_values_offsets(values, offsets)
5382
5383        expected_dim = values.dim() + 1
5384        expected_batch_size = offsets.shape[0] - 1
5385        expected_base = base if values_is_view else values
5386        lengths = offsets.diff()
5387        self._validate_nt(
5388            nt,
5389            device,
5390            dtype,
5391            torch.jagged,
5392            requires_grad,
5393            expected_dim,
5394            expected_batch_size,
5395            # ensure NT is a proper view
5396            base=expected_base,
5397            contiguous=True,
5398            # if no min / max are passed, expect the metadata cache to be empty
5399            cached_min_seqlen=None,
5400            cached_max_seqlen=None,
5401        )
5402
5403        if requires_grad:
5404            # Make sure grads flow back
5405            (nt * 2).backward(torch.ones_like(nt))
5406
5407            @torch.compiler.disable
5408            def _check_grad(t):
5409                self.assertTrue(t.grad is not None)
5410                self.assertEqual(t.grad, torch.ones_like(t) * 2)
5411
5412            _check_grad(base if values_is_view else values)
5413
5414    @dtypes(torch.float)
5415    @parametrize("pass_min_max", [False, True])
5416    def test_nested_tensor_from_jagged(self, device, dtype, pass_min_max):
5417        # === construct from (values, offsets) ===
5418        values = torch.randn(10, 5, device=device, dtype=dtype)
5419        offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64)
5420
5421        # compute min / max seqlen
5422        lengths = offsets.diff()
5423        min_seqlen = lengths.min().item()
5424        max_seqlen = lengths.max().item()
5425
5426        if pass_min_max:
5427            nt = torch.nested.nested_tensor_from_jagged(
5428                values, offsets=offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen
5429            )
5430        else:
5431            nt = torch.nested.nested_tensor_from_jagged(values, offsets=offsets)
5432        self._validate_nt(
5433            nt,
5434            device,
5435            dtype,
5436            torch.jagged,
5437            requires_grad=False,
5438            dim=3,
5439            batch_size=4,
5440            contiguous=True,
5441            cached_min_seqlen=(min_seqlen if pass_min_max else None),
5442            cached_max_seqlen=(max_seqlen if pass_min_max else None),
5443            base=values,
5444        )
5445
5446        # === construct from (values, offsets, lengths) ===
5447        lengths = torch.tensor([2, 1, 1, 2], device=device)
5448
5449        # compute min / max seqlen
5450        min_seqlen = lengths.min().item()
5451        max_seqlen = lengths.max().item()
5452
5453        if pass_min_max:
5454            nt = torch.nested.nested_tensor_from_jagged(
5455                values,
5456                offsets=offsets,
5457                lengths=lengths,
5458                min_seqlen=min_seqlen,
5459                max_seqlen=max_seqlen,
5460            )
5461        else:
5462            nt = torch.nested.nested_tensor_from_jagged(
5463                values, offsets=offsets, lengths=lengths
5464            )
5465
5466        # when both offsets / lengths are specified, expect non-contiguous
5467        self._validate_nt(
5468            nt,
5469            device,
5470            dtype,
5471            torch.jagged,
5472            requires_grad=False,
5473            dim=3,
5474            batch_size=4,
5475            contiguous=False,
5476            cached_min_seqlen=(min_seqlen if pass_min_max else None),
5477            cached_max_seqlen=(max_seqlen if pass_min_max else None),
5478            base=values,
5479        )
5480        self.assertIs(nt.lengths(), lengths)
5481
5482        # === construct from (values, lengths) ===
5483        values = torch.randn(14, 5, device=device, dtype=dtype)
5484        lengths = torch.tensor([2, 3, 4, 5], device=device)
5485
5486        # compute min / max seqlen
5487        min_seqlen = lengths.min().item()
5488        max_seqlen = lengths.max().item()
5489
5490        if pass_min_max:
5491            nt = torch.nested.nested_tensor_from_jagged(
5492                values, lengths=lengths, min_seqlen=min_seqlen, max_seqlen=max_seqlen
5493            )
5494        else:
5495            nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths)
5496
5497        # for now, if only lengths is specified, convert to offsets to integrate best with the
5498        # existing kernels
5499        expected_offsets = torch.tensor([0, 2, 5, 9, 14], device=device)
5500        expected_nt = torch.nested.nested_tensor_from_jagged(
5501            values, offsets=expected_offsets
5502        )
5503        self._validate_nt(
5504            nt,
5505            device,
5506            dtype,
5507            torch.jagged,
5508            requires_grad=False,
5509            dim=3,
5510            batch_size=4,
5511            contiguous=True,
5512            cached_min_seqlen=(min_seqlen if pass_min_max else None),
5513            cached_max_seqlen=(max_seqlen if pass_min_max else None),
5514            base=values,
5515            ref_nt=expected_nt,
5516        )
5517
5518        # error case: no offsets or lengths
5519        with self.assertRaisesRegex(
5520            RuntimeError, "At least one of offsets or lengths is required"
5521        ):
5522            torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None)
5523
5524    @onlyCPU
5525    def test_nested_tensor_from_jagged_fx_trace(self, device):
5526        def fn(x, y):
5527            return torch.nested.nested_tensor_from_jagged(x, y)
5528
5529        def user_unwrapped(x, y):
5530            return fn(x, y)
5531
5532        with self.assertRaisesRegex(
5533            RuntimeError,
5534            "torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace",
5535        ):
5536            torch.fx.symbolic_trace(user_unwrapped)
5537
5538    @dtypes(torch.float, torch.double, torch.half)
5539    @parametrize("dim", range(5))
5540    @parametrize(
5541        "layout",
5542        [torch.strided, torch.jagged],
5543        name_fn=lambda l: f"layout_{str(l).split('.')[1]}",
5544    )
5545    @parametrize("requires_grad", [False, True])
5546    @parametrize("contiguous", [False, True])
5547    def test_as_nested_tensor_from_tensor(
5548        self, device, dtype, dim, layout, requires_grad, contiguous
5549    ):
5550        if dim == 0:
5551            t = torch.tensor(3.0, requires_grad=requires_grad)
5552        else:
5553            t = torch.randn(*(3 for _ in range(dim)), requires_grad=requires_grad)
5554        assert t.dim() == dim
5555
5556        if dim < 2:
5557            # 0-1 dim tensors can't be converted to NTs
5558            with self.assertRaisesRegex(
5559                RuntimeError, "Expected tensor argument to have dim"
5560            ):
5561                nt = torch.nested.as_nested_tensor(
5562                    t, device=device, dtype=dtype, layout=layout
5563                )
5564            return
5565
5566        orig_t = t
5567        if not contiguous:
5568            t = t.transpose(0, 1)
5569
5570        nt = torch.nested.as_nested_tensor(t, device=device, dtype=dtype, layout=layout)
5571        expected_dim = t.dim()
5572        expected_batch_size = t.size(0)
5573        expected_seqlen = t.size(1) if layout == torch.jagged else None
5574        self._validate_nt(
5575            nt,
5576            device,
5577            dtype,
5578            layout,
5579            requires_grad=requires_grad,
5580            dim=dim,
5581            batch_size=expected_batch_size,
5582            contiguous=True,
5583            cached_min_seqlen=expected_seqlen,
5584            cached_max_seqlen=expected_seqlen,
5585        )
5586
5587        if torch.device(device) == t.device and dtype == t.dtype and contiguous:
5588            # should be the non-copying (view) case
5589            self.assertTrue(nt._is_view() and nt._base is t)
5590
5591        # should have equivalent components to construction from unbound tensor list
5592        nt_from_unbind = torch.nested.as_nested_tensor(
5593            list(t.unbind(0)), device=device, dtype=dtype, layout=layout
5594        )
5595        self.assertEqualIgnoringNestedInts(nt, nt_from_unbind)
5596
5597        # ensure call on a NT with the same properties returns the NT directly
5598        nt2 = torch.nested.as_nested_tensor(
5599            nt, device=device, dtype=dtype, layout=layout
5600        )
5601        self.assertTrue(nt is nt2)
5602
5603        # ensure call with device=None uses input tensor device
5604        nt3 = torch.nested.as_nested_tensor(
5605            t.to(device=device, dtype=dtype),
5606            device=None,
5607            dtype=None,
5608            layout=layout,
5609        )
5610        self._validate_nt(
5611            nt3,
5612            device,
5613            dtype,
5614            layout,
5615            requires_grad=requires_grad,
5616            dim=dim,
5617            batch_size=expected_batch_size,
5618            contiguous=True,
5619            cached_min_seqlen=expected_seqlen,
5620            cached_max_seqlen=expected_seqlen,
5621        )
5622
5623        # we don't support conversion between layouts this way atm
5624        other_layout = torch.strided if layout == torch.jagged else torch.jagged
5625        with self.assertRaisesRegex(
5626            RuntimeError, "Converting between nested tensor layouts is not supported"
5627        ):
5628            torch.nested.as_nested_tensor(
5629                nt, device=device, dtype=dtype, layout=other_layout
5630            )
5631
5632        if requires_grad:
5633            # make sure gradients flow back into inputs
5634            (nt * 2).backward(torch.ones_like(nt))
5635            self.assertEqual(orig_t.grad, torch.ones_like(orig_t) * 2)
5636
5637    @dtypes(torch.double, torch.half)
5638    @onlyCUDA
5639    def test_device_dtype_transfer_updates_offsets(self, device, dtype):
5640        for tensor_list in self._get_example_tensor_lists():
5641            orig_device = torch.device("cpu")
5642            orig_dtype = torch.float32
5643            nt = torch.nested.nested_tensor(
5644                tensor_list, layout=torch.jagged, device=orig_device, dtype=orig_dtype
5645            )
5646
5647            self.assertEqual(torch.int64, nt.offsets().dtype)
5648            nt = nt.to(device=device).to(dtype=dtype)
5649
5650            # offsets should still be int64 on the new device
5651            self.assertEqual(nt.values().device, nt.offsets().device)
5652            self.assertEqual(torch.int64, nt.offsets().dtype)
5653
5654    def test_unbind(self, device):
5655        for tensor_list in self._get_example_tensor_lists():
5656            nt = torch.nested.nested_tensor(
5657                tensor_list, layout=torch.jagged, device=device
5658            )  # ragged_idx = 1
5659            out = nt.unbind()
5660            self.assertEqual(len(out), len(tensor_list))
5661            for i, t in enumerate(out):
5662                self.assertEqual(t, tensor_list[i])
5663
5664    @parametrize("ragged_idx", [2, 3])
5665    def test_unbind_transpose(self, device, ragged_idx):
5666        for tensor_list in self._get_example_tensor_lists():
5667            nt = torch.nested.nested_tensor(
5668                tensor_list, layout=torch.jagged, device=device
5669            )
5670            if ragged_idx < nt.dim():
5671                nt = nt.transpose(1, ragged_idx)  # set ragged_idx
5672                out = nt.unbind()
5673                self.assertEqual(len(out), len(tensor_list))
5674                for i, t in enumerate(out):
5675                    self.assertEqual(
5676                        t.transpose(0, ragged_idx - 1), tensor_list[i]
5677                    )  # transpose back each element of result
5678
5679    def test_unbind_transpose_ragged_idx_last_dim(self, device):
5680        for tensor_list in self._get_example_tensor_lists():
5681            nt = torch.nested.nested_tensor(
5682                tensor_list, layout=torch.jagged, device=device
5683            ).transpose(1, -1)  # set ragged_idx = last dimension
5684            out = nt.unbind()
5685            self.assertEqual(len(out), len(tensor_list))
5686            for i, t in enumerate(out):
5687                self.assertEqual(
5688                    t.transpose(0, -1), tensor_list[i]
5689                )  # transpose back each element of result
5690
5691    def test_unbind_lengths(self, device):
5692        values = torch.randn(16, 128, device=device)
5693        offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
5694        lengths = torch.tensor([6, 2, 1, 2], device=device)
5695        nt = torch.nested.nested_tensor_from_jagged(
5696            values, offsets=offsets, lengths=lengths
5697        )  # 3D nested tensor
5698
5699        tensor_list = []
5700        for i in range(offsets.shape[0] - 1):
5701            tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i])])
5702
5703        out = nt.unbind()
5704        self.assertEqual(len(out), len(tensor_list))
5705        for i, t in enumerate(out):
5706            self.assertEqual(t, tensor_list[i])
5707
5708    def test_unbind_lengths_ragged_idx_1(self, device):
5709        values = torch.randn(16, 8, 128, device=device)
5710        offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
5711        lengths = torch.tensor([6, 2, 1, 2], device=device)
5712        ragged_idx = 1
5713        nt = torch.nested._internal.nested_tensor.NestedTensor(
5714            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
5715        )  # 4D nested tensor
5716
5717        tensor_list = []
5718        for i in range(offsets.shape[0] - 1):
5719            tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i]), :, :])
5720
5721        out = nt.unbind()
5722
5723        self.assertEqual(len(out), len(tensor_list))
5724        for i, t in enumerate(out):
5725            self.assertEqual(t, tensor_list[i])
5726
5727    def test_unbind_lengths_ragged_idx_equals_2_bad_dim(self, device):
5728        values = torch.randn(16, 8, 128, device=device)
5729        offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
5730        lengths = torch.tensor([6, 2, 1, 2], device=device)
5731        ragged_idx = 2
5732        nt = torch.nested._internal.nested_tensor.NestedTensor(
5733            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
5734        )  # 4D nested tensor
5735
5736        self.assertRaisesRegex(
5737            RuntimeError,
5738            r"unbind\(\): nested tensor offsets and lengths.*",
5739            lambda: nt.unbind(),
5740        )
5741
5742    def test_unbind_lengths_ragged_idx_2(self, device):
5743        values = torch.randn(16, 8, 128, device=device)
5744        offsets = torch.tensor([0, 2, 4, 8], device=device)
5745        lengths = torch.tensor([2, 1, 3], device=device)
5746        ragged_idx = 2
5747        nt = torch.nested._internal.nested_tensor.NestedTensor(
5748            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
5749        )  # 4D nested tensor
5750
5751        tensor_list = []
5752        for i in range(offsets.shape[0] - 1):
5753            tensor_list.append(values[:, offsets[i] : (offsets[i] + lengths[i]), :])
5754
5755        out = nt.unbind()
5756
5757        self.assertEqual(len(out), len(tensor_list))
5758        for i, t in enumerate(out):
5759            self.assertEqual(t, tensor_list[i])
5760
5761    def test_unbind_lengths_ragged_idx_3(self, device):
5762        values = torch.randn(16, 8, 128, device=device)
5763        offsets = torch.tensor([0, 100, 128], device=device)
5764        lengths = torch.tensor([50, 28], device=device)
5765        ragged_idx = 3
5766        nt = torch.nested._internal.nested_tensor.NestedTensor(
5767            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
5768        )  # 4D nested tensor
5769
5770        tensor_list = []
5771        for i in range(offsets.shape[0] - 1):
5772            tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])])
5773
5774        out = nt.unbind()
5775
5776        self.assertEqual(len(out), len(tensor_list))
5777        for i, t in enumerate(out):
5778            self.assertEqual(t, tensor_list[i])
5779
5780    @skipIfTorchDynamo(
5781        "TorchDynamo raises an error for ragged_idx == 0 earlier than Torch"
5782    )
5783    def test_unbind_lengths_ragged_idx_0(self, device):
5784        values = torch.randn(16, 8, 128, device=device)
5785        offsets = torch.tensor([0, 100, 128], device=device)
5786        lengths = torch.tensor([50, 28], device=device)
5787        ragged_idx = 0
5788        nt = torch.nested._internal.nested_tensor.NestedTensor(
5789            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
5790        )  # 4D nested tensor
5791
5792        tensor_list = []
5793        for i in range(offsets.shape[0] - 1):
5794            tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])])
5795
5796        self.assertRaisesRegex(
5797            RuntimeError,
5798            r"unbind\(\): nested tensor.*out of bounds",
5799            lambda: nt.unbind(),
5800        )
5801
5802    def test_narrow(self, device):
5803        starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
5804        lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
5805        buffer = (
5806            torch.arange(0, 10, device=device, dtype=torch.int64)
5807            .unsqueeze(0)
5808            .expand(5, -1)
5809            .clone()
5810            .detach()
5811        )
5812        nt = torch.nested.narrow(buffer, 1, starts, lengths, layout=torch.jagged)
5813
5814        self.assertTrue(nt._is_view() and nt._base is buffer)
5815
5816        # TODO: Use this approach when unbind is functional
5817        # unbinded_nt = nt.unbind()
5818        # for i in range(starts.shape[0]):
5819        #     self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i])
5820        for i in range(starts.shape[0]):
5821            self.assertEqual(
5822                torch.arange(
5823                    starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64
5824                ),
5825                nt.values()[nt.offsets()[i] : (nt.offsets()[i] + nt.lengths()[i])],
5826            )
5827
5828    def test_njt_cat(self, device):
5829        offsets = torch.tensor([0, 2, 3], device=device, dtype=torch.int64)
5830        values_1 = torch.randn(
5831            3, 2, dtype=torch.float64, device=device, requires_grad=True
5832        )
5833        values_2 = torch.randn(
5834            3, 4, dtype=torch.float64, device=device, requires_grad=True
5835        )
5836
5837        def grad_test_func(values_1, values_2, offsets):
5838            nt_1 = torch.nested.nested_tensor_from_jagged(values_1, offsets)
5839            nt_2 = torch.nested.nested_tensor_from_jagged(values_2, offsets)
5840            nt_3 = torch.cat([nt_1, nt_2], dim=-1)
5841            return nt_3.values()
5842
5843        assert gradcheck(
5844            grad_test_func,
5845            inputs=(values_1, values_2, offsets),
5846            check_batched_grad=False,
5847        )
5848
5849    def test_is_contiguous(self, device):
5850        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
5851        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
5852        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
5853        nt_contiguous = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
5854
5855        starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
5856        lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
5857        narrow_base = (
5858            torch.arange(0, 10, device=device, dtype=torch.int64)
5859            .unsqueeze(0)
5860            .expand(5, -1)
5861            .clone()
5862        )
5863        nt_noncontiguous = torch.nested.narrow(
5864            narrow_base, 1, starts_nc, lengths_nc, layout=torch.jagged
5865        )
5866
5867        starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64)
5868        lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64)
5869        nt_contiguous_narrow = torch.nested.narrow(
5870            narrow_base, 1, starts_c, lengths_c, layout=torch.jagged
5871        )
5872
5873        # Test contiguous case
5874        assert nt_contiguous.is_contiguous()
5875
5876        # Test narrow case
5877        assert not nt_noncontiguous.is_contiguous()
5878        assert nt_contiguous_narrow.is_contiguous()
5879
5880        # Test querying by memory_format
5881        self.assertTrue(
5882            nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)
5883        )
5884        self.assertTrue(
5885            not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)
5886        )
5887        self.assertTrue(
5888            nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format)
5889        )
5890
5891    def test_layout_under_torch_dispatch_mode(self):
5892        from torch.testing._internal.logging_tensor import (
5893            capture_logs_with_logging_tensor_mode,
5894        )
5895
5896        nt = random_nt_from_dims(
5897            [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
5898        )
5899
5900        with capture_logs_with_logging_tensor_mode():
5901            self.assertEqual(nt.layout, torch.jagged)
5902
5903    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
5904    @parametrize(
5905        "func", [torch.empty_like, torch.randn_like], name_fn=lambda f: f.__name__
5906    )
5907    def test_like_shape(self, func):
5908        nt = random_nt_from_dims(
5909            [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
5910        )
5911        nt_like = func(nt)
5912
5913        for nt_ub in nt_like.unbind():
5914            t_like = func(nt_ub)
5915            self.assertEqual(nt_ub.shape, t_like.shape)
5916
5917    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
5918    @parametrize(
5919        "func", [torch.ones_like, torch.zeros_like], name_fn=lambda f: f.__name__
5920    )
5921    def test_like_value(self, func):
5922        nt = random_nt_from_dims(
5923            [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
5924        )
5925        nt_like = func(nt)
5926
5927        for nt_ub in nt_like.unbind():
5928            t_like = func(nt_ub)
5929            self.assertEqual(nt_ub, t_like)
5930
5931    def test_noncontiguous_pointwise(self, device):
5932        a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
5933        b = torch.randn(3, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
5934        c = torch.randn(4, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
5935        nt = torch.nested.nested_tensor([a, b, c], layout=torch.jagged)
5936        # transpose ragged dim
5937        transposed = nt.transpose(1, 2)
5938        self.assertFalse(transposed.is_contiguous())
5939        clone = transposed.clone()
5940
5941        def check_nt_equality(x, y):
5942            self.assertEqual(x.values(), y.values())
5943            self.assertEqual(x.offsets(), y.offsets())
5944            self.assertEqual(x._ragged_idx, y._ragged_idx)
5945            self.assertEqual(x.shape, y.shape)
5946
5947        self.assertFalse(clone.is_contiguous())
5948        check_nt_equality(clone, transposed)
5949
5950        clone_contig = transposed.clone(memory_format=torch.contiguous_format)
5951        self.assertTrue(clone_contig.is_contiguous())
5952        check_nt_equality(clone_contig, transposed)
5953
5954        detached = transposed.detach()
5955        self.assertFalse(clone.is_contiguous())
5956        check_nt_equality(detached, transposed)
5957
5958    def test_permute(self, device):
5959        nt = random_nt_from_dims(
5960            [2, None, 3, 5], device, torch.float32, layout=torch.jagged
5961        )
5962        nt_shape = nt.shape
5963        nt_inner_shape = nt.values().shape
5964        with self.assertRaisesRegex(
5965            ValueError,
5966            r"permute\(\): number of dimensions in the tensor input \(4\) "
5967            + r"does not match the length of the desired ordering of dimensions \(3\).",
5968        ):
5969            nt.permute(0, 2, 1)
5970        with self.assertRaisesRegex(
5971            ValueError, r"permute\(\): duplicate dims are not allowed."
5972        ):
5973            nt.permute(0, 2, -2, 3)
5974        with self.assertRaisesRegex(
5975            ValueError, "Permute is not supported on the batch dimension for jagged NT"
5976        ):
5977            nt.permute(1, 0, 2, 3)
5978        nt_permute = nt.permute(0, 2, 1, -1)
5979        self.assertEqual(
5980            nt_permute.shape, (nt_shape[0], nt_shape[2], nt_shape[1], nt_shape[3])
5981        )
5982        self.assertEqual(
5983            nt_permute.values().shape,
5984            (nt_inner_shape[1], nt_inner_shape[0], nt_inner_shape[2]),
5985        )
5986        self.assertEqual(nt_permute._ragged_idx, 2)
5987        self.assertEqual(nt_permute.permute(0, 2, 1, 3), nt)
5988
5989    def test_to_dtype(self, device):
5990        nt = random_nt_from_dims(
5991            [2, None, 3], device, torch.float32, layout=torch.jagged
5992        )
5993        nt_after = nt.to(torch.float64)
5994        self.assertEqual(torch.float32, nt.dtype)
5995        self.assertEqual(torch.float64, nt_after.dtype)
5996        self.assertEqual(torch.float64, nt_after.values().dtype)
5997        self.assertEqual(torch.int64, nt_after.offsets().dtype)
5998
5999        noncontiguous_nt = nt.transpose(1, 2)
6000        noncontiguous_nt_after = noncontiguous_nt.to(torch.bfloat16)
6001        self.assertEqual(torch.bfloat16, noncontiguous_nt_after.dtype)
6002        self.assertEqual(torch.bfloat16, noncontiguous_nt_after.values().dtype)
6003        self.assertEqual(torch.int64, noncontiguous_nt_after.offsets().dtype)
6004
6005    def test_to_copy(self, device):
6006        nt = torch.nested.nested_tensor(
6007            [
6008                torch.randn(
6009                    i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device
6010                )
6011                for i in range(3)
6012            ],
6013            layout=torch.jagged,
6014        )
6015
6016        nt_copy_dtype = torch.ops.aten._to_copy(nt, dtype=torch.float16)
6017        self.assertEqual(torch.float16, nt_copy_dtype.dtype)
6018
6019        nt_t = nt.transpose(1, 2)
6020        nt_t_copy_dtype = torch.ops.aten._to_copy(nt_t, dtype=torch.float16)
6021        self.assertEqual(torch.float16, nt_t_copy_dtype.dtype)
6022
6023    def test_copy_(self, device):
6024        offsets = torch.tensor([0, 2, 4], device=device)
6025        a = torch.nested.nested_tensor_from_jagged(
6026            torch.zeros(4, 3, device=device), offsets
6027        )
6028        b = torch.nested.nested_tensor_from_jagged(
6029            torch.ones(4, 3, device=device), offsets
6030        )
6031        a.copy_(b)
6032        torch._dynamo.disable(self.assertEqual)(a, b)
6033
6034        offsets_2 = torch.tensor([0, 2, 4], device=device)
6035        c = torch.nested.nested_tensor_from_jagged(
6036            torch.ones(4, 3, device=device), offsets_2
6037        )
6038        # fail when tensors have the same size but not the exact same offset tensor.
6039        with self.assertRaisesRegex(
6040            RuntimeError,
6041            "copy_ only supports Nested Tensors that have same size and the exact same offset tensor.",
6042        ):
6043            a.copy_(c)
6044
6045        # fail when tensors have different sizes
6046        a = a.transpose(1, 2)
6047        with self.assertRaisesRegex(
6048            RuntimeError,
6049            "copy_ only supports Nested Tensors that have same size and the exact same offset tensor.",
6050        ):
6051            a.copy_(b)
6052
6053    @skipIfTorchDynamo("Dynamo doesn't know how to trace prof.events()")
6054    def test_profiler_sequence_nr(self):
6055        with torch.profiler.profile() as prof:
6056            values = torch.randn(4, 6, requires_grad=True)
6057            offsets = torch.tensor([0, 2, 4])
6058            values = values * 2
6059            l = torch.nn.Linear(6, 8)
6060            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
6061
6062            nt = l(nt)
6063            val = nt.values()
6064
6065            loss = val.sum()
6066            loss.backward()
6067
6068        fwd_seq_nrs = []
6069        for evt in prof.events():
6070            if (
6071                "linear" in evt.name.lower()
6072                and "backward" not in evt.name.lower()
6073                and evt.sequence_nr != -1
6074            ):
6075                fwd_seq_nrs.append(evt.sequence_nr)
6076
6077        bwd_seq_nrs = []
6078        for evt in prof.events():
6079            if (
6080                "linear" in evt.name.lower()
6081                and "backward" in evt.name.lower()
6082                and "evaluate_function" not in evt.name.lower()
6083                and evt.sequence_nr != -1
6084            ):
6085                bwd_seq_nrs.append(evt.sequence_nr)
6086
6087        # There should only be one such event with a sequence number:
6088        # the PythonTLSSnapshot event - but, note that it's not terrible if
6089        # we end up with multiple events with the same sequence number - so we
6090        # could relax this check if it becomes inconvenient to maintain this
6091        # property.
6092        self.assertEqual(len(fwd_seq_nrs), 1)
6093        self.assertEqual(len(bwd_seq_nrs), 1)
6094        self.assertEqual(fwd_seq_nrs[0], bwd_seq_nrs[0])
6095
6096    def test_is_same_size(self, device):
6097        def get_3_tensors():
6098            return [
6099                torch.randn(
6100                    i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device
6101                )
6102                for i in range(3)
6103            ]
6104
6105        nt1, offsets1 = jagged_from_list(get_3_tensors(), None)
6106        nt2, offsets1 = jagged_from_list(get_3_tensors(), offsets1)
6107
6108        nt3, offsets2 = jagged_from_list(get_3_tensors(), None)
6109        nt4, offsets2 = jagged_from_list(get_3_tensors(), offsets2)
6110
6111        def check_size(nt1, nt2, nt3, nt4):
6112            self.assertTrue(torch.ops.aten.is_same_size(nt1, nt2))
6113            self.assertTrue(torch.ops.aten.is_same_size(nt3, nt4))
6114            self.assertFalse(torch.ops.aten.is_same_size(nt1, nt3))
6115
6116        check_size(nt1, nt2, nt3, nt4)
6117
6118        nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4))
6119        check_size(nt1_t, nt2_t, nt3_t, nt4_t)
6120
6121    @skipIfTorchDynamo("compiles internally")
6122    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
6123    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6124    def test_specialize_dynamic_shape(self, device):
6125        values = torch.randn((18, 16), device=device)
6126        offsets = torch.tensor([0, 2, 3, 6, 15, 18], device=device)
6127        like_values = torch.randn_like(values)
6128
6129        # this marks values as dynamic
6130        nt = torch.nested.nested_tensor_from_jagged(values, offsets)
6131
6132        def fn(values, same_size):
6133            # here, the dynamic shape is specialized by same_size's shape
6134            # https://github.com/pytorch/pytorch/issues/127097
6135            # make sure this doesn't error out in torch.compile
6136            return values + same_size
6137
6138        self.assertEqual(
6139            fn(values, like_values),
6140            torch.compile(fn)(values, like_values),
6141        )
6142
6143    @skipIfTorchDynamo("compiles internally")
6144    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
6145    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6146    def test_specialize_dynamic_shape_recompile(self, device):
6147        def generate_inp(total_len):
6148            values = torch.randn((total_len, 16), device=device)
6149            offsets = torch.tensor([0, 2, 3, 6, 15, total_len], device=device)
6150            like_values = torch.randn_like(values)
6151            return values, offsets, like_values
6152
6153        def check_results(ref_fn, res_fn, args):
6154            values, offsets, like_values = args
6155            # this may add dynamic shape markings
6156            # goal of this test is to make sure that whatever markings are there,
6157            # we eventually stop recompiling as shape changes.
6158            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
6159
6160            self.assertEqual(ref_fn(values, like_values), res_fn(values, like_values))
6161
6162        def fn(values, same_size):
6163            return values + same_size
6164
6165        compile_counter = torch._dynamo.testing.CompileCounter()
6166
6167        compiled_fn = torch._dynamo.optimize(compile_counter, nopython=True)(fn)
6168        check_results(fn, compiled_fn, generate_inp(18))
6169        self.assertEqual(compile_counter.frame_count, 1)
6170
6171        check_results(fn, compiled_fn, generate_inp(19))
6172        # we'll probably recompile here with dynamic shapes - it's okay if not though.
6173        frame_count_2 = compile_counter.frame_count
6174        self.assertIn(frame_count_2, [1, 2])
6175
6176        # make sure that by now we've already compiled with dynamic shapes, so additional
6177        # shapes should not trigger additional recompiles.
6178        check_results(fn, compiled_fn, generate_inp(20))
6179        self.assertEqual(compile_counter.frame_count, frame_count_2)
6180
6181    # Note 1: Math fallback doesn't work with bfloat16 on CUDA
6182    # Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT
6183    @unittest.skipIf(
6184        TEST_WITH_ROCM,
6185        "ROCm doesn't support flash attention or mem_efficient attention for NT",
6186    )
6187    @dtypes(
6188        *(
6189            [torch.float16, torch.bfloat16, torch.float32]
6190            if SM80OrLater
6191            else [torch.float16, torch.float32]
6192        )
6193    )
6194    def test_sdpa(self, device, dtype):
6195        batch_size = 1
6196        emb_dims = 128
6197        n_heads = 8
6198        head_dims = emb_dims // n_heads
6199
6200        sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device)
6201        sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device)
6202
6203        query = torch.nn.Linear(
6204            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
6205        )
6206        key = torch.nn.Linear(
6207            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
6208        )
6209        value = torch.nn.Linear(
6210            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
6211        )
6212
6213        # Simplest case: 1 sentence, no batching
6214        x_d1 = sen1.unsqueeze(0)
6215        x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged)
6216
6217        # See note below for why we detach here.
6218        q_d1 = (
6219            query(x_d1)
6220            .view(batch_size, -1, n_heads, head_dims)
6221            .detach()
6222            .requires_grad_(True)
6223        )
6224        q_d1_t = q_d1.transpose(1, 2)
6225        k_d1 = (
6226            key(x_d1)
6227            .view(batch_size, -1, n_heads, head_dims)
6228            .detach()
6229            .requires_grad_(True)
6230        )
6231        k_d1_t = k_d1.transpose(1, 2)
6232        v_d1 = (
6233            value(x_d1)
6234            .view(batch_size, -1, n_heads, head_dims)
6235            .detach()
6236            .requires_grad_(True)
6237        )
6238        v_d1_t = v_d1.transpose(1, 2)
6239
6240        q_nt = (
6241            query(x_nt)
6242            .view(*x_nt.size()[0:2], n_heads, head_dims)
6243            .detach()
6244            .requires_grad_(True)
6245        )
6246        q_nt_t = q_nt.transpose(1, 2)
6247        k_nt = (
6248            key(x_nt)
6249            .view(*x_nt.size()[0:2], n_heads, head_dims)
6250            .detach()
6251            .requires_grad_(True)
6252        )
6253        k_nt_t = k_nt.transpose(1, 2)
6254        v_nt = (
6255            value(x_nt)
6256            .view(*x_nt.size()[0:2], n_heads, head_dims)
6257            .detach()
6258            .requires_grad_(True)
6259        )
6260        v_nt_t = v_nt.transpose(1, 2)
6261
6262        # High Precision Math Reference
6263        q_d1_f32 = q_d1.to(torch.float32)
6264        k_d1_f32 = k_d1.to(torch.float32)
6265        v_d1_f32 = v_d1.to(torch.float32)
6266        q_d1_f32_t = q_d1_f32.transpose(1, 2)
6267        k_d1_f32_t = k_d1_f32.transpose(1, 2)
6268        v_d1_f32_t = v_d1_f32.transpose(1, 2)
6269        out_ref = torch.ops.aten._scaled_dot_product_attention_math(
6270            q_d1_f32_t, k_d1_f32_t, v_d1_f32_t
6271        )[0]
6272        grads_ref = torch.autograd.grad(out_ref.sum(), (q_d1_f32, k_d1_f32, v_d1_f32))
6273
6274        # Low Precision Math Reference
6275        out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
6276            q_d1_t, k_d1_t, v_d1_t
6277        )[0]
6278        grads_lp_ref = torch.autograd.grad(out_lp_ref.sum(), (q_d1, k_d1, v_d1))
6279
6280        # Compute tolerances
6281        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
6282        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(grads_ref[0], grads_lp_ref[0])
6283        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(grads_ref[1], grads_lp_ref[1])
6284        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(grads_ref[2], grads_lp_ref[2])
6285        grad_atols = [grad_q_ref_atol, grad_k_ref_atol, grad_v_ref_atol]
6286        grad_rtols = [grad_q_ref_rtol, grad_k_ref_rtol, grad_v_ref_rtol]
6287
6288        attn_d1 = torch.nn.functional.scaled_dot_product_attention(
6289            q_d1_t, k_d1_t, v_d1_t
6290        ).transpose(1, 2)
6291        attn_nt = torch.nn.functional.scaled_dot_product_attention(
6292            q_nt_t, k_nt_t, v_nt_t
6293        ).transpose(1, 2)
6294
6295        self.assertEqual(
6296            attn_d1,
6297            attn_nt.unbind()[0].unsqueeze(0),
6298            atol=output_ref_atol,
6299            rtol=output_ref_rtol,
6300        )
6301
6302        # Simple case: 2 sentences, no extra params
6303        x_d2 = sen2.unsqueeze(0)
6304        x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged)
6305
6306        # NB: we make sure the leaf tensor we compute gradients for is the view-ed tensor before
6307        # it is transposed. This is because today we cannot backward through view or unbind a
6308        # transposed tensor.
6309        q_d2 = (
6310            query(x_d2)
6311            .view(batch_size, -1, n_heads, head_dims)
6312            .detach()
6313            .requires_grad_(True)
6314        )
6315        q_d2_t = q_d2.transpose(1, 2)
6316        k_d2 = (
6317            key(x_d2)
6318            .view(batch_size, -1, n_heads, head_dims)
6319            .detach()
6320            .requires_grad_(True)
6321        )
6322        k_d2_t = k_d2.transpose(1, 2)
6323        v_d2 = (
6324            value(x_d2)
6325            .view(batch_size, -1, n_heads, head_dims)
6326            .detach()
6327            .requires_grad_(True)
6328        )
6329        v_d2_t = v_d2.transpose(1, 2)
6330
6331        q_nt = (
6332            query(x_nt)
6333            .view(*x_nt.size()[0:2], n_heads, head_dims)
6334            .detach()
6335            .requires_grad_(True)
6336        )
6337        q_nt_t = q_nt.transpose(1, 2)
6338        k_nt = (
6339            key(x_nt)
6340            .view(*x_nt.size()[0:2], n_heads, head_dims)
6341            .detach()
6342            .requires_grad_(True)
6343        )
6344        k_nt_t = k_nt.transpose(1, 2)
6345        v_nt = (
6346            value(x_nt)
6347            .view(*x_nt.size()[0:2], n_heads, head_dims)
6348            .detach()
6349            .requires_grad_(True)
6350        )
6351        v_nt_t = v_nt.transpose(1, 2)
6352
6353        attn_d2 = torch.nn.functional.scaled_dot_product_attention(
6354            q_d2_t, k_d2_t, v_d2_t
6355        ).transpose(1, 2)
6356        d1_grads = torch.autograd.grad(attn_d1.sum(), (q_d1, k_d1, v_d1))
6357        d2_grads = torch.autograd.grad(attn_d2.sum(), (q_d2, k_d2, v_d2))
6358
6359        # Simple case 3: batch_size = 1, seq_len = 1
6360        q_3 = torch.randn(1, 8, 16, dtype=dtype, device=device)
6361        q_nt_3 = torch.nested.as_nested_tensor([q_3], layout=torch.jagged)
6362        q_nt_3 = q_nt_3.transpose(1, 2)
6363        attn_out = torch.nn.functional.scaled_dot_product_attention(
6364            q_nt_3, q_nt_3, q_nt_3
6365        )
6366        self.assertEqual(attn_out.shape, q_nt_3.shape)
6367
6368        def check_forward_backward():
6369            attn_nt = torch.nn.functional.scaled_dot_product_attention(
6370                q_nt_t, k_nt_t, v_nt_t
6371            ).transpose(1, 2)
6372
6373            attn_nts = attn_nt.unbind()
6374            self.assertEqual(
6375                attn_d1,
6376                attn_nts[0].unsqueeze(0),
6377                atol=output_ref_atol,
6378                rtol=output_ref_rtol,
6379            )
6380            self.assertEqual(
6381                attn_d2,
6382                attn_nts[1].unsqueeze(0),
6383                atol=output_ref_atol,
6384                rtol=output_ref_rtol,
6385            )
6386
6387            nt_grads = torch.autograd.grad(attn_nt.values().sum(), (q_nt, k_nt, v_nt))
6388            for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip(
6389                nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols
6390            ):
6391                unbound_nt_grads = nt_grad.unbind()
6392                self.assertEqual(
6393                    d1_grad,
6394                    unbound_nt_grads[0].unsqueeze(0),
6395                    atol=grad_atol,
6396                    rtol=grad_rtol,
6397                )
6398                self.assertEqual(
6399                    d2_grad,
6400                    unbound_nt_grads[1].unsqueeze(0),
6401                    atol=grad_atol,
6402                    rtol=grad_rtol,
6403                )
6404
6405        # Default
6406        check_forward_backward()
6407
6408        # Test dispatcher works by calling only mem-effn and math (as they are safe for all devices)
6409        with torch.backends.cuda.sdp_kernel(
6410            enable_flash=False, enable_mem_efficient=True, enable_math=True
6411        ):
6412            check_forward_backward()
6413
6414        # Test math fallback
6415        with torch.backends.cuda.sdp_kernel(
6416            enable_flash=False, enable_mem_efficient=False, enable_math=True
6417        ):
6418            # Math fallback doesn't work with bfloat16 on CUDA because
6419            # "group_gemm_dispatch" not implemented for 'BFloat16'
6420            if not (str(device).startswith("cuda") and dtype == torch.bfloat16):
6421                check_forward_backward()
6422
6423    @skipIfTorchDynamo("SDPA test compiles internally")
6424    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
6425    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6426    # Guarding with sqrt() doesn't work on ROCm?
6427    @skipCUDAIfRocm
6428    @onlyCUDA
6429    @dtypes(
6430        *(
6431            [torch.float16, torch.bfloat16, torch.float32]
6432            if SM80OrLater
6433            else [torch.float16, torch.float32]
6434        )
6435    )
6436    def test_sdpa_compile(self, device, dtype):
6437        batch_size = 1
6438        emb_dims = 1024
6439        n_heads = 8
6440        head_dims = emb_dims // n_heads
6441
6442        sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device)
6443        sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device)
6444
6445        query = torch.nn.Linear(
6446            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
6447        )
6448        key = torch.nn.Linear(
6449            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
6450        )
6451        value = torch.nn.Linear(
6452            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
6453        )
6454
6455        # Simplest case: 1 sentence, no batching
6456        x_d1 = sen1.unsqueeze(0)
6457        x_d2 = sen2.unsqueeze(0)
6458        x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged)
6459
6460        q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
6461        k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
6462        v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
6463        q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
6464        k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
6465        v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
6466
6467        q_nt = (
6468            query(x_nt)
6469            .view(*x_nt.size()[0:2], n_heads, head_dims)
6470            .detach()
6471            .transpose(1, 2)
6472        )
6473        k_nt = (
6474            key(x_nt)
6475            .view(*x_nt.size()[0:2], n_heads, head_dims)
6476            .detach()
6477            .transpose(1, 2)
6478        )
6479        v_nt = (
6480            value(x_nt)
6481            .view(*x_nt.size()[0:2], n_heads, head_dims)
6482            .detach()
6483            .transpose(1, 2)
6484        )
6485
6486        # High Precision Math Reference
6487        q_d1_f32 = q_d1.to(torch.float32)
6488        k_d1_f32 = k_d1.to(torch.float32)
6489        v_d1_f32 = v_d1.to(torch.float32)
6490        out_ref = torch.ops.aten._scaled_dot_product_attention_math(
6491            q_d1_f32, k_d1_f32, v_d1_f32
6492        )[0]
6493        # Low Precision Math Reference
6494        out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
6495            q_d1, k_d1, v_d1
6496        )[0]
6497        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
6498
6499        attn_d1 = torch.nn.functional.scaled_dot_product_attention(
6500            q_d1, k_d1, v_d1
6501        ).transpose(1, 2)
6502        attn_d2 = torch.nn.functional.scaled_dot_product_attention(
6503            q_d2, k_d2, v_d2
6504        ).transpose(1, 2)
6505
6506        compiled_sdpa = torch.compile(torch.nn.functional.scaled_dot_product_attention)
6507        attn_nt = compiled_sdpa(q_nt, k_nt, v_nt).transpose(1, 2)
6508
6509        attn_nts = attn_nt.unbind()
6510        self.assertEqual(
6511            attn_d1,
6512            attn_nts[0].unsqueeze(0),
6513            atol=output_ref_atol,
6514            rtol=output_ref_rtol,
6515        )
6516        self.assertEqual(
6517            attn_d2,
6518            attn_nts[1].unsqueeze(0),
6519            atol=output_ref_atol,
6520            rtol=output_ref_rtol,
6521        )
6522
6523    @dtypes(torch.float32, torch.double, torch.half)
6524    def test_sdpa_with_constant_sequence_length(self, device, dtype):
6525        # shape (B, P*, S, D)
6526        # B: batch size
6527        # P*: ragged number of prompts
6528        # S: (constant) sequence length
6529        # D: embedding size
6530        query = random_nt_from_dims(
6531            [4, None, 8, 10],
6532            device=device,
6533            dtype=dtype,
6534            layout=torch.jagged,
6535            requires_grad=True,
6536        )
6537        key = random_nt_from_similar(query)
6538        value = random_nt_from_similar(query)
6539        output = F.scaled_dot_product_attention(query, key, value)
6540        self.assertTrue(isinstance(output, NestedTensor))
6541        output.values().sum().backward()
6542
6543        query_dense = query.clone().detach().requires_grad_(True)
6544        # should be equivalent to just running the buffers through
6545        output_dense = F.scaled_dot_product_attention(
6546            query_dense.values(), key.values(), value.values()
6547        )
6548        torch._dynamo.disable(self.assertEqual)(output._values, output_dense)
6549        output_dense.sum().backward()
6550        torch._dynamo.disable(self.assertEqual)(query.grad, query_dense.grad)
6551
6552    @onlyCUDA
6553    @unittest.skipIf(
6554        not PLATFORM_SUPPORTS_FUSED_ATTENTION,
6555        "Platform doesn't support flash or mem-efficient attention",
6556    )
6557    @dtypes(
6558        *(
6559            [torch.float16, torch.bfloat16, torch.float32]
6560            if SM80OrLater
6561            else [torch.float16, torch.float32]
6562        )
6563    )
6564    def test_sdpa_with_packed_in_proj(self, device, dtype):
6565        # shape (B, *, D)
6566        input_packed = random_nt_from_dims(
6567            [5, None, 10], device=device, dtype=dtype, layout=torch.jagged
6568        )
6569
6570        # Do input projection.
6571        num_heads = 2
6572        # should be multiple of 4 for efficient kernels (e.g. flash / mem-efficient)
6573        head_dim = 8
6574        qkv_linear = torch.nn.Linear(10, num_heads * head_dim * 3).to(
6575            device=device, dtype=dtype
6576        )
6577
6578        def in_proj(input_packed, qkv_linear=qkv_linear):
6579            qkv_post_proj = qkv_linear(input_packed)
6580            # these are non-contiguous to trigger _is_safe_to_get_storage_as_tensor()
6581            q, k, v = qkv_post_proj.chunk(3, dim=-1)
6582            q = q.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
6583            k = k.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
6584            v = v.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
6585            return q, k, v
6586
6587        q, k, v = in_proj(input_packed)
6588        output = F.scaled_dot_product_attention(q, k, v, attn_mask=None)
6589
6590        # compare to individually running unbound components through
6591        for in_component, out_component in zip(
6592            input_packed.unbind(), output.transpose(-2, -3).unbind()
6593        ):
6594            q, k, v = in_proj(in_component)
6595            out = F.scaled_dot_product_attention(q, k, v).transpose(-2, -3)
6596
6597            # Low Precision Math Reference
6598            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q, k, v)[
6599                0
6600            ].transpose(-2, -3)
6601            output_ref_atol, output_ref_rtol = get_tolerances(
6602                out, out_lp_ref, fudge_factor=2
6603            )
6604
6605            self.assertEqual(
6606                out, out_component, atol=output_ref_atol, rtol=output_ref_rtol
6607            )
6608
6609    @skipIfTorchDynamo("SDPA test compiles internally")
6610    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
6611    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6612    # mha_varlen_fwd not supported on ROCm
6613    @skipCUDAIfRocm
6614    @onlyCUDA
6615    @dtypes(
6616        *(
6617            [torch.float16, torch.bfloat16, torch.float32]
6618            if SM80OrLater
6619            else [torch.float16, torch.float32]
6620        )
6621    )
6622    def test_sdpa_backwards(self, device, dtype):
6623        values = torch.randn(9, 3, 256, requires_grad=True, device=device, dtype=dtype)
6624        offsets = torch.tensor([0, 1, 3, 5, 9], device=device, dtype=torch.int64)
6625
6626        @torch.compile
6627        def f(values, offsets):
6628            nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
6629            nt = nt.transpose(-2, -3)
6630            # purposefully graph break to trigger view replay for subclass view input
6631            torch.tensor(1).item()
6632            output = F.scaled_dot_product_attention(nt, nt, nt).transpose(-2, -3)
6633            return convert_nt_to_jagged(output)
6634
6635        output = f(values, offsets)
6636        output.sum().backward()
6637        self.assertEqual(values.grad, torch.ones_like(values))
6638
6639    @unittest.skipIf(
6640        not PLATFORM_SUPPORTS_FUSED_ATTENTION,
6641        "Platform doesn't support flash or mem-efficient attention",
6642    )
6643    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6644    @skipCUDAIfRocm
6645    @onlyCUDA
6646    @skipIfTorchDynamo()
6647    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
6648    def test_sdpa_autocast(self, device):
6649        def fn_nt(values32, values16, offsets):
6650            nt32 = convert_jagged_to_nested_tensor(values32, offsets, max_length=16)
6651            nt16 = convert_jagged_to_nested_tensor(values16, offsets, max_length=16)
6652            nt32 = nt32.transpose(1, 2)
6653            nt16 = nt16.transpose(1, 2)
6654            return F.scaled_dot_product_attention(nt32, nt16, nt32)
6655
6656        def fn_dense(x32, x16):
6657            x32 = x32.view(8, 16, 4, 16).transpose(1, 2)
6658            x16 = x16.view(8, 16, 4, 16).transpose(1, 2)
6659            return F.scaled_dot_product_attention(x32, x16, x32)
6660
6661        values32 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float32)
6662        values16 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float16)
6663        offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32)
6664
6665        x32 = values32.clone()
6666        x16 = values16.clone()
6667
6668        with torch.autocast(device_type="cuda", dtype=torch.float16):
6669            out_dense_eager = fn_dense(x32, x16)
6670            out_dense_compiled = torch.compile(fn_dense)(x32, x16)
6671            out_nt_eager = fn_nt(values32, values16, offsets)
6672            out_nt_compiled = torch.compile(fn_nt)(values32, values16, offsets)
6673
6674        self.assertEqual(out_dense_eager, out_dense_compiled)
6675        self.assertEqual(
6676            out_dense_eager.transpose(1, 2),
6677            out_nt_eager.values().transpose(0, 1).view(8, 16, 4, 16),
6678        )
6679        self.assertEqual(
6680            out_dense_eager.transpose(1, 2),
6681            out_nt_compiled.values().transpose(0, 1).view(8, 16, 4, 16),
6682        )
6683
6684        def get_values():
6685            return tuple(
6686                x.clone().detach().requires_grad_(True) for x in (values32, values16)
6687            )
6688
6689        v32_dense_eager, v16_dense_eager = get_values()
6690        v32_dense_compile, v16_dense_compile = get_values()
6691        v32_nt_eager, v16_nt_eager = get_values()
6692        v32_nt_compile, v16_nt_compile = get_values()
6693
6694        with torch.autocast(device_type="cuda", dtype=torch.float16):
6695            loss_dense_eager = fn_dense(v32_dense_eager, v16_dense_eager).sum()
6696            loss_dense_compile = torch.compile(fn_dense)(
6697                v32_dense_compile, v16_dense_compile
6698            ).sum()
6699            loss_nt_eager = fn_nt(v32_nt_eager, v16_nt_eager, offsets).values().sum()
6700            loss_nt_compile = (
6701                torch.compile(fn_nt)(v32_nt_compile, v16_nt_compile, offsets)
6702                .values()
6703                .sum()
6704            )
6705
6706        loss_dense_eager.backward()
6707        loss_dense_compile.backward()
6708        loss_nt_eager.backward()
6709        loss_nt_compile.backward()
6710
6711        self.assertEqual(v32_dense_eager.grad, v32_dense_compile.grad)
6712        self.assertEqual(v32_dense_eager.grad, v32_nt_eager.grad)
6713        self.assertEqual(v32_dense_eager.grad, v32_nt_compile.grad)
6714
6715        self.assertEqual(v16_dense_eager.grad, v16_dense_compile.grad)
6716        self.assertEqual(v16_dense_eager.grad, v16_nt_eager.grad)
6717        self.assertEqual(v16_dense_eager.grad, v16_nt_compile.grad)
6718
6719    @unittest.skipIf(
6720        not PLATFORM_SUPPORTS_FUSED_ATTENTION,
6721        "Platform doesn't support flash or mem-efficient attention",
6722    )
6723    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6724    @skipCUDAIfRocm
6725    @onlyCUDA
6726    @skipIfTorchDynamo()
6727    def test_sdpa_flop_counter(self, device):
6728        from torch.utils.flop_counter import FlopCounterMode
6729
6730        def get_flops(nt):
6731            flop_counter = FlopCounterMode(display=False)
6732            with flop_counter:
6733                ret = torch.nn.functional.scaled_dot_product_attention(nt, nt, nt)
6734                ret.values().sum().backward()
6735            return flop_counter.get_total_flops()
6736
6737        values = torch.randn(
6738            (8 * 16, 4, 16), requires_grad=True, device=device, dtype=torch.float16
6739        )
6740        offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32)
6741        nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16)
6742
6743        values_meta = torch.randn(
6744            (8 * 16, 4, 16), requires_grad=True, device="meta", dtype=torch.float16
6745        )
6746        offsets_meta = torch.arange(0, 8 * 16 + 1, 16, device="meta", dtype=torch.int32)
6747        nt_meta = convert_jagged_to_nested_tensor(values, offsets, max_length=16)
6748
6749        self.assertEqual(get_flops(nt), get_flops(nt_meta))
6750
6751    @skipIfTorchDynamo()
6752    def test_nested_tensor_activation_checkpoint(self, device):
6753        values = torch.randn(
6754            9, 3, 256, requires_grad=True, device=device, dtype=torch.float32
6755        )
6756        lengths = torch.tensor([1, 2, 3, 3], device=device, dtype=torch.int64)
6757        offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0)
6758
6759        def fn(values, offsets):
6760            nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
6761            return convert_nt_to_jagged(nt).sum()
6762
6763        checkpoint(fn, values, offsets, use_reentrant=False).backward()
6764        self.assertIsNotNone(values.grad)
6765
6766        context_fn = partial(
6767            create_selective_checkpoint_contexts, [torch.ops.aten.cumsum.default]
6768        )
6769
6770        values.grad = None
6771
6772        def fn(values, lengths):
6773            offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0)
6774            nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
6775            return convert_nt_to_jagged(nt).sum()
6776
6777        checkpoint(
6778            fn, values, lengths, use_reentrant=False, context_fn=context_fn
6779        ).backward()
6780        self.assertIsNotNone(values.grad)
6781
6782    # Internally-defined NT use cases are lifted to here for maximum test realism.
6783    # TODO: Remove these when ViewNestedFromBuffer, etc. are deprecated.
6784    @skipCUDAIfRocm  # not needed
6785    @skipIfTorchDynamo("compiles internally")
6786    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
6787    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6788    @parametrize("use_legacy_api", [True, False])
6789    @skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644")
6790    def test_dummy_mha_with_nt(self, device, use_legacy_api):
6791        bs = 3
6792        d1 = 2
6793        d2 = 4
6794        d3 = 16
6795        n_heads = 2
6796        d_head = d3 // n_heads
6797        max_length_1 = 10
6798        max_length_2 = 20
6799        torch.manual_seed(0)
6800
6801        class mha(torch.nn.Module):
6802            def __init__(self, use_legacy_api) -> None:
6803                super().__init__()
6804                torch.manual_seed(0)
6805                self.linear = torch.nn.Linear(d2, d3, device=device)
6806                self.use_legacy_api = use_legacy_api
6807
6808            def forward(self, query, value, offsets):
6809                value = self.linear(value)
6810                if self.use_legacy_api:
6811                    key = convert_jagged_to_nested_tensor_legacy(
6812                        value, offsets, max_length_1
6813                    )
6814                    value = convert_jagged_to_nested_tensor_legacy(
6815                        value, offsets, max_length_2
6816                    )
6817                    query = convert_dense_to_nested_tensor_legacy(query)
6818                else:
6819                    key = convert_jagged_to_nested_tensor(value, offsets, max_length_1)
6820                    value = convert_jagged_to_nested_tensor(
6821                        value, offsets, max_length_2
6822                    )
6823                    query = convert_dense_to_nested_tensor(query)
6824                q = query.view(bs, -1, n_heads, d_head).transpose(1, 2)
6825                k = key.view(bs, -1, n_heads, d_head).transpose(1, 2)
6826                v = value.view(bs, -1, n_heads, d_head).transpose(1, 2)
6827
6828                with torch.nn.attention.sdpa_kernel(
6829                    [
6830                        torch.nn.attention.SDPBackend.FLASH_ATTENTION,
6831                        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
6832                    ]
6833                ):
6834                    attn_output = torch.nn.functional.scaled_dot_product_attention(
6835                        q,
6836                        k,
6837                        v,
6838                        attn_mask=None,
6839                        dropout_p=0.0,
6840                        is_causal=False,
6841                    )
6842                attn_output = attn_output.transpose(1, 2)
6843                if self.use_legacy_api:
6844                    attn_output = convert_nt_to_jagged_legacy(attn_output)
6845                else:
6846                    attn_output = convert_nt_to_jagged(attn_output)
6847                return attn_output, key._max_seqlen, value._max_seqlen
6848
6849        query = torch.rand(bs, d1, d3, device=device)
6850        value = torch.rand(30, d2, requires_grad=True, device=device)
6851        # total_length must > than max_length otherwise flash_attn backwark will fail
6852        offsets = torch.tensor([0, 2, 3, 30], device=device)
6853
6854        m = mha(use_legacy_api)
6855        symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m)
6856        m = torch.compile(symbolic_traced)
6857        attn_output, cached_key_max_seqlen, cached_value_max_seqlen = m(
6858            query, value, offsets
6859        )
6860        loss = attn_output.sum()
6861        # Check that NT can be fx traced and torch.compile, and backward works
6862        loss.backward()
6863
6864        # Check that value.requires_grad is not lost after tracing and compiling
6865        value_grad = value.grad  # save for comparison later
6866        self.assertIsNotNone(value_grad)
6867        # check that max_seqlen is cached properly
6868        self.assertEqual(cached_key_max_seqlen, max_length_1)
6869        self.assertEqual(cached_value_max_seqlen, max_length_2)
6870
6871        # check if the output is numerically equivalent with the eager mode
6872        m_eager = mha(use_legacy_api)
6873
6874        value.grad = None
6875        attn_output_eager, _, _ = m_eager(query, value, offsets)
6876        attn_output_eager.sum().backward()
6877        self.assertTrue(torch.allclose(attn_output_eager, attn_output))
6878        self.assertTrue(torch.allclose(value_grad, value.grad))
6879
6880    @dtypes(torch.float32)
6881    def test_apply_(self, device, dtype):
6882        nt = random_nt_from_dims(
6883            [5, None, 10],
6884            device=device,
6885            dtype=dtype,
6886            layout=torch.jagged,
6887            requires_grad=True,
6888        )
6889
6890        def f(x):
6891            return x * 2
6892
6893        if device != "cpu":
6894            with self.assertRaisesRegex(
6895                TypeError, "apply_ is only implemented on CPU tensors"
6896            ):
6897                nt.apply_(f)
6898            return
6899
6900        before = nt._values.clone().detach()
6901
6902        nt.apply_(f)
6903        expected = f(before)
6904        self.assertEqual(expected, nt._values)
6905        # apply_ should swap values in-place without appending to autograd graph
6906        self.assertIsNone(nt.grad)
6907        self.assertIsNone(nt._values.grad_fn)
6908
6909    @dtypes(torch.float64, torch.float32, torch.half)
6910    def test_jagged_padded_dense_conversion_kernels(self, device, dtype):
6911        values = torch.randn(10, 5, device=device, dtype=dtype)
6912        offsets = torch.tensor([0, 1, 3, 8, 10], device=device, dtype=torch.int64)
6913        max_length = offsets.diff().max().item()
6914        padding_value = 1.3
6915
6916        # convert jagged -> padded dense
6917        padded = torch.ops.aten._jagged_to_padded_dense_forward(
6918            values, [offsets], [max_length], padding_value
6919        )
6920
6921        batch_size = offsets.shape[0] - 1
6922        expected_padded_shape = (batch_size, max_length, values.shape[-1])
6923        self.assertEqual(padded.shape, expected_padded_shape)
6924
6925        # convert padded dense -> jagged
6926        total_L = values.shape[0]
6927        output_jagged = torch.ops.aten._padded_dense_to_jagged_forward(
6928            padded, [offsets], total_L
6929        )
6930
6931        # should be equivalent to the original values
6932        self.assertEqual(values, output_jagged)
6933
6934        # success case: truncate to max length as needed
6935        trunc_max_length = max_length - 1
6936        trunc_padded = torch.ops.aten._jagged_to_padded_dense_forward(
6937            values, [offsets], [trunc_max_length], padding_value
6938        )
6939        self.assertEqual(padded[:, :trunc_max_length, :], trunc_padded)
6940
6941        # specific to CPU impls
6942        if device == "cpu":
6943            # error case: multiple offsets on cpu since CPU kernels don't support more now
6944            with self.assertRaisesRegex(
6945                RuntimeError, "only a single jagged dim is supported"
6946            ):
6947                torch.ops.aten._jagged_to_padded_dense_forward(
6948                    values, [offsets, offsets], [max_length, max_length], padding_value
6949                )
6950
6951            with self.assertRaisesRegex(
6952                RuntimeError, "only a single jagged dim is supported"
6953            ):
6954                torch.ops.aten._padded_dense_to_jagged_forward(
6955                    padded, [offsets, offsets], total_L
6956                )
6957
6958            # error case: > 1D offsets
6959            offsets2d = offsets.unsqueeze(-1)
6960            with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"):
6961                torch.ops.aten._jagged_to_padded_dense_forward(
6962                    values, [offsets2d], [max_length], padding_value
6963                )
6964
6965            with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"):
6966                torch.ops.aten._padded_dense_to_jagged_forward(
6967                    padded, [offsets2d], total_L
6968                )
6969
6970            # error case: final offset != total_L
6971            offsets_wrong = offsets.clone().detach()
6972            offsets_wrong[-1] = total_L + 1
6973            with self.assertRaisesRegex(
6974                RuntimeError, "final offset should match total_L value"
6975            ):
6976                torch.ops.aten._padded_dense_to_jagged_forward(
6977                    padded, [offsets_wrong], total_L
6978                )
6979
6980            # error case: 1D padded input
6981            padded_wrong = padded.flatten().clone().detach()
6982            with self.assertRaisesRegex(RuntimeError, "expected padded dim >= 2"):
6983                torch.ops.aten._padded_dense_to_jagged_forward(
6984                    padded_wrong, [offsets], total_L
6985                )
6986
6987            # error case: batch item has length > max length
6988            # max_length is 5 above; 7 here
6989            offsets_wrong = torch.tensor(
6990                [0, 1, 8, 9, 10], device=device, dtype=torch.int64
6991            )
6992            with self.assertRaisesRegex(RuntimeError, "found batch item of length"):
6993                torch.ops.aten._padded_dense_to_jagged_forward(
6994                    padded, [offsets_wrong], total_L
6995                )
6996
6997    @dtypes(torch.float32)
6998    @skipIfTorchDynamo("Test compiles internally")
6999    @unittest.skipIf(
7000        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
7001    )
7002    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
7003    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
7004    @skipCUDAIfRocm
7005    def test_compile_preserves_metadata_cache(self, device, dtype):
7006        # shape (B, *, D)
7007        nt = random_nt_from_dims(
7008            [4, None, 3, 16],
7009            device=device,
7010            dtype=dtype,
7011            layout=torch.jagged,
7012            requires_grad=True,
7013        )
7014
7015        # expect min / max seqlen to be stored here
7016        cache = dict(nt._metadata_cache)
7017
7018        @torch.compile
7019        def f(nt):
7020            q = nt.transpose(-3, -2)
7021            output = F.scaled_dot_product_attention(q, q, q).transpose(-3, -2)
7022            return output
7023
7024        output = f(nt)
7025        output.backward(torch.ones_like(output))
7026        self.assertEqual(output._metadata_cache, cache)
7027
7028    @dtypes(torch.float32)
7029    @skipIfTorchDynamo("Test compiles internally")
7030    @unittest.skipIf(
7031        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
7032    )
7033    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
7034    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
7035    @skipCUDAIfRocm
7036    def test_compile_with_dynamic_max_seq_len(self, device, dtype):
7037        # shape (B, *, D)
7038        # max seq len: 18
7039        nt = torch.nested.nested_tensor(
7040            [
7041                torch.randn(2, 5),
7042                torch.randn(3, 5),
7043                torch.randn(18, 5),
7044            ],
7045            layout=torch.jagged,
7046        )
7047
7048        # max seq len: 19
7049        nt2 = torch.nested.nested_tensor(
7050            [
7051                torch.randn(2, 5),
7052                torch.randn(3, 5),
7053                torch.randn(19, 5),
7054            ],
7055            layout=torch.jagged,
7056        )
7057
7058        def f(nt):
7059            # TODO: Replace with public API when we can use @properties
7060            return torch.ones_like(nt) * nt._get_max_seqlen()
7061
7062        for dynamic in [False, True, None]:
7063            self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
7064
7065    @dtypes(torch.float32)
7066    @skipIfTorchDynamo("Test compiles internally")
7067    @unittest.skipIf(
7068        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
7069    )
7070    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
7071    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
7072    @skipCUDAIfRocm
7073    def test_compile_with_dynamic_min_seq_len(self, device, dtype):
7074        # shape (B, *, D)
7075        # min seq len: 7
7076        nt = torch.nested.nested_tensor(
7077            [
7078                torch.randn(7, 5),
7079                torch.randn(8, 5),
7080                torch.randn(9, 5),
7081            ],
7082            layout=torch.jagged,
7083        )
7084
7085        # min seq len: 8
7086        nt2 = torch.nested.nested_tensor(
7087            [
7088                torch.randn(8, 5),
7089                torch.randn(9, 5),
7090                torch.randn(10, 5),
7091            ],
7092            layout=torch.jagged,
7093        )
7094
7095        def f(nt):
7096            # TODO: Replace with public API when we can use @properties
7097            return torch.ones_like(nt) * nt._get_min_seqlen()
7098
7099        for dynamic in [False, True, None]:
7100            self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
7101
7102    @dtypes(torch.float32)
7103    @skipIfTorchDynamo("Test compiles internally")
7104    @unittest.skipIf(
7105        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
7106    )
7107    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
7108    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
7109    @skipCUDAIfRocm
7110    def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype):
7111        # shape (B, *, D)
7112        # max seq len: 18
7113        nt = torch.nested.nested_tensor(
7114            [
7115                torch.randn(2, 5),
7116                torch.randn(3, 5),
7117                torch.randn(18, 5),
7118            ],
7119            layout=torch.jagged,
7120        )
7121
7122        # max seq len: 19
7123        nt2 = torch.nested.nested_tensor(
7124            [
7125                torch.randn(2, 5),
7126                torch.randn(3, 5),
7127                torch.randn(19, 5),
7128            ],
7129            layout=torch.jagged,
7130        )
7131
7132        def f(nt):
7133            nt2 = nt.sin() + 1
7134            # TODO: Replace with public API when we can use @properties
7135            return torch.ones_like(nt2) * nt2._get_max_seqlen()
7136
7137        ref = f(nt)
7138        output = torch.compile(f, fullgraph=True, dynamic=False)(nt)
7139        self.assertEqual(ref, output)
7140
7141        for dynamic in [False, True, None]:
7142            self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
7143
7144    @dtypes(torch.float32, torch.double, torch.half)
7145    def test_unbind_backward(self, device, dtype):
7146        nt = torch.nested.nested_tensor(
7147            [
7148                torch.randn(2, 4, device=device),
7149                torch.randn(5, 4, device=device),
7150                torch.randn(3, 4, device=device),
7151            ],
7152            layout=torch.jagged,
7153            requires_grad=True,
7154        )
7155
7156        a, b, c = nt.unbind()
7157        b.sum().backward()
7158
7159        @torch._dynamo.disable
7160        def check(nt):
7161            expected_grad = torch.zeros_like(nt)
7162            expected_grad.unbind()[1].add_(1.0)
7163            self.assertEqual(nt.grad, expected_grad)
7164
7165        check(nt)
7166
7167
7168FORWARD_FAILURES = {
7169    # === BEGIN NotImplementedError SECTION ===
7170    # unary
7171    "nn.functional.celu",
7172    "nn.functional.elu",
7173    "nn.functional.hardshrink",
7174    "nn.functional.hardsigmoid",
7175    "nn.functional.hardtanh",
7176    "nn.functional.logsigmoid",
7177    "nn.functional.mish",
7178    "nn.functional.relu6",
7179    "nn.functional.rrelu",
7180    "nn.functional.selu",
7181    "nn.functional.softplus",
7182    "nn.functional.softshrink",
7183    "nn.functional.threshold",
7184    "rad2deg",
7185    # binary
7186    "__rsub__",
7187    "complex",
7188    "floor_divide",
7189    "polar",
7190    "rsub",
7191    # reduction
7192    "all",
7193    "amax",
7194    "amin",
7195    "any",
7196    "argmax",
7197    "argmin",
7198    "count_nonzero",
7199    "linalg.vector_norm",
7200    "nansum",
7201    "std",
7202    "std.unbiased",
7203    "var",
7204    "var.unbiased",
7205    # === BEGIN UNSUPPORTED SECTION ===
7206    # RuntimeError: mean(): not supported for NestedTensor on dim=1
7207    "mean",
7208    # ValueError: expects strided tensor (got torch.jagged tensor)
7209    "masked.amax",
7210    "masked.amin",
7211    "masked.argmax",
7212    "masked.argmin",
7213    "masked.logsumexp",
7214    "masked.mean",
7215    "masked.norm",
7216    "masked.prod",
7217    "masked.std",
7218    "masked.sum",
7219    "masked.var",
7220    # === BEGIN BUG SECTION ===
7221    # Returns a tuple of Tensors so it doesn't work with NJT's unary pointwise logic
7222    "frexp",
7223    # Need to adjust sample input func to pass the right thing
7224    "nn.functional.prelu",
7225    # TypeError: fill() received an invalid combination of arguments
7226    # got (NestedTensor), but expected one of:
7227    # * (Tensor input, Tensor value)
7228    # * (Tensor input, Number value)
7229    "fill",
7230    # RuntimeError: unsupported tensor layout: Jagged
7231    "jiterator_binary",
7232    "jiterator_binary_return_by_ref",
7233    "jiterator_unary",
7234    # Bug found: sum() with keepdim=True returns invalid shape
7235    "sum",
7236    # RuntimeError: prod(): keepdim=True must be set for NestedTensor
7237    "prod",
7238    # RuntimeError: "jagged_to_padded_dense" not implemented for 'Bool'
7239    "nanmean",
7240}
7241
7242BACKWARD_FAILURES = {
7243    *FORWARD_FAILURES,
7244    # TODO: categorize these
7245    "__rpow__",
7246    "atanh",
7247    "cdouble",
7248    "cfloat",
7249    "chalf",
7250    "clamp_max",
7251    "clamp_min",
7252    "copysign",
7253    "float_power",
7254    "max.binary",
7255    "maximum",
7256    "min.binary",
7257    "minimum",
7258    "pow",
7259    "sgn",
7260    "sinc",
7261    "special.i1",
7262    "special.i1e",
7263    # clone() on a "non-contiguous with holes" NJT allocates a new offsets -> new nested int
7264    # RuntimeError: Function CloneBackward0 returned an invalid gradient at index 0 -
7265    # got [3, j29, 5] but expected shape compatible with [3, j28, 5]
7266    "clone",
7267    # Calling into torch.ops.aten.size directly
7268    "masked_select",
7269}
7270
7271COMPILE_FORWARD_FAILURES = {
7272    *FORWARD_FAILURES,
7273    # clone() on non-contiguous with holes NJTs currently use unbind(), leading to
7274    # data-dependent error in torch.compile
7275    "clone",
7276}
7277
7278COMPARE_TENSOR_COMPONENT_EQUALITY = {
7279    # masked_select is expected to output a different shape
7280    "masked_select",
7281}
7282
7283
7284def withXFails(failure_list):
7285    return decorateIf(
7286        unittest.expectedFailure,
7287        lambda params: params["op"].full_name in failure_list,
7288    )
7289
7290
7291# OpInfo-based NJT tests. These tests utilize an NJT-specific op_db generated from the standard
7292# op_db. Note that certain tradeoffs were made wrt coverage vs. time spent running tests:
7293#   * All tests run with dtype=torch.float32 only
7294class TestNestedTensorOpInfo(NestedTensorTestCase):
7295    # TODO: move this
7296    def _gen_grad_outputs(self, out_val):
7297        if isinstance(out_val, (list, tuple)):
7298            return tuple(torch.ones_like(c) for c in out_val)
7299        else:
7300            return (torch.ones_like(out_val),)
7301
7302    @withXFails(FORWARD_FAILURES)
7303    @ops([op for op in njt_op_db if op.supports_njt], allowed_dtypes=(torch.float32,))
7304    def test_forward(self, device, dtype, op):
7305        for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=False):
7306            # compare to reference, but expect different nested int
7307            out = op.op(sample.input, *sample.args, **sample.kwargs)
7308            out_ref = op.ref(op, sample)
7309            self.assertEqualIgnoringNestedInts(out, out_ref)
7310
7311    @withXFails(BACKWARD_FAILURES)
7312    @ops(
7313        [op for op in njt_op_db if op.supports_njt and op.supports_autograd],
7314        allowed_dtypes=(torch.float32,),
7315    )
7316    def test_backward(self, device, dtype, op):
7317        for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=True):
7318            # compare to reference, but expect different nested int
7319            out = op.op(sample.input, *sample.args, **sample.kwargs)
7320            out_ref = op.ref(op, sample)
7321            self.assertEqualIgnoringNestedInts(out, out_ref)
7322
7323            inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs))
7324            g_inps = [
7325                inp
7326                for inp in inps
7327                if isinstance(inp, torch.Tensor) and inp.requires_grad
7328            ]
7329            if len(g_inps) > 0:
7330                grads = torch.autograd.grad(
7331                    out, inputs=g_inps, grad_outputs=self._gen_grad_outputs(out)
7332                )
7333
7334                grads_ref = torch.autograd.grad(
7335                    out_ref,
7336                    inputs=g_inps,
7337                    grad_outputs=self._gen_grad_outputs(out_ref),
7338                )
7339
7340                self.assertEqual(grads, grads_ref)
7341
7342    @withXFails(COMPILE_FORWARD_FAILURES)
7343    @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
7344    @ops([op for op in njt_op_db if op.supports_njt], allowed_dtypes=(torch.float32,))
7345    def test_compile_forward(self, device, dtype, op):
7346        for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=False):
7347            torch.compiler.reset()
7348
7349            op_fn = op.op
7350
7351            def f(*args, **kwargs):
7352                return op_fn(*args, **kwargs)
7353
7354            compiled_f = torch.compile(
7355                f, fullgraph=True, backend="aot_eager_decomp_partition"
7356            )
7357
7358            out_ref = f(sample.input, *sample.args, **sample.kwargs)
7359            out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)
7360
7361            if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY:
7362                self.assertEqualIgnoringNestedInts(out_compile, out_ref)
7363            else:
7364                self.assertEqual(out_compile, out_ref)
7365
7366    @withXFails(BACKWARD_FAILURES)
7367    @ops(
7368        [op for op in njt_op_db if op.supports_njt and op.supports_autograd],
7369        allowed_dtypes=(torch.float32,),
7370    )
7371    @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
7372    def test_compile_backward(self, device, dtype, op):
7373        for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=True):
7374            torch.compiler.reset()
7375
7376            op_fn = op.op
7377
7378            def f(*args, **kwargs):
7379                return op_fn(*args, **kwargs)
7380
7381            compiled_f = torch.compile(
7382                f, fullgraph=True, backend="aot_eager_decomp_partition"
7383            )
7384
7385            out_ref = f(sample.input, *sample.args, **sample.kwargs)
7386            out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)
7387
7388            self.assertEqual(out_compile, out_ref)
7389
7390            inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs))
7391            g_inps = [
7392                inp
7393                for inp in inps
7394                if isinstance(inp, torch.Tensor) and inp.requires_grad
7395            ]
7396            if len(g_inps) > 0:
7397                grads_compile = torch.autograd.grad(
7398                    out_compile,
7399                    inputs=g_inps,
7400                    grad_outputs=self._gen_grad_outputs(out_compile),
7401                )
7402
7403                grads_ref = torch.autograd.grad(
7404                    out_ref, inputs=g_inps, grad_outputs=self._gen_grad_outputs(out_ref)
7405                )
7406
7407                self.assertEqual(grads_compile, grads_ref)
7408
7409
7410instantiate_parametrized_tests(TestNestedTensor)
7411instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
7412instantiate_device_type_tests(TestNestedTensorAutograd, globals())
7413instantiate_device_type_tests(TestNestedTensorSubclass, globals())
7414instantiate_device_type_tests(TestNestedTensorOpInfo, globals())
7415
7416if __name__ == "__main__":
7417    run_tests()
7418