xref: /aosp_15_r20/external/pytorch/test/test_nestedtensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nestedtensor"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport io
4*da0073e9SAndroid Build Coastguard Workerimport itertools
5*da0073e9SAndroid Build Coastguard Workerimport math
6*da0073e9SAndroid Build Coastguard Workerimport sys
7*da0073e9SAndroid Build Coastguard Workerimport unittest
8*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
9*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Tuple
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerimport numpy as np
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerimport torch
14*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo
15*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
16*da0073e9SAndroid Build Coastguard Workerimport torch.nn
17*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
18*da0073e9SAndroid Build Coastguard Workerfrom torch.nested._internal.nested_tensor import (
19*da0073e9SAndroid Build Coastguard Worker    buffer_from_jagged,
20*da0073e9SAndroid Build Coastguard Worker    jagged_from_list,
21*da0073e9SAndroid Build Coastguard Worker    nested_view_from_values_offsets,
22*da0073e9SAndroid Build Coastguard Worker    NestedTensor,
23*da0073e9SAndroid Build Coastguard Worker    ViewNestedFromBuffer,
24*da0073e9SAndroid Build Coastguard Worker)
25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import (
26*da0073e9SAndroid Build Coastguard Worker    PLATFORM_SUPPORTS_FUSED_ATTENTION,
27*da0073e9SAndroid Build Coastguard Worker    SM70OrLater,
28*da0073e9SAndroid Build Coastguard Worker    SM80OrLater,
29*da0073e9SAndroid Build Coastguard Worker)
30*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
31*da0073e9SAndroid Build Coastguard Worker    dtypes,
32*da0073e9SAndroid Build Coastguard Worker    dtypesIfCUDA,
33*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
34*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
35*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
36*da0073e9SAndroid Build Coastguard Worker    ops,
37*da0073e9SAndroid Build Coastguard Worker    PYTORCH_CUDA_MEMCHECK,
38*da0073e9SAndroid Build Coastguard Worker    skipCPUIf,
39*da0073e9SAndroid Build Coastguard Worker    skipCUDAIf,
40*da0073e9SAndroid Build Coastguard Worker    skipCUDAIfRocm,
41*da0073e9SAndroid Build Coastguard Worker    skipMeta,
42*da0073e9SAndroid Build Coastguard Worker)
43*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import floating_types_and_half
44*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
45*da0073e9SAndroid Build Coastguard Worker    decorateIf,
46*da0073e9SAndroid Build Coastguard Worker    freeze_rng_state,
47*da0073e9SAndroid Build Coastguard Worker    gradcheck,
48*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
49*da0073e9SAndroid Build Coastguard Worker    IS_FBCODE,
50*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
51*da0073e9SAndroid Build Coastguard Worker    markDynamoStrictTest,
52*da0073e9SAndroid Build Coastguard Worker    NestedTensorTestCase,
53*da0073e9SAndroid Build Coastguard Worker    parametrize,
54*da0073e9SAndroid Build Coastguard Worker    run_tests,
55*da0073e9SAndroid Build Coastguard Worker    skipIfSlowGradcheckEnv,
56*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
57*da0073e9SAndroid Build Coastguard Worker    subtest,
58*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
59*da0073e9SAndroid Build Coastguard Worker    xfailIfTorchDynamo,
60*da0073e9SAndroid Build Coastguard Worker)
61*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.opinfo.definitions.nested import njt_op_db
62*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_flatten
63*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker# Tests are ported from pytorch/nestedtensor.
67*da0073e9SAndroid Build Coastguard Worker# This makes porting as_nested_tensor easier in the future.
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Workerdef _iter_constructors():
71*da0073e9SAndroid Build Coastguard Worker    # yield as_nested_tensor
72*da0073e9SAndroid Build Coastguard Worker    yield torch.nested.nested_tensor
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker# Returns True if the function recompiles between inputs1 and inputs2 with the
76*da0073e9SAndroid Build Coastguard Worker# specified dynamic setting.
77*da0073e9SAndroid Build Coastguard Workerdef _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
78*da0073e9SAndroid Build Coastguard Worker    compile_count = [0]
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    def counter(gm, example_inputs):
81*da0073e9SAndroid Build Coastguard Worker        compile_count[0] += 1
82*da0073e9SAndroid Build Coastguard Worker        return gm
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic)
85*da0073e9SAndroid Build Coastguard Worker    compiled_f(*inputs1)
86*da0073e9SAndroid Build Coastguard Worker    compiled_f(*inputs2)
87*da0073e9SAndroid Build Coastguard Worker    return compile_count[0] > 1
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker# Helper function to generate a pair of random nested tensors
91*da0073e9SAndroid Build Coastguard Worker# one is contiguous, the other is not, but they appear to have same entries
92*da0073e9SAndroid Build Coastguard Worker# an output nested tensor consists of
93*da0073e9SAndroid Build Coastguard Worker# * `len(ragged_sizes)` matrices
94*da0073e9SAndroid Build Coastguard Worker# * matrices[i].shape == (20, ragged_sizes[i])
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Workerdef random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16):
98*da0073e9SAndroid Build Coastguard Worker    xs = []
99*da0073e9SAndroid Build Coastguard Worker    for size in ragged_sizes:
100*da0073e9SAndroid Build Coastguard Worker        xs.append(torch.randn((size, 20), device=device, dtype=dtype))
101*da0073e9SAndroid Build Coastguard Worker    # contiguous nested tensor
102*da0073e9SAndroid Build Coastguard Worker    ys = []
103*da0073e9SAndroid Build Coastguard Worker    for x in xs:
104*da0073e9SAndroid Build Coastguard Worker        ys.append(x.transpose(-1, -2))
105*da0073e9SAndroid Build Coastguard Worker    nt_contiguous = torch.nested.nested_tensor(ys)
106*da0073e9SAndroid Build Coastguard Worker    # noncontiguous nested tensor
107*da0073e9SAndroid Build Coastguard Worker    n = len(ragged_sizes)
108*da0073e9SAndroid Build Coastguard Worker    nt_noncontiguous = torch.nested.nested_tensor(xs).transpose(-1, -2)
109*da0073e9SAndroid Build Coastguard Worker    return nt_contiguous, nt_noncontiguous
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker# Helper functions to pad a noncontiguous nested tensor
113*da0073e9SAndroid Build Coastguard Worker# can be replaced once to_padded_tensor supports noncontiguous memory
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Workerdef noncontiguous_to_padded_tensor(input, shape=None):
117*da0073e9SAndroid Build Coastguard Worker    tensors = input.unbind()
118*da0073e9SAndroid Build Coastguard Worker    ntensors = len(tensors)
119*da0073e9SAndroid Build Coastguard Worker    assert ntensors > 0
120*da0073e9SAndroid Build Coastguard Worker    if shape is None:
121*da0073e9SAndroid Build Coastguard Worker        shape = []
122*da0073e9SAndroid Build Coastguard Worker        for size in tensors[0].shape:
123*da0073e9SAndroid Build Coastguard Worker            shape.append(size)
124*da0073e9SAndroid Build Coastguard Worker        for i in range(1, ntensors):
125*da0073e9SAndroid Build Coastguard Worker            new_shape = tensors[i].shape
126*da0073e9SAndroid Build Coastguard Worker            for j in range(len(shape)):
127*da0073e9SAndroid Build Coastguard Worker                shape[j] = max(shape[j], new_shape[j])
128*da0073e9SAndroid Build Coastguard Worker        shape = [ntensors] + shape
129*da0073e9SAndroid Build Coastguard Worker    result = tensors[0].new_zeros(shape)
130*da0073e9SAndroid Build Coastguard Worker    for itensor in range(ntensors):
131*da0073e9SAndroid Build Coastguard Worker        tensor = tensors[itensor]
132*da0073e9SAndroid Build Coastguard Worker        view = result[itensor]
133*da0073e9SAndroid Build Coastguard Worker        for idim in range(tensor.dim()):
134*da0073e9SAndroid Build Coastguard Worker            view = view.narrow(idim, 0, tensor.size(idim))
135*da0073e9SAndroid Build Coastguard Worker        view.copy_(tensor)
136*da0073e9SAndroid Build Coastguard Worker    return result
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker# Helper function to generate a random nested tensor
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Workerdef random_nt(
143*da0073e9SAndroid Build Coastguard Worker    device,
144*da0073e9SAndroid Build Coastguard Worker    dtype,
145*da0073e9SAndroid Build Coastguard Worker    num_tensors,
146*da0073e9SAndroid Build Coastguard Worker    max_dims,
147*da0073e9SAndroid Build Coastguard Worker    min_dims=None,
148*da0073e9SAndroid Build Coastguard Worker    layout=torch.strided,
149*da0073e9SAndroid Build Coastguard Worker    require_non_empty=True,
150*da0073e9SAndroid Build Coastguard Worker):
151*da0073e9SAndroid Build Coastguard Worker    if min_dims is None:
152*da0073e9SAndroid Build Coastguard Worker        min_dims = tuple([0] * len(max_dims))
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker    assert len(max_dims) == len(min_dims)
155*da0073e9SAndroid Build Coastguard Worker    for min_dim, max_dim in zip(min_dims, max_dims):
156*da0073e9SAndroid Build Coastguard Worker        assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim"
157*da0073e9SAndroid Build Coastguard Worker        assert min_dim >= 0, "random_nt: min_dim must be non-negative"
158*da0073e9SAndroid Build Coastguard Worker        if require_non_empty:
159*da0073e9SAndroid Build Coastguard Worker            assert not (
160*da0073e9SAndroid Build Coastguard Worker                min_dim == 0 and max_dim == 1
161*da0073e9SAndroid Build Coastguard Worker            ), "random_nt: zero cannot be the only possible value if require_non_empty is True"
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker    if require_non_empty:
164*da0073e9SAndroid Build Coastguard Worker        # Select a random idx that will be required to be non-empty
165*da0073e9SAndroid Build Coastguard Worker        non_zero_idx = torch.randint(low=0, high=num_tensors, size=(1,)).item()
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    ts1 = []
168*da0073e9SAndroid Build Coastguard Worker    for i, _ in enumerate(range(num_tensors)):
169*da0073e9SAndroid Build Coastguard Worker        tensor_dims = []
170*da0073e9SAndroid Build Coastguard Worker        for min_dim, max_dim in zip(min_dims, max_dims):
171*da0073e9SAndroid Build Coastguard Worker            new_min_dim = min_dim
172*da0073e9SAndroid Build Coastguard Worker            if require_non_empty and i == non_zero_idx and min_dim == 0:
173*da0073e9SAndroid Build Coastguard Worker                new_min_dim = 1
174*da0073e9SAndroid Build Coastguard Worker            tensor_dims.append(
175*da0073e9SAndroid Build Coastguard Worker                torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item()
176*da0073e9SAndroid Build Coastguard Worker            )
177*da0073e9SAndroid Build Coastguard Worker        t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
178*da0073e9SAndroid Build Coastguard Worker        ts1.append(t1)
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout)
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker# Alternate approach to generating a random NT.
184*da0073e9SAndroid Build Coastguard Worker# dims should be something like [5, None, 10], with None indicating that a
185*da0073e9SAndroid Build Coastguard Worker# random ragged structure should be used
186*da0073e9SAndroid Build Coastguard Workerdef random_nt_from_dims(
187*da0073e9SAndroid Build Coastguard Worker    dims, device=None, dtype=None, layout=torch.strided, requires_grad=False
188*da0073e9SAndroid Build Coastguard Worker):
189*da0073e9SAndroid Build Coastguard Worker    sizes = [
190*da0073e9SAndroid Build Coastguard Worker        [
191*da0073e9SAndroid Build Coastguard Worker            d if d is not None else torch.randint(2, 10, size=(1,)).item()
192*da0073e9SAndroid Build Coastguard Worker            for d in dims[1:]
193*da0073e9SAndroid Build Coastguard Worker        ]
194*da0073e9SAndroid Build Coastguard Worker        for d in range(dims[0])
195*da0073e9SAndroid Build Coastguard Worker    ]
196*da0073e9SAndroid Build Coastguard Worker    return torch.nested.nested_tensor(
197*da0073e9SAndroid Build Coastguard Worker        [torch.randn(*size) for size in sizes],
198*da0073e9SAndroid Build Coastguard Worker        device=device,
199*da0073e9SAndroid Build Coastguard Worker        dtype=dtype,
200*da0073e9SAndroid Build Coastguard Worker        layout=layout,
201*da0073e9SAndroid Build Coastguard Worker        requires_grad=requires_grad,
202*da0073e9SAndroid Build Coastguard Worker    )
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker# Creates an NT matching another NT's number of components and
206*da0073e9SAndroid Build Coastguard Worker# shape / ragged structure for all dims specified to be -1.
207*da0073e9SAndroid Build Coastguard Workerdef random_nt_from_similar(other, dims=None):
208*da0073e9SAndroid Build Coastguard Worker    if dims is None:
209*da0073e9SAndroid Build Coastguard Worker        return torch.randn_like(other)
210*da0073e9SAndroid Build Coastguard Worker    assert len(dims) == other.dim()
211*da0073e9SAndroid Build Coastguard Worker    assert dims[0] == -1 or dims[0] == other.size(0)
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker    ret_sizes = []
214*da0073e9SAndroid Build Coastguard Worker    for t in other.unbind():
215*da0073e9SAndroid Build Coastguard Worker        other_size = t.shape
216*da0073e9SAndroid Build Coastguard Worker        ret_size = []
217*da0073e9SAndroid Build Coastguard Worker        for i, d in enumerate(dims[1:]):
218*da0073e9SAndroid Build Coastguard Worker            if d == -1:
219*da0073e9SAndroid Build Coastguard Worker                ret_size.append(other_size[i])
220*da0073e9SAndroid Build Coastguard Worker            else:
221*da0073e9SAndroid Build Coastguard Worker                ret_size.append(d)
222*da0073e9SAndroid Build Coastguard Worker        ret_sizes.append(ret_size)
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker    return torch.nested.nested_tensor(
225*da0073e9SAndroid Build Coastguard Worker        [torch.randn(*size) for size in ret_sizes], device=other.device
226*da0073e9SAndroid Build Coastguard Worker    )
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker# makes naming nice for tests that parametrize over layout.
230*da0073e9SAndroid Build Coastguard Workerdef layout_name(layout):
231*da0073e9SAndroid Build Coastguard Worker    # e.g. "torch.jagged" -> "jagged"
232*da0073e9SAndroid Build Coastguard Worker    return layout.__repr__().split(".")[-1]
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Workerdef get_op_name(layout):
236*da0073e9SAndroid Build Coastguard Worker    # e.g. "<OpOverload(op='aten.sum', overload='dim_IntList')>" -> "sum"
237*da0073e9SAndroid Build Coastguard Worker    return layout.__name__.split(".")[0].split("_")[-1]
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker# Helper function for test_dummy_mha_with_nt
241*da0073e9SAndroid Build Coastguard Worker@torch.fx.wrap
242*da0073e9SAndroid Build Coastguard Workerdef convert_dense_to_nested_tensor_legacy(values):
243*da0073e9SAndroid Build Coastguard Worker    offsets = torch.arange(
244*da0073e9SAndroid Build Coastguard Worker        0, values.shape[0] * values.shape[1] + 1, values.shape[1], device=values.device
245*da0073e9SAndroid Build Coastguard Worker    )
246*da0073e9SAndroid Build Coastguard Worker    metadata_cache = {"max_seqlen": values.shape[1], "min_seqlen": 1}
247*da0073e9SAndroid Build Coastguard Worker    nt = ViewNestedFromBuffer.apply(
248*da0073e9SAndroid Build Coastguard Worker        values.view(-1, values.shape[-1]), offsets, metadata_cache
249*da0073e9SAndroid Build Coastguard Worker    )
250*da0073e9SAndroid Build Coastguard Worker    return nt
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker# Helper function for test_dummy_mha_with_nt
254*da0073e9SAndroid Build Coastguard Worker@torch.fx.wrap
255*da0073e9SAndroid Build Coastguard Workerdef convert_jagged_to_nested_tensor_legacy(
256*da0073e9SAndroid Build Coastguard Worker    values: torch.Tensor, offsets: torch.Tensor, max_length: int
257*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor:
258*da0073e9SAndroid Build Coastguard Worker    metadata_cache = {"max_seqlen": max_length, "min_seqlen": 1}
259*da0073e9SAndroid Build Coastguard Worker    nt = ViewNestedFromBuffer.apply(values, offsets, metadata_cache)
260*da0073e9SAndroid Build Coastguard Worker    return nt
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker# Helper function for test_dummy_mha_with_nt
264*da0073e9SAndroid Build Coastguard Worker@torch.fx.wrap
265*da0073e9SAndroid Build Coastguard Workerdef convert_nt_to_jagged_legacy(nt):
266*da0073e9SAndroid Build Coastguard Worker    return buffer_from_jagged(nt)
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker# Helper function for test_dummy_mha_with_nt
270*da0073e9SAndroid Build Coastguard Worker@torch.fx.wrap
271*da0073e9SAndroid Build Coastguard Workerdef convert_dense_to_nested_tensor(values):
272*da0073e9SAndroid Build Coastguard Worker    nt = torch.nested.as_nested_tensor(values, layout=torch.jagged)
273*da0073e9SAndroid Build Coastguard Worker    return nt
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker# Helper function for test_dummy_mha_with_nt
277*da0073e9SAndroid Build Coastguard Worker@torch.fx.wrap
278*da0073e9SAndroid Build Coastguard Workerdef convert_jagged_to_nested_tensor(
279*da0073e9SAndroid Build Coastguard Worker    values: torch.Tensor, offsets: torch.Tensor, max_length: int
280*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor:
281*da0073e9SAndroid Build Coastguard Worker    nt = torch.nested.nested_tensor_from_jagged(
282*da0073e9SAndroid Build Coastguard Worker        values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length
283*da0073e9SAndroid Build Coastguard Worker    )
284*da0073e9SAndroid Build Coastguard Worker    return nt
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker# Helper function for test_dummy_mha_with_nt
288*da0073e9SAndroid Build Coastguard Workerdef convert_nt_to_jagged(nt):
289*da0073e9SAndroid Build Coastguard Worker    return nt.values()
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker@markDynamoStrictTest
293*da0073e9SAndroid Build Coastguard Workerclass TestNestedTensor(NestedTensorTestCase):
294*da0073e9SAndroid Build Coastguard Worker    @parametrize("batch_size", [2, 4])
295*da0073e9SAndroid Build Coastguard Worker    @parametrize("max_seq_len", [3, 5])
296*da0073e9SAndroid Build Coastguard Worker    @parametrize("vocab_size", [10, 20])
297*da0073e9SAndroid Build Coastguard Worker    def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
298*da0073e9SAndroid Build Coastguard Worker        data = []
299*da0073e9SAndroid Build Coastguard Worker        nested_tensor_ref_list = []
300*da0073e9SAndroid Build Coastguard Worker        for _ in range(batch_size):
301*da0073e9SAndroid Build Coastguard Worker            if max_seq_len == 0:
302*da0073e9SAndroid Build Coastguard Worker                length = 0
303*da0073e9SAndroid Build Coastguard Worker            else:
304*da0073e9SAndroid Build Coastguard Worker                length = np.random.randint(low=1, high=max_seq_len)
305*da0073e9SAndroid Build Coastguard Worker            row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
306*da0073e9SAndroid Build Coastguard Worker            data.append(row)
307*da0073e9SAndroid Build Coastguard Worker            nested_tensor_ref_list.append(torch.Tensor(row))
308*da0073e9SAndroid Build Coastguard Worker        nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
309*da0073e9SAndroid Build Coastguard Worker        nested_tensor_list = nested_tensor.unbind()
310*da0073e9SAndroid Build Coastguard Worker        for id in range(batch_size):
311*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
312*da0073e9SAndroid Build Coastguard Worker                nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64)
313*da0073e9SAndroid Build Coastguard Worker            )
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker    @parametrize("batch_size", [2, 4])
316*da0073e9SAndroid Build Coastguard Worker    @parametrize("max_seq_len", [3, 5])
317*da0073e9SAndroid Build Coastguard Worker    @parametrize("vocab_size", [10, 20])
318*da0073e9SAndroid Build Coastguard Worker    def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
319*da0073e9SAndroid Build Coastguard Worker        data = []
320*da0073e9SAndroid Build Coastguard Worker        nested_tensor_ref_list = []
321*da0073e9SAndroid Build Coastguard Worker        for _ in range(batch_size):
322*da0073e9SAndroid Build Coastguard Worker            if max_seq_len == 0:
323*da0073e9SAndroid Build Coastguard Worker                length = 0
324*da0073e9SAndroid Build Coastguard Worker            else:
325*da0073e9SAndroid Build Coastguard Worker                length = np.random.randint(low=1, high=max_seq_len)
326*da0073e9SAndroid Build Coastguard Worker            row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
327*da0073e9SAndroid Build Coastguard Worker            row = [list(item * np.arange(max_seq_len)) for item in row]
328*da0073e9SAndroid Build Coastguard Worker            data.append(row)
329*da0073e9SAndroid Build Coastguard Worker            nested_tensor_ref_list.append(torch.Tensor(row))
330*da0073e9SAndroid Build Coastguard Worker        nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
331*da0073e9SAndroid Build Coastguard Worker        nested_tensor_list = nested_tensor.unbind()
332*da0073e9SAndroid Build Coastguard Worker        for id in range(batch_size):
333*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
334*da0073e9SAndroid Build Coastguard Worker                nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64)
335*da0073e9SAndroid Build Coastguard Worker            )
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker    @parametrize("batch_size", [2, 4])
338*da0073e9SAndroid Build Coastguard Worker    @parametrize("max_seq_len", [3, 5])
339*da0073e9SAndroid Build Coastguard Worker    @parametrize("vocab_size", [10, 20])
340*da0073e9SAndroid Build Coastguard Worker    def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size):
341*da0073e9SAndroid Build Coastguard Worker        data = []
342*da0073e9SAndroid Build Coastguard Worker        nested_tensor_ref_list = []
343*da0073e9SAndroid Build Coastguard Worker        for _ in range(batch_size):
344*da0073e9SAndroid Build Coastguard Worker            if max_seq_len == 0:
345*da0073e9SAndroid Build Coastguard Worker                length = 0
346*da0073e9SAndroid Build Coastguard Worker            else:
347*da0073e9SAndroid Build Coastguard Worker                length = np.random.randint(low=1, high=max_seq_len)
348*da0073e9SAndroid Build Coastguard Worker            row = list(
349*da0073e9SAndroid Build Coastguard Worker                np.random.randint(low=0, high=vocab_size, size=(length,)).astype(float)
350*da0073e9SAndroid Build Coastguard Worker            )
351*da0073e9SAndroid Build Coastguard Worker            row = [list(item * np.arange(max_seq_len)) for item in row]
352*da0073e9SAndroid Build Coastguard Worker            data.append(row)
353*da0073e9SAndroid Build Coastguard Worker            nested_tensor_ref_list.append(torch.Tensor(row))
354*da0073e9SAndroid Build Coastguard Worker        nested_tensor = torch.nested.nested_tensor(data, dtype=torch.float)
355*da0073e9SAndroid Build Coastguard Worker        nested_tensor_list = nested_tensor.unbind()
356*da0073e9SAndroid Build Coastguard Worker        for id in range(batch_size):
357*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
358*da0073e9SAndroid Build Coastguard Worker                nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.float)
359*da0073e9SAndroid Build Coastguard Worker            )
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
362*da0073e9SAndroid Build Coastguard Worker    def _test_unbind_case(self, a, b):
363*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([a, b])
364*da0073e9SAndroid Build Coastguard Worker        a1, b1 = nt.unbind()
365*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a is not a1)
366*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(b is not b1)
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([a, b], dtype=a.dtype)
369*da0073e9SAndroid Build Coastguard Worker        a1, b1 = nt.unbind(0)
370*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, a1)
371*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b, b1)
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker        a = torch.randn((2, 3)).add_(1)
374*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([a])
375*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, nt.unbind(0)[0])
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
378*da0073e9SAndroid Build Coastguard Worker    def test_unbind_0(self):
379*da0073e9SAndroid Build Coastguard Worker        self._test_unbind_case(torch.tensor([1, 2]), torch.tensor([7, 8]))
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
382*da0073e9SAndroid Build Coastguard Worker    def test_unbind_1(self):
383*da0073e9SAndroid Build Coastguard Worker        self._test_unbind_case(torch.tensor([1]), torch.tensor([7]))
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
386*da0073e9SAndroid Build Coastguard Worker    def test_unbind_3(self):
387*da0073e9SAndroid Build Coastguard Worker        self._test_unbind_case(torch.tensor([1.0]), torch.tensor([]))
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
390*da0073e9SAndroid Build Coastguard Worker    def test_unbind_4(self):
391*da0073e9SAndroid Build Coastguard Worker        self._test_unbind_case(torch.tensor([]), torch.tensor([]))
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
394*da0073e9SAndroid Build Coastguard Worker    def test_unbind_dim(self):
395*da0073e9SAndroid Build Coastguard Worker        def _test_fn(unbind_fn):
396*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(3, 2)
397*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(2, 3)
398*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor([a, b])
399*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1))
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        # Both of these tests are necessary, because we're using
402*da0073e9SAndroid Build Coastguard Worker        # torch_function.
403*da0073e9SAndroid Build Coastguard Worker        _test_fn(lambda x, dim: x.unbind(dim))
404*da0073e9SAndroid Build Coastguard Worker        # TODO: Re-enable this once using torch_dispatch
405*da0073e9SAndroid Build Coastguard Worker        # _test_fn(lambda x, dim: torch.unbind(x, dim))
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
408*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor(self):
409*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
410*da0073e9SAndroid Build Coastguard Worker            TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0]))
411*da0073e9SAndroid Build Coastguard Worker        )
412*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0))
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
415*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_matching_dim(self):
416*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
417*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
418*da0073e9SAndroid Build Coastguard Worker            "Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.",
419*da0073e9SAndroid Build Coastguard Worker            lambda: torch.nested.nested_tensor([torch.tensor(1.0), torch.tensor([])]),
420*da0073e9SAndroid Build Coastguard Worker        )
421*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
422*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
423*da0073e9SAndroid Build Coastguard Worker            "Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.",
424*da0073e9SAndroid Build Coastguard Worker            lambda: torch.nested.nested_tensor(
425*da0073e9SAndroid Build Coastguard Worker                [torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])]
426*da0073e9SAndroid Build Coastguard Worker            ),
427*da0073e9SAndroid Build Coastguard Worker        )
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
430*da0073e9SAndroid Build Coastguard Worker    def test_default_nested_tensor(self):
431*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: torch.nested.nested_tensor())
432*da0073e9SAndroid Build Coastguard Worker        default_nested_tensor = torch.nested.nested_tensor([])
433*da0073e9SAndroid Build Coastguard Worker        default_tensor = torch.tensor([])
434*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(default_nested_tensor.nested_dim(), 1)
435*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(default_nested_tensor.nested_size(), ())
436*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(default_nested_tensor.dim(), default_tensor.dim())
437*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(default_nested_tensor.layout, default_tensor.layout)
438*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(default_nested_tensor.device, default_tensor.device)
439*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype)
440*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
441*da0073e9SAndroid Build Coastguard Worker            default_nested_tensor.requires_grad, default_tensor.requires_grad
442*da0073e9SAndroid Build Coastguard Worker        )
443*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(default_tensor.grad)
444*da0073e9SAndroid Build Coastguard Worker        # TODO: Re-enable once we have a performance driven
445*da0073e9SAndroid Build Coastguard Worker        # use case and implementation.
446*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(default_nested_tensor.is_pinned(),
447*da0073e9SAndroid Build Coastguard Worker        #                  default_tensor.is_pinned())
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
450*da0073e9SAndroid Build Coastguard Worker    def test_dim(self):
451*da0073e9SAndroid Build Coastguard Worker        for constructor in _iter_constructors():
452*da0073e9SAndroid Build Coastguard Worker            a1 = constructor([])
453*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a1.dim(), 1)
454*da0073e9SAndroid Build Coastguard Worker            a1 = constructor([torch.tensor(3.0)])
455*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a1.dim(), 1)
456*da0073e9SAndroid Build Coastguard Worker            a1 = constructor([torch.tensor([1, 2, 3, 4])])
457*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a1.dim(), 2)
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.")
460*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
461*da0073e9SAndroid Build Coastguard Worker    def test_numel(self):
462*da0073e9SAndroid Build Coastguard Worker        for constructor in _iter_constructors():
463*da0073e9SAndroid Build Coastguard Worker            a1 = constructor([])
464*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a1.numel(), 0)
465*da0073e9SAndroid Build Coastguard Worker            a1 = constructor([torch.tensor(3.0), torch.tensor(4.0)])
466*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a1.numel(), 2)
467*da0073e9SAndroid Build Coastguard Worker            a1 = constructor([torch.randn(2, 2, 2)])
468*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a1.numel(), 8)
469*da0073e9SAndroid Build Coastguard Worker            a1 = constructor([torch.randn([1, 2, 3]), torch.randn(3, 2, 1)])
470*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a1.numel(), 12)
471*da0073e9SAndroid Build Coastguard Worker            a1 = constructor([torch.randn([1, 1, 3]), torch.randn(3, 2, 4)])
472*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a1.numel(), 27)
473*da0073e9SAndroid Build Coastguard Worker            a1 = constructor([torch.randn([5, 5, 5]), torch.randn(6, 6, 6)])
474*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a1.numel(), 341)
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker            # Interesting edge case
477*da0073e9SAndroid Build Coastguard Worker            a1 = constructor([torch.randn([1, 2, 3]), torch.randn(1, 2, 0)])
478*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a1.numel(), 6)
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
481*da0073e9SAndroid Build Coastguard Worker    def test_size(self):
482*da0073e9SAndroid Build Coastguard Worker        for constructor in _iter_constructors():
483*da0073e9SAndroid Build Coastguard Worker            a1 = constructor([])
484*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(
485*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
486*da0073e9SAndroid Build Coastguard Worker                "NestedTensorImpl doesn't support sizes",
487*da0073e9SAndroid Build Coastguard Worker                lambda: a1.size(),
488*da0073e9SAndroid Build Coastguard Worker            )
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker    def test_size_dim(self):
491*da0073e9SAndroid Build Coastguard Worker        a = torch.nested.nested_tensor([])
492*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.size(0), 0)
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker        a = torch.nested.nested_tensor([torch.tensor(1)])
495*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.size(0), 1)
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker        a = torch.nested.nested_tensor([torch.tensor(1), torch.tensor(2)])
498*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.size(0), 2)
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker        a = torch.nested.nested_tensor([torch.rand(1, 2), torch.rand(1, 8)])
501*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.size(0), 2)
502*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.size(1), 1)
503*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
504*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
505*da0073e9SAndroid Build Coastguard Worker            "Given dimension 2 is irregular and does not have a size",
506*da0073e9SAndroid Build Coastguard Worker            lambda: a.size(2),
507*da0073e9SAndroid Build Coastguard Worker        )
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker        a = torch.nested.nested_tensor([torch.rand(3, 4), torch.rand(5, 4)])
510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.size(0), 2)
511*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
512*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
513*da0073e9SAndroid Build Coastguard Worker            "Given dimension 1 is irregular and does not have a size",
514*da0073e9SAndroid Build Coastguard Worker            lambda: a.size(1),
515*da0073e9SAndroid Build Coastguard Worker        )
516*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.size(2), 4)
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.")
519*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
520*da0073e9SAndroid Build Coastguard Worker    def test_stride(self):
521*da0073e9SAndroid Build Coastguard Worker        for constructor in _iter_constructors():
522*da0073e9SAndroid Build Coastguard Worker            a1 = constructor([])
523*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(
524*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
525*da0073e9SAndroid Build Coastguard Worker                "NestedTensorImpl doesn't support strides",
526*da0073e9SAndroid Build Coastguard Worker                lambda: a1.stride(),
527*da0073e9SAndroid Build Coastguard Worker            )
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.")
530*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
531*da0073e9SAndroid Build Coastguard Worker    def test_is_contiguous(self):
532*da0073e9SAndroid Build Coastguard Worker        # Test empty case
533*da0073e9SAndroid Build Coastguard Worker        nt_empty = torch.nested.nested_tensor([])
534*da0073e9SAndroid Build Coastguard Worker        assert nt_empty.is_contiguous()
535*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt_empty, nt_empty.contiguous())
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker        # Test contiguous case
540*da0073e9SAndroid Build Coastguard Worker        assert nt_contiguous.is_contiguous()
541*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt_contiguous, nt_contiguous.contiguous())
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker        # Test non_contiguous case
544*da0073e9SAndroid Build Coastguard Worker        assert not nt_noncontiguous.is_contiguous()
545*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous())
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker        # Test querying by memory_format
548*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
549*da0073e9SAndroid Build Coastguard Worker            nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)
550*da0073e9SAndroid Build Coastguard Worker        )
551*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
552*da0073e9SAndroid Build Coastguard Worker            not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)
553*da0073e9SAndroid Build Coastguard Worker        )
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
556*da0073e9SAndroid Build Coastguard Worker    def test_repr_string(self):
557*da0073e9SAndroid Build Coastguard Worker        a = torch.nested.nested_tensor([])
558*da0073e9SAndroid Build Coastguard Worker        expected = "nested_tensor([\n\n])"
559*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(a), expected)
560*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(a), expected)
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker        a = torch.nested.nested_tensor([torch.tensor(1.0)])
563*da0073e9SAndroid Build Coastguard Worker        expected = "nested_tensor([\n  tensor(1.)\n])"
564*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(a), expected)
565*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(a), expected)
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker        a = torch.nested.nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])])
568*da0073e9SAndroid Build Coastguard Worker        expected = "nested_tensor([\n  tensor([[1, 2]]),\n  tensor([[4, 5]])\n])"
569*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(a), expected)
570*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(a), expected)
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker    def test_to_padded_tensor_on_empty_tensor(self):
573*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([])
574*da0073e9SAndroid Build Coastguard Worker        empty = torch.nested.to_padded_tensor(nt, 4)
575*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty, torch.tensor([]))
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker    def test_nested_namespace(self):
578*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([torch.randn(2, 3), torch.randn(4, 5)])
579*da0073e9SAndroid Build Coastguard Worker        result = nt.to_padded_tensor(4)
580*da0073e9SAndroid Build Coastguard Worker        nested_namespace_result = torch.nested.to_padded_tensor(nt, 4)
581*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, nested_namespace_result)
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker    def test_to(self):
584*da0073e9SAndroid Build Coastguard Worker        ntensors = 4
585*da0073e9SAndroid Build Coastguard Worker        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker        def test_copy_behavior(t, non_blocking=False):
588*da0073e9SAndroid Build Coastguard Worker            self.assertIs(t, t.to(t, non_blocking=non_blocking))
589*da0073e9SAndroid Build Coastguard Worker            self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
590*da0073e9SAndroid Build Coastguard Worker            self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking))
591*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True))
592*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True))
593*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(
594*da0073e9SAndroid Build Coastguard Worker                t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)
595*da0073e9SAndroid Build Coastguard Worker            )
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker            devices = [t.device]
598*da0073e9SAndroid Build Coastguard Worker            if t.device.type == "cuda":
599*da0073e9SAndroid Build Coastguard Worker                if t.device.index == -1:
600*da0073e9SAndroid Build Coastguard Worker                    devices.append(f"cuda:{torch.cuda.current_device()}")
601*da0073e9SAndroid Build Coastguard Worker                elif t.device.index == torch.cuda.current_device():
602*da0073e9SAndroid Build Coastguard Worker                    devices.append("cuda")
603*da0073e9SAndroid Build Coastguard Worker            for device in devices:
604*da0073e9SAndroid Build Coastguard Worker                self.assertIs(t, t.to(device, non_blocking=non_blocking))
605*da0073e9SAndroid Build Coastguard Worker                self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking))
606*da0073e9SAndroid Build Coastguard Worker                self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True))
607*da0073e9SAndroid Build Coastguard Worker                self.assertIsNot(
608*da0073e9SAndroid Build Coastguard Worker                    t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True)
609*da0073e9SAndroid Build Coastguard Worker                )
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker        test_copy_behavior(nt)
612*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.device, nt.to("cpu").device)
613*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.device, nt.to("cpu", dtype=torch.float32).device)
614*da0073e9SAndroid Build Coastguard Worker        self.assertIs(torch.float32, nt.to("cpu", dtype=torch.float32).dtype)
615*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.device, nt.to(torch.float32).device)
616*da0073e9SAndroid Build Coastguard Worker        self.assertIs(torch.float32, nt.to(dtype=torch.float32).dtype)
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker        def test_data_ptr(getter):
619*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(getter(nt), getter(nt.to("cpu")))
620*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
621*da0073e9SAndroid Build Coastguard Worker                getter(nt), getter(nt.to(dtype=nt.dtype, device=nt.device, copy=False))
622*da0073e9SAndroid Build Coastguard Worker            )
623*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(getter(nt), getter(nt.to("cpu", copy=False)))
624*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(getter(nt), getter(nt.to("cpu", copy=True)))
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker        test_data_ptr(lambda nt: nt.data_ptr())
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
629*da0073e9SAndroid Build Coastguard Worker            for non_blocking in [True, False]:
630*da0073e9SAndroid Build Coastguard Worker                for cuda in [
631*da0073e9SAndroid Build Coastguard Worker                    "cuda",
632*da0073e9SAndroid Build Coastguard Worker                    "cuda:0" if torch.cuda.device_count() == 1 else "cuda:1",
633*da0073e9SAndroid Build Coastguard Worker                ]:
634*da0073e9SAndroid Build Coastguard Worker                    nt2 = random_nt(cuda, torch.float32, ntensors, (4, 4))
635*da0073e9SAndroid Build Coastguard Worker                    test_copy_behavior(nt2, non_blocking)
636*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
637*da0073e9SAndroid Build Coastguard Worker                        nt2.device, nt2.to(cuda, non_blocking=non_blocking).device
638*da0073e9SAndroid Build Coastguard Worker                    )
639*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
640*da0073e9SAndroid Build Coastguard Worker                        nt.device, nt2.to("cpu", non_blocking=non_blocking).device
641*da0073e9SAndroid Build Coastguard Worker                    )
642*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
643*da0073e9SAndroid Build Coastguard Worker                        nt2.device, nt.to(cuda, non_blocking=non_blocking).device
644*da0073e9SAndroid Build Coastguard Worker                    )
645*da0073e9SAndroid Build Coastguard Worker                    self.assertIs(
646*da0073e9SAndroid Build Coastguard Worker                        torch.int32,
647*da0073e9SAndroid Build Coastguard Worker                        nt2.to(
648*da0073e9SAndroid Build Coastguard Worker                            "cpu", dtype=torch.int32, non_blocking=non_blocking
649*da0073e9SAndroid Build Coastguard Worker                        ).dtype,
650*da0073e9SAndroid Build Coastguard Worker                    )
651*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
652*da0073e9SAndroid Build Coastguard Worker                        nt.device,
653*da0073e9SAndroid Build Coastguard Worker                        nt2.to(
654*da0073e9SAndroid Build Coastguard Worker                            "cpu", dtype=torch.int32, non_blocking=non_blocking
655*da0073e9SAndroid Build Coastguard Worker                        ).device,
656*da0073e9SAndroid Build Coastguard Worker                    )
657*da0073e9SAndroid Build Coastguard Worker                    self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype)
658*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device)
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker    def test_copy_(self):
661*da0073e9SAndroid Build Coastguard Worker        ntensors = 4
662*da0073e9SAndroid Build Coastguard Worker        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
663*da0073e9SAndroid Build Coastguard Worker        nt_copy = torch.empty_like(nt)
664*da0073e9SAndroid Build Coastguard Worker        nt_copy.copy_(nt)
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker        for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
667*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_ub, nt_copy_ub)
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker        nt_error = torch.nested.nested_tensor([torch.tensor([0, 0])])
670*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
671*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
672*da0073e9SAndroid Build Coastguard Worker            "copy_ only supports tensors that are the same size for Nested implementations",
673*da0073e9SAndroid Build Coastguard Worker            lambda: nt_error.copy_(nt),
674*da0073e9SAndroid Build Coastguard Worker        )
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
677*da0073e9SAndroid Build Coastguard Worker            nt = random_nt(torch.device("cuda"), torch.float32, ntensors, (4, 4))
678*da0073e9SAndroid Build Coastguard Worker            nt_copy = torch.empty_like(nt, device=torch.device("cpu"))
679*da0073e9SAndroid Build Coastguard Worker            nt_copy.copy_(nt, non_blocking=True)
680*da0073e9SAndroid Build Coastguard Worker            torch.cuda.current_stream(torch.cuda.current_device()).synchronize()
681*da0073e9SAndroid Build Coastguard Worker            for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
682*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(nt_ub, nt_copy_ub)
683*da0073e9SAndroid Build Coastguard Worker
684*da0073e9SAndroid Build Coastguard Worker            nt_copy = torch.empty_like(nt, device=torch.device("cpu"))
685*da0073e9SAndroid Build Coastguard Worker            nt_copy.copy_(nt, non_blocking=False)
686*da0073e9SAndroid Build Coastguard Worker            for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
687*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(nt_ub, nt_copy_ub)
688*da0073e9SAndroid Build Coastguard Worker
689*da0073e9SAndroid Build Coastguard Worker    def test_fill_(self):
690*da0073e9SAndroid Build Coastguard Worker        ntensors = 4
691*da0073e9SAndroid Build Coastguard Worker        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
692*da0073e9SAndroid Build Coastguard Worker        nt.fill_(10.0)
693*da0073e9SAndroid Build Coastguard Worker        for nt_ub in nt.unbind():
694*da0073e9SAndroid Build Coastguard Worker            t = torch.empty_like(nt_ub)
695*da0073e9SAndroid Build Coastguard Worker            t.fill_(10.0)
696*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_ub, t)
697*da0073e9SAndroid Build Coastguard Worker
698*da0073e9SAndroid Build Coastguard Worker        fill_tensor = torch.tensor([11.0])
699*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
700*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
701*da0073e9SAndroid Build Coastguard Worker            "fill_ only supports 0-dimension value tensor",
702*da0073e9SAndroid Build Coastguard Worker            lambda: nt.fill_(fill_tensor),
703*da0073e9SAndroid Build Coastguard Worker        )
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker        nt.fill_(fill_tensor[0])
706*da0073e9SAndroid Build Coastguard Worker        for nt_ub in nt.unbind():
707*da0073e9SAndroid Build Coastguard Worker            t = torch.empty_like(nt_ub)
708*da0073e9SAndroid Build Coastguard Worker            t.fill_(11.0)
709*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_ub, t)
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker    def test_zero_(self):
712*da0073e9SAndroid Build Coastguard Worker        ntensors = 4
713*da0073e9SAndroid Build Coastguard Worker        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
714*da0073e9SAndroid Build Coastguard Worker        nt.zero_()
715*da0073e9SAndroid Build Coastguard Worker        for nt_ub in nt.unbind():
716*da0073e9SAndroid Build Coastguard Worker            t = torch.empty_like(nt_ub)
717*da0073e9SAndroid Build Coastguard Worker            t.fill_(0.0)
718*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_ub, t)
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker    @parametrize(
721*da0073e9SAndroid Build Coastguard Worker        "func",
722*da0073e9SAndroid Build Coastguard Worker        [torch.ones_like, torch.zeros_like, torch.randn_like],
723*da0073e9SAndroid Build Coastguard Worker        name_fn=lambda f: f.__name__,
724*da0073e9SAndroid Build Coastguard Worker    )
725*da0073e9SAndroid Build Coastguard Worker    def test_like_functions(self, func):
726*da0073e9SAndroid Build Coastguard Worker        ntensors = 4
727*da0073e9SAndroid Build Coastguard Worker        nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
728*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
729*da0073e9SAndroid Build Coastguard Worker        nt_like = func(nt)
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
732*da0073e9SAndroid Build Coastguard Worker        for nt_ub in nt_like.unbind():
733*da0073e9SAndroid Build Coastguard Worker            t_like = func(nt_ub)
734*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_ub, t_like)
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker    def test_cat(self):
737*da0073e9SAndroid Build Coastguard Worker        # dim=0 success case
738*da0073e9SAndroid Build Coastguard Worker        # No constraints on ragged structures matching.
739*da0073e9SAndroid Build Coastguard Worker        x = random_nt_from_dims([5, None, 10])
740*da0073e9SAndroid Build Coastguard Worker        y = random_nt_from_dims([3, 4, None])
741*da0073e9SAndroid Build Coastguard Worker        output = torch.cat([x, y], dim=0)
742*da0073e9SAndroid Build Coastguard Worker        for out_component, xy_component in zip(
743*da0073e9SAndroid Build Coastguard Worker            output.unbind(), itertools.chain(x.unbind(), y.unbind())
744*da0073e9SAndroid Build Coastguard Worker        ):
745*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_component, xy_component)
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Worker        # dim=-1 success case
748*da0073e9SAndroid Build Coastguard Worker        # shape (B, *, D)
749*da0073e9SAndroid Build Coastguard Worker        x = random_nt_from_dims([5, None, 10])
750*da0073e9SAndroid Build Coastguard Worker        # shape (B, *, D'); same structure as x but dim=-1 differs
751*da0073e9SAndroid Build Coastguard Worker        y = random_nt_from_similar(x, dims=[-1, -1, 8])
752*da0073e9SAndroid Build Coastguard Worker        # should be shape (B, *, D + D') when supported
753*da0073e9SAndroid Build Coastguard Worker        output = torch.cat([x, y], dim=-1)
754*da0073e9SAndroid Build Coastguard Worker        for out_component, x_component, y_component in zip(
755*da0073e9SAndroid Build Coastguard Worker            output.unbind(), x.unbind(), y.unbind()
756*da0073e9SAndroid Build Coastguard Worker        ):
757*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
758*da0073e9SAndroid Build Coastguard Worker                out_component, torch.cat([x_component, y_component], dim=-1)
759*da0073e9SAndroid Build Coastguard Worker            )
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker        # dim between 0 and -1 success case
762*da0073e9SAndroid Build Coastguard Worker        x = random_nt_from_dims([5, None, 2, 3])
763*da0073e9SAndroid Build Coastguard Worker        # same structure as x but dim=2 differs
764*da0073e9SAndroid Build Coastguard Worker        y = random_nt_from_similar(x, dims=[-1, -1, 4, -1])
765*da0073e9SAndroid Build Coastguard Worker        output = torch.cat([x, y], dim=2)
766*da0073e9SAndroid Build Coastguard Worker        for out_component, x_component, y_component in zip(
767*da0073e9SAndroid Build Coastguard Worker            output.unbind(), x.unbind(), y.unbind()
768*da0073e9SAndroid Build Coastguard Worker        ):
769*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
770*da0073e9SAndroid Build Coastguard Worker                out_component, torch.cat([x_component, y_component], dim=1)
771*da0073e9SAndroid Build Coastguard Worker            )
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker        # error case: mixed NT / dense inputs
774*da0073e9SAndroid Build Coastguard Worker        x = random_nt_from_dims([5, None, 2])
775*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 3, 2)
776*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
777*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "expected each tensor in given list to be nested"
778*da0073e9SAndroid Build Coastguard Worker        ):
779*da0073e9SAndroid Build Coastguard Worker            torch.cat([x, y], dim=-1)
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker        # error case: NTs with different dims
782*da0073e9SAndroid Build Coastguard Worker        x = random_nt_from_dims([5, None, 2])
783*da0073e9SAndroid Build Coastguard Worker        y = random_nt_from_dims([5, None, 2, 3])
784*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
785*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
786*da0073e9SAndroid Build Coastguard Worker            "expected all nested tensors to have matching ragged structures outside of the concatenated dim",
787*da0073e9SAndroid Build Coastguard Worker        ):
788*da0073e9SAndroid Build Coastguard Worker            torch.cat([x, y], dim=-1)
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Worker        # error case: non-contiguous NT
791*da0073e9SAndroid Build Coastguard Worker        x, y = random_nt_noncontiguous_pair((2, 3, 4), dtype=torch.float32)
792*da0073e9SAndroid Build Coastguard Worker        # transpose to put ragged dim next to batch dim
793*da0073e9SAndroid Build Coastguard Worker        x, y = x.transpose(-2, -1), y.transpose(-2, -1)
794*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
795*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "only contiguous nested tensors are supported"
796*da0073e9SAndroid Build Coastguard Worker        ):
797*da0073e9SAndroid Build Coastguard Worker            torch.cat([x, y], dim=-1)
798*da0073e9SAndroid Build Coastguard Worker
799*da0073e9SAndroid Build Coastguard Worker        # error case: multiple ragged dims in inputs
800*da0073e9SAndroid Build Coastguard Worker        x = random_nt_from_dims([5, None, None, 2])
801*da0073e9SAndroid Build Coastguard Worker        y = random_nt_from_similar(x)
802*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
803*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
804*da0073e9SAndroid Build Coastguard Worker            "only nested tensors with a single ragged dim next to the batch dim are supported",
805*da0073e9SAndroid Build Coastguard Worker        ):
806*da0073e9SAndroid Build Coastguard Worker            torch.cat([x, y], dim=-1)
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Worker        # error case: ragged dim not next to batch dim
809*da0073e9SAndroid Build Coastguard Worker        x = random_nt_from_dims([5, 2, None])
810*da0073e9SAndroid Build Coastguard Worker        y = random_nt_from_similar(x)
811*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
812*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
813*da0073e9SAndroid Build Coastguard Worker            "only nested tensors with a single ragged dim next to the batch dim are supported",
814*da0073e9SAndroid Build Coastguard Worker        ):
815*da0073e9SAndroid Build Coastguard Worker            torch.cat([x, y], dim=1)
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker        # error case: NTs with different batch sizes
818*da0073e9SAndroid Build Coastguard Worker        x = random_nt_from_dims([5, None, 2])
819*da0073e9SAndroid Build Coastguard Worker        y = random_nt_from_dims([3, None, 2])
820*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
821*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
822*da0073e9SAndroid Build Coastguard Worker            "expected all nested tensors to have matching ragged structures outside of the concatenated dim",
823*da0073e9SAndroid Build Coastguard Worker        ):
824*da0073e9SAndroid Build Coastguard Worker            torch.cat([x, y], dim=-1)
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker        # error case: NTs with different ragged structures
827*da0073e9SAndroid Build Coastguard Worker        x = torch.nested.nested_tensor(
828*da0073e9SAndroid Build Coastguard Worker            [
829*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 6),
830*da0073e9SAndroid Build Coastguard Worker                torch.randn(4, 6),
831*da0073e9SAndroid Build Coastguard Worker                torch.randn(5, 6),
832*da0073e9SAndroid Build Coastguard Worker            ]
833*da0073e9SAndroid Build Coastguard Worker        )
834*da0073e9SAndroid Build Coastguard Worker        y = torch.nested.nested_tensor(
835*da0073e9SAndroid Build Coastguard Worker            [
836*da0073e9SAndroid Build Coastguard Worker                torch.randn(5, 6),
837*da0073e9SAndroid Build Coastguard Worker                torch.randn(4, 6),
838*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 6),
839*da0073e9SAndroid Build Coastguard Worker            ]
840*da0073e9SAndroid Build Coastguard Worker        )
841*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
842*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
843*da0073e9SAndroid Build Coastguard Worker            "expected all nested tensors to have matching ragged structures outside of the concatenated dim",
844*da0073e9SAndroid Build Coastguard Worker        ):
845*da0073e9SAndroid Build Coastguard Worker            torch.cat([x, y], dim=-1)
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker@markDynamoStrictTest
849*da0073e9SAndroid Build Coastguard Workerclass TestNestedTensorDeviceType(NestedTensorTestCase):
850*da0073e9SAndroid Build Coastguard Worker    # Helper function to generate a pair of random nested tensors
851*da0073e9SAndroid Build Coastguard Worker    # the 2 nested tensors have same shapes
852*da0073e9SAndroid Build Coastguard Worker    def random_nt_pair(self, device, dtype, num_tensors, max_dims):
853*da0073e9SAndroid Build Coastguard Worker        ts1 = []
854*da0073e9SAndroid Build Coastguard Worker        ts2 = []
855*da0073e9SAndroid Build Coastguard Worker        for _ in range(num_tensors):
856*da0073e9SAndroid Build Coastguard Worker            tensor_dims = tuple(
857*da0073e9SAndroid Build Coastguard Worker                [
858*da0073e9SAndroid Build Coastguard Worker                    torch.randint(low=0, high=max_dim, size=(1,)).item()
859*da0073e9SAndroid Build Coastguard Worker                    for max_dim in max_dims
860*da0073e9SAndroid Build Coastguard Worker                ]
861*da0073e9SAndroid Build Coastguard Worker            )
862*da0073e9SAndroid Build Coastguard Worker            t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
863*da0073e9SAndroid Build Coastguard Worker            t2 = torch.randn(tensor_dims, device=device, dtype=dtype)
864*da0073e9SAndroid Build Coastguard Worker            ts1.append(t1)
865*da0073e9SAndroid Build Coastguard Worker            ts2.append(t2)
866*da0073e9SAndroid Build Coastguard Worker        return (
867*da0073e9SAndroid Build Coastguard Worker            torch.nested.nested_tensor(ts1, device=device, dtype=dtype),
868*da0073e9SAndroid Build Coastguard Worker            torch.nested.nested_tensor(ts2, device=device, dtype=dtype),
869*da0073e9SAndroid Build Coastguard Worker        )
870*da0073e9SAndroid Build Coastguard Worker
871*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and_half())
872*da0073e9SAndroid Build Coastguard Worker    def test_detach(self, device, dtype):
873*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False)
874*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False)
875*da0073e9SAndroid Build Coastguard Worker        x = torch.nested.nested_tensor([a, b], requires_grad=True)
876*da0073e9SAndroid Build Coastguard Worker
877*da0073e9SAndroid Build Coastguard Worker        x_detach = x.detach()
878*da0073e9SAndroid Build Coastguard Worker
879*da0073e9SAndroid Build Coastguard Worker        z = x_detach * 4
880*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x_detach.requires_grad)
881*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(z.requires_grad)
882*da0073e9SAndroid Build Coastguard Worker
883*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True)
884*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True)
885*da0073e9SAndroid Build Coastguard Worker        x = torch.nested.as_nested_tensor([a, b])
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker        y = x * 2
888*da0073e9SAndroid Build Coastguard Worker        y = y.detach()
889*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(y.requires_grad)
890*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(y.grad_fn)
891*da0073e9SAndroid Build Coastguard Worker
892*da0073e9SAndroid Build Coastguard Worker        z = x + y
893*da0073e9SAndroid Build Coastguard Worker        torch.nested.to_padded_tensor(z, 0).sum().backward()
894*da0073e9SAndroid Build Coastguard Worker        # This is an incorrect gradient, but we assume that's what the user
895*da0073e9SAndroid Build Coastguard Worker        # wanted. detach() is an advanced option.
896*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype))
897*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype))
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
900*da0073e9SAndroid Build Coastguard Worker    def test_unbind_noncontiguous(self, device, dtype):
901*da0073e9SAndroid Build Coastguard Worker        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
902*da0073e9SAndroid Build Coastguard Worker            (2, 3, 6, 7), device, dtype
903*da0073e9SAndroid Build Coastguard Worker        )
904*da0073e9SAndroid Build Coastguard Worker        ub_contiguous = nt_contiguous.unbind()
905*da0073e9SAndroid Build Coastguard Worker        ub_noncontiguous = nt_noncontiguous.unbind()
906*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(ub_contiguous), len(ub_noncontiguous))
907*da0073e9SAndroid Build Coastguard Worker        n = len(ub_contiguous)
908*da0073e9SAndroid Build Coastguard Worker        for i in range(n):
909*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ub_contiguous[i], ub_noncontiguous[i])
910*da0073e9SAndroid Build Coastguard Worker
911*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
912*da0073e9SAndroid Build Coastguard Worker    @skipMeta
913*da0073e9SAndroid Build Coastguard Worker    def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype):
914*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(4, 4, 4, device=device, dtype=dtype)
915*da0073e9SAndroid Build Coastguard Worker        ts = list(torch.unbind(t))
916*da0073e9SAndroid Build Coastguard Worker        ts[0] = ts[0][:-1]
917*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
918*da0073e9SAndroid Build Coastguard Worker        padded = torch.nested.to_padded_tensor(nt, 0)
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker        nt_to = torch._nested_from_padded_and_nested_example(padded, nt)
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Worker        for t1, t2 in zip(nt.unbind(), nt_to.unbind()):
923*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t1, t2)
924*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.device, nt_to.device)
925*da0073e9SAndroid Build Coastguard Worker
926*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
927*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float, torch.half)
928*da0073e9SAndroid Build Coastguard Worker    @skipMeta
929*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
930*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm(self, device, dtype):
931*da0073e9SAndroid Build Coastguard Worker        def _test(size):
932*da0073e9SAndroid Build Coastguard Worker            # Simple shapes test
933*da0073e9SAndroid Build Coastguard Worker            t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
934*da0073e9SAndroid Build Coastguard Worker            t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
935*da0073e9SAndroid Build Coastguard Worker            ts = [t0, t1, t0, t1]
936*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
937*da0073e9SAndroid Build Coastguard Worker            layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
938*da0073e9SAndroid Build Coastguard Worker            nt_result = layer_norm(nt)
939*da0073e9SAndroid Build Coastguard Worker            for nt_subresult, t in zip(nt_result.unbind(), ts):
940*da0073e9SAndroid Build Coastguard Worker                t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
941*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(nt_subresult, t_result)
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker            # More complex nt test with different lengths for each tensor
944*da0073e9SAndroid Build Coastguard Worker            t0 = torch.randn(4, size, device=device, dtype=dtype, requires_grad=False)
945*da0073e9SAndroid Build Coastguard Worker            t1 = torch.randn(10, size, device=device, dtype=dtype, requires_grad=False)
946*da0073e9SAndroid Build Coastguard Worker            t2 = torch.randn(7, size, device=device, dtype=dtype, requires_grad=False)
947*da0073e9SAndroid Build Coastguard Worker            ts = [t0, t1, t2, t0, t2]
948*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
949*da0073e9SAndroid Build Coastguard Worker            layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
950*da0073e9SAndroid Build Coastguard Worker            nt_result = layer_norm(nt)
951*da0073e9SAndroid Build Coastguard Worker            for nt_subresult, t in zip(nt_result.unbind(), ts):
952*da0073e9SAndroid Build Coastguard Worker                t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
953*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(nt_subresult, t_result)
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker            if size <= 128:
956*da0073e9SAndroid Build Coastguard Worker                # Test with multidimensional tensors after irregular dim
957*da0073e9SAndroid Build Coastguard Worker                # (run only with smaller dimensions to ensure fast execution)
958*da0073e9SAndroid Build Coastguard Worker                t0 = torch.randn(
959*da0073e9SAndroid Build Coastguard Worker                    4, size, size, 4, device=device, dtype=dtype, requires_grad=False
960*da0073e9SAndroid Build Coastguard Worker                )
961*da0073e9SAndroid Build Coastguard Worker                t1 = torch.randn(
962*da0073e9SAndroid Build Coastguard Worker                    10, size, size, 4, device=device, dtype=dtype, requires_grad=False
963*da0073e9SAndroid Build Coastguard Worker                )
964*da0073e9SAndroid Build Coastguard Worker                t2 = torch.randn(
965*da0073e9SAndroid Build Coastguard Worker                    7, size, size, 4, device=device, dtype=dtype, requires_grad=False
966*da0073e9SAndroid Build Coastguard Worker                )
967*da0073e9SAndroid Build Coastguard Worker                ts = [t0, t1, t2, t0, t2]
968*da0073e9SAndroid Build Coastguard Worker                nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
969*da0073e9SAndroid Build Coastguard Worker                layer_norm = torch.nn.LayerNorm(
970*da0073e9SAndroid Build Coastguard Worker                    (size, size, 4), device=device, dtype=dtype
971*da0073e9SAndroid Build Coastguard Worker                )
972*da0073e9SAndroid Build Coastguard Worker                nt_result = layer_norm(nt)
973*da0073e9SAndroid Build Coastguard Worker                for nt_subresult, t in zip(nt_result.unbind(), ts):
974*da0073e9SAndroid Build Coastguard Worker                    t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0))
975*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(nt_subresult, t_result)
976*da0073e9SAndroid Build Coastguard Worker
977*da0073e9SAndroid Build Coastguard Worker                # Test where the normalizing dimensions are not all
978*da0073e9SAndroid Build Coastguard Worker                layer_norm = torch.nn.LayerNorm((size, 4), device=device, dtype=dtype)
979*da0073e9SAndroid Build Coastguard Worker                nt_result = layer_norm(nt)
980*da0073e9SAndroid Build Coastguard Worker                for nt_subresult, t in zip(nt_result.unbind(), ts):
981*da0073e9SAndroid Build Coastguard Worker                    t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0))
982*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(nt_subresult, t_result)
983*da0073e9SAndroid Build Coastguard Worker
984*da0073e9SAndroid Build Coastguard Worker        for size in (1024, 1023, 513, 512, 256, 128, 2, 4, 32):
985*da0073e9SAndroid Build Coastguard Worker            _test(size)
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
988*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float, torch.half)
989*da0073e9SAndroid Build Coastguard Worker    @skipMeta
990*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
991*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm_breaking(self, device, dtype):
992*da0073e9SAndroid Build Coastguard Worker        size = 128
993*da0073e9SAndroid Build Coastguard Worker        t0 = torch.randn(
994*da0073e9SAndroid Build Coastguard Worker            4, size, size, 4, device=device, dtype=dtype, requires_grad=False
995*da0073e9SAndroid Build Coastguard Worker        )
996*da0073e9SAndroid Build Coastguard Worker        t1 = torch.randn(
997*da0073e9SAndroid Build Coastguard Worker            10, size, size, 4, device=device, dtype=dtype, requires_grad=False
998*da0073e9SAndroid Build Coastguard Worker        )
999*da0073e9SAndroid Build Coastguard Worker        t2 = torch.randn(
1000*da0073e9SAndroid Build Coastguard Worker            7, size, size, 4, device=device, dtype=dtype, requires_grad=False
1001*da0073e9SAndroid Build Coastguard Worker        )
1002*da0073e9SAndroid Build Coastguard Worker        ts = [t0, t1, t2, t0, t2]
1003*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1004*da0073e9SAndroid Build Coastguard Worker        layer_norm = torch.nn.LayerNorm((4, size, size, 4), device=device, dtype=dtype)
1005*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1006*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1007*da0073e9SAndroid Build Coastguard Worker            "normalized_shape extends into irregular dimensions for the nested tensor",
1008*da0073e9SAndroid Build Coastguard Worker            lambda: layer_norm(nt),
1009*da0073e9SAndroid Build Coastguard Worker        )
1010*da0073e9SAndroid Build Coastguard Worker        layer_norm = torch.nn.LayerNorm((size + 1, size, 4), device=device, dtype=dtype)
1011*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1012*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1013*da0073e9SAndroid Build Coastguard Worker            "The shape at dimension 0",
1014*da0073e9SAndroid Build Coastguard Worker            lambda: layer_norm(nt),
1015*da0073e9SAndroid Build Coastguard Worker        )
1016*da0073e9SAndroid Build Coastguard Worker
1017*da0073e9SAndroid Build Coastguard Worker    @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
1018*da0073e9SAndroid Build Coastguard Worker    def test_embedding(self, device, layout):
1019*da0073e9SAndroid Build Coastguard Worker        inputs = [
1020*da0073e9SAndroid Build Coastguard Worker            torch.randint(100, (L,), device=device, dtype=torch.int64)
1021*da0073e9SAndroid Build Coastguard Worker            for L in torch.randint(5, 50, (8,))
1022*da0073e9SAndroid Build Coastguard Worker        ]
1023*da0073e9SAndroid Build Coastguard Worker        x = torch.nested.nested_tensor(
1024*da0073e9SAndroid Build Coastguard Worker            inputs, device=device, dtype=torch.int64, layout=layout
1025*da0073e9SAndroid Build Coastguard Worker        )
1026*da0073e9SAndroid Build Coastguard Worker        emb = torch.nn.Embedding(100, 8, device=device)
1027*da0073e9SAndroid Build Coastguard Worker        y = emb(x)
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.disable
1030*da0073e9SAndroid Build Coastguard Worker        def check(inputs, y):
1031*da0073e9SAndroid Build Coastguard Worker            ys = y.unbind()
1032*da0073e9SAndroid Build Coastguard Worker            for i, inp in enumerate(inputs):
1033*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(emb(inp), ys[i])
1034*da0073e9SAndroid Build Coastguard Worker
1035*da0073e9SAndroid Build Coastguard Worker        check(inputs, y)
1036*da0073e9SAndroid Build Coastguard Worker
1037*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1038*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
1039*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and_half())
1040*da0073e9SAndroid Build Coastguard Worker    def test_masked_fill(self, device, dtype):
1041*da0073e9SAndroid Build Coastguard Worker        # nested tensor * nested tensor
1042*da0073e9SAndroid Build Coastguard Worker        (nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4))
1043*da0073e9SAndroid Build Coastguard Worker        mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()])
1044*da0073e9SAndroid Build Coastguard Worker        ref = torch.nested.nested_tensor(
1045*da0073e9SAndroid Build Coastguard Worker            [t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())]
1046*da0073e9SAndroid Build Coastguard Worker        )
1047*da0073e9SAndroid Build Coastguard Worker        out = nt.masked_fill(mask, 0)
1048*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, out)
1049*da0073e9SAndroid Build Coastguard Worker
1050*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16)
1051*da0073e9SAndroid Build Coastguard Worker    def test_to_padded_tensor_simple(self, device, dtype):
1052*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(4, 4, 4, device=device, dtype=dtype)
1053*da0073e9SAndroid Build Coastguard Worker        ts = list(torch.unbind(t))
1054*da0073e9SAndroid Build Coastguard Worker        ts[0] = ts[0][:-1]
1055*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1056*da0073e9SAndroid Build Coastguard Worker        for padding_value in (0, 1):
1057*da0073e9SAndroid Build Coastguard Worker            padded = torch.nested.to_padded_tensor(nt, padding_value)
1058*da0073e9SAndroid Build Coastguard Worker
1059*da0073e9SAndroid Build Coastguard Worker            correct_output = t.clone()
1060*da0073e9SAndroid Build Coastguard Worker            if padding_value == 0:
1061*da0073e9SAndroid Build Coastguard Worker                correct_output[0][-1] = torch.zeros_like(correct_output[0][-1])
1062*da0073e9SAndroid Build Coastguard Worker            else:
1063*da0073e9SAndroid Build Coastguard Worker                correct_output[0][-1] = torch.ones_like(correct_output[0][-1])
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(padded, correct_output)
1066*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(padded.device, torch.device(device))
1067*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(padded.dtype, dtype)
1068*da0073e9SAndroid Build Coastguard Worker
1069*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16)
1070*da0073e9SAndroid Build Coastguard Worker    def test_to_padded_tensor_output_size(self, device, dtype):
1071*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(4, 4, 4, device=device, dtype=dtype)
1072*da0073e9SAndroid Build Coastguard Worker        output_size = (4, 6, 5)
1073*da0073e9SAndroid Build Coastguard Worker        ts = list(torch.unbind(t))
1074*da0073e9SAndroid Build Coastguard Worker        ts[0] = ts[0][:-1]
1075*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1076*da0073e9SAndroid Build Coastguard Worker        for padding_value in (0, 1):
1077*da0073e9SAndroid Build Coastguard Worker            padded = torch.nested.to_padded_tensor(
1078*da0073e9SAndroid Build Coastguard Worker                nt, padding_value, output_size=output_size
1079*da0073e9SAndroid Build Coastguard Worker            )
1080*da0073e9SAndroid Build Coastguard Worker            correct_output = (
1081*da0073e9SAndroid Build Coastguard Worker                torch.ones(output_size, device=device, dtype=dtype) * padding_value
1082*da0073e9SAndroid Build Coastguard Worker            )
1083*da0073e9SAndroid Build Coastguard Worker            correct_output[:4:, :4, :4] = t.clone()
1084*da0073e9SAndroid Build Coastguard Worker            if padding_value == 0:
1085*da0073e9SAndroid Build Coastguard Worker                correct_output[0][3] = torch.zeros_like(correct_output[0][3])
1086*da0073e9SAndroid Build Coastguard Worker            else:
1087*da0073e9SAndroid Build Coastguard Worker                correct_output[0][3] = torch.ones_like(correct_output[0][3])
1088*da0073e9SAndroid Build Coastguard Worker
1089*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(padded, correct_output)
1090*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(padded.device, torch.device(device))
1091*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(padded.dtype, dtype)
1092*da0073e9SAndroid Build Coastguard Worker
1093*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
1094*da0073e9SAndroid Build Coastguard Worker    def test_to_padded_tensor_dim2(self, device, dtype):
1095*da0073e9SAndroid Build Coastguard Worker        ts = [
1096*da0073e9SAndroid Build Coastguard Worker            torch.randn(160, device=device, dtype=dtype),
1097*da0073e9SAndroid Build Coastguard Worker            torch.randn(1240, device=device, dtype=dtype),
1098*da0073e9SAndroid Build Coastguard Worker            torch.randn(2400, device=device, dtype=dtype),
1099*da0073e9SAndroid Build Coastguard Worker        ]
1100*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1101*da0073e9SAndroid Build Coastguard Worker        pad = 42
1102*da0073e9SAndroid Build Coastguard Worker        correct_output = []
1103*da0073e9SAndroid Build Coastguard Worker        for t in ts:
1104*da0073e9SAndroid Build Coastguard Worker            next_output = torch.ones_like(ts[2]) * pad
1105*da0073e9SAndroid Build Coastguard Worker            correct_output.append(next_output)
1106*da0073e9SAndroid Build Coastguard Worker            next_output[: t.size(0)].copy_(t)
1107*da0073e9SAndroid Build Coastguard Worker        correct_output = torch.stack(correct_output)
1108*da0073e9SAndroid Build Coastguard Worker        padded = torch.nested.to_padded_tensor(nt, pad)
1109*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(padded, correct_output)
1110*da0073e9SAndroid Build Coastguard Worker
1111*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
1112*da0073e9SAndroid Build Coastguard Worker    def test_to_padded_tensor_dim3(self, device, dtype):
1113*da0073e9SAndroid Build Coastguard Worker        ts = [
1114*da0073e9SAndroid Build Coastguard Worker            torch.randn(16, 21, device=device, dtype=dtype),
1115*da0073e9SAndroid Build Coastguard Worker            torch.randn(24, 32, device=device, dtype=dtype),
1116*da0073e9SAndroid Build Coastguard Worker            torch.randn(40, 53, device=device, dtype=dtype),
1117*da0073e9SAndroid Build Coastguard Worker        ]
1118*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1119*da0073e9SAndroid Build Coastguard Worker        pad = 42
1120*da0073e9SAndroid Build Coastguard Worker        correct_output = []
1121*da0073e9SAndroid Build Coastguard Worker        for t in ts:
1122*da0073e9SAndroid Build Coastguard Worker            next_output = torch.ones_like(ts[2]) * pad
1123*da0073e9SAndroid Build Coastguard Worker            correct_output.append(next_output)
1124*da0073e9SAndroid Build Coastguard Worker            next_output[: t.size(0), : t.size(1)].copy_(t)
1125*da0073e9SAndroid Build Coastguard Worker        correct_output = torch.stack(correct_output)
1126*da0073e9SAndroid Build Coastguard Worker        padded = torch.nested.to_padded_tensor(nt, pad)
1127*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(padded, correct_output)
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
1130*da0073e9SAndroid Build Coastguard Worker    def test_to_padded_tensor_dim4(self, device, dtype):
1131*da0073e9SAndroid Build Coastguard Worker        ts = [
1132*da0073e9SAndroid Build Coastguard Worker            torch.randn(16, 21, 13, device=device, dtype=dtype),
1133*da0073e9SAndroid Build Coastguard Worker            torch.randn(24, 32, 14, device=device, dtype=dtype),
1134*da0073e9SAndroid Build Coastguard Worker            torch.randn(40, 53, 16, device=device, dtype=dtype),
1135*da0073e9SAndroid Build Coastguard Worker        ]
1136*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1137*da0073e9SAndroid Build Coastguard Worker        pad = 42
1138*da0073e9SAndroid Build Coastguard Worker        correct_output = []
1139*da0073e9SAndroid Build Coastguard Worker        for t in ts:
1140*da0073e9SAndroid Build Coastguard Worker            next_output = torch.ones_like(ts[2]) * pad
1141*da0073e9SAndroid Build Coastguard Worker            correct_output.append(next_output)
1142*da0073e9SAndroid Build Coastguard Worker            next_output[: t.size(0), : t.size(1), : t.size(2)].copy_(t)
1143*da0073e9SAndroid Build Coastguard Worker        correct_output = torch.stack(correct_output)
1144*da0073e9SAndroid Build Coastguard Worker        padded = torch.nested.to_padded_tensor(nt, pad)
1145*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(padded, correct_output)
1146*da0073e9SAndroid Build Coastguard Worker
1147*da0073e9SAndroid Build Coastguard Worker    # TODO: test noncontiguous to_padded_tensor
1148*da0073e9SAndroid Build Coastguard Worker    # For now this tests the functionality of noncontiguous_to_padded_tensor
1149*da0073e9SAndroid Build Coastguard Worker    # and the error message of to_padded_tensor
1150*da0073e9SAndroid Build Coastguard Worker    # since to_padded_tensor does not support noncontiguous buffer yet
1151*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
1152*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
1153*da0073e9SAndroid Build Coastguard Worker    def test_to_padded_tensor_noncontiguous(self, device, dtype):
1154*da0073e9SAndroid Build Coastguard Worker        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
1155*da0073e9SAndroid Build Coastguard Worker            (2, 3, 6, 7), device, dtype
1156*da0073e9SAndroid Build Coastguard Worker        )
1157*da0073e9SAndroid Build Coastguard Worker        # test noncontiguous_to_padded_tensor functionality
1158*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1159*da0073e9SAndroid Build Coastguard Worker            torch.nested.to_padded_tensor(nt_contiguous, 0.0),
1160*da0073e9SAndroid Build Coastguard Worker            noncontiguous_to_padded_tensor(nt_noncontiguous),
1161*da0073e9SAndroid Build Coastguard Worker        )
1162*da0073e9SAndroid Build Coastguard Worker        # test to_padded_tensor error message
1163*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1164*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1165*da0073e9SAndroid Build Coastguard Worker            r"for now to_padded_tensor only supports contiguous nested tensor",
1166*da0073e9SAndroid Build Coastguard Worker            lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0),
1167*da0073e9SAndroid Build Coastguard Worker        )
1168*da0073e9SAndroid Build Coastguard Worker
1169*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1170*da0073e9SAndroid Build Coastguard Worker    def test_device_checks(self, device):
1171*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([], device=device)
1172*da0073e9SAndroid Build Coastguard Worker        is_cuda = "cuda" in str(device)
1173*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.is_cuda, is_cuda)
1174*da0073e9SAndroid Build Coastguard Worker
1175*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
1176*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_indexing(self, device, dtype):
1177*da0073e9SAndroid Build Coastguard Worker        # edge case: empty nested tensor
1178*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor([])
1179*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: nt0[0])
1180*da0073e9SAndroid Build Coastguard Worker        # normal case
1181*da0073e9SAndroid Build Coastguard Worker        x0 = torch.randn((2, 5), device=device, dtype=dtype)
1182*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn((3, 4), device=device, dtype=dtype)
1183*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([x0, x1])
1184*da0073e9SAndroid Build Coastguard Worker        # single index: only support integer in the batch dimension
1185*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt[0], x0)
1186*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt[-1], x1)
1187*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: nt[2])
1188*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: nt[-3])
1189*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, lambda: nt[:])
1190*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt[...], nt)
1191*da0073e9SAndroid Build Coastguard Worker        # tuple of indices: only support integer in the batch dimension
1192*da0073e9SAndroid Build Coastguard Worker        #                 + all possible indexing in the original tensor dimensions
1193*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt[0, 0, 0], x0[0, 0])
1194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt[0, 1, :], x0[1, :])
1195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt[1, ...], x1)
1196*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: nt[1, 4, 2])
1197*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1])
1198*da0073e9SAndroid Build Coastguard Worker        # test select on non-batch dimensions
1199*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.select(1, 0)[0], x0.select(0, 0))
1200*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.select(1, 0)[1], x1.select(0, 0))
1201*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: nt.select(1, 3))
1202*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.select(2, 0)[0], x0.select(1, 0))
1203*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.select(2, 0)[1], x1.select(1, 0))
1204*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: nt.select(2, 5))
1205*da0073e9SAndroid Build Coastguard Worker        # make sure indexing returns a view
1206*da0073e9SAndroid Build Coastguard Worker        nt[0].fill_(100.0)
1207*da0073e9SAndroid Build Coastguard Worker        answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5))
1208*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt[0], answer)
1209*da0073e9SAndroid Build Coastguard Worker        nt[1, 1, :].fill_(200.0)
1210*da0073e9SAndroid Build Coastguard Worker        answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4)
1211*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt[1, 1, :], answer)
1212*da0073e9SAndroid Build Coastguard Worker
1213*da0073e9SAndroid Build Coastguard Worker        # Test that indexing works when requires_grad_(True)
1214*da0073e9SAndroid Build Coastguard Worker        # previously this was failing because the backward kernel for select.int uses .sizes()
1215*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([x0, x1]).requires_grad_(True)
1216*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt[0], x0)
1217*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt[-1], x1)
1218*da0073e9SAndroid Build Coastguard Worker        grad_x0 = torch.randn((2, 5), device=device, dtype=dtype)
1219*da0073e9SAndroid Build Coastguard Worker        nt[0].backward(grad_x0)
1220*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.nested.nested_tensor(
1221*da0073e9SAndroid Build Coastguard Worker            [grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)]
1222*da0073e9SAndroid Build Coastguard Worker        )
1223*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.grad, expected_grad)
1224*da0073e9SAndroid Build Coastguard Worker
1225*da0073e9SAndroid Build Coastguard Worker    @parametrize(
1226*da0073e9SAndroid Build Coastguard Worker        "func",
1227*da0073e9SAndroid Build Coastguard Worker        [
1228*da0073e9SAndroid Build Coastguard Worker            subtest(torch.nn.functional.relu, name="relu"),
1229*da0073e9SAndroid Build Coastguard Worker            subtest(torch.nn.functional.relu_, name="relu_"),
1230*da0073e9SAndroid Build Coastguard Worker            subtest(torch.nn.functional.gelu, name="gelu"),
1231*da0073e9SAndroid Build Coastguard Worker            subtest(torch._C._nn.gelu_, name="gelu_"),
1232*da0073e9SAndroid Build Coastguard Worker            subtest(torch.tanh, name="tanh"),
1233*da0073e9SAndroid Build Coastguard Worker            subtest(torch.tanh_, name="tanh_"),
1234*da0073e9SAndroid Build Coastguard Worker            subtest(torch.neg, name="neg"),
1235*da0073e9SAndroid Build Coastguard Worker            subtest(torch.nn.functional.silu, name="silu"),
1236*da0073e9SAndroid Build Coastguard Worker            subtest(partial(torch.nn.functional.silu, inplace=True), name="silu_"),
1237*da0073e9SAndroid Build Coastguard Worker            subtest(torch.abs, name="abs"),
1238*da0073e9SAndroid Build Coastguard Worker            subtest(torch.abs_, name="abs_"),
1239*da0073e9SAndroid Build Coastguard Worker            subtest(torch.sgn, name="sgn"),
1240*da0073e9SAndroid Build Coastguard Worker            subtest(torch.logical_not, name="logical_not"),
1241*da0073e9SAndroid Build Coastguard Worker            subtest(torch.sin, name="sin"),
1242*da0073e9SAndroid Build Coastguard Worker            subtest(torch.cos, name="cos"),
1243*da0073e9SAndroid Build Coastguard Worker        ],
1244*da0073e9SAndroid Build Coastguard Worker    )
1245*da0073e9SAndroid Build Coastguard Worker    def test_activations(self, device, func):
1246*da0073e9SAndroid Build Coastguard Worker        nt, nt_noncontiguous = random_nt_noncontiguous_pair(
1247*da0073e9SAndroid Build Coastguard Worker            (2, 3, 6, 7), device=device, dtype=torch.float32
1248*da0073e9SAndroid Build Coastguard Worker        )
1249*da0073e9SAndroid Build Coastguard Worker        nested_result = func(nt)
1250*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nested_result.is_nested)
1251*da0073e9SAndroid Build Coastguard Worker        for t, t_res in zip(nt.unbind(), nested_result.unbind()):
1252*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(func(t), t_res)
1253*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1254*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1255*da0073e9SAndroid Build Coastguard Worker            "NestedTensor must be contiguous to get buffer.",
1256*da0073e9SAndroid Build Coastguard Worker            lambda: func(nt_noncontiguous),
1257*da0073e9SAndroid Build Coastguard Worker        )
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker    @parametrize("func", [subtest(torch.ge, name="ge"), subtest(torch.eq, name="eq")])
1260*da0073e9SAndroid Build Coastguard Worker    def test_binary_ops_with_scalar(self, device, func):
1261*da0073e9SAndroid Build Coastguard Worker        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
1262*da0073e9SAndroid Build Coastguard Worker            (2, 3, 6, 7), device=device, dtype=torch.float32
1263*da0073e9SAndroid Build Coastguard Worker        )
1264*da0073e9SAndroid Build Coastguard Worker        scalar = 0.0
1265*da0073e9SAndroid Build Coastguard Worker
1266*da0073e9SAndroid Build Coastguard Worker        # should work regardless of contiguity
1267*da0073e9SAndroid Build Coastguard Worker        for nt in (nt_contiguous, nt_noncontiguous):
1268*da0073e9SAndroid Build Coastguard Worker            nested_result = func(nt, scalar)
1269*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(nested_result.is_nested)
1270*da0073e9SAndroid Build Coastguard Worker            for t, t_res in zip(nt.unbind(), nested_result.unbind()):
1271*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(func(t, scalar), t_res)
1272*da0073e9SAndroid Build Coastguard Worker
1273*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and_half())
1274*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_chunk(self, device, dtype):
1275*da0073e9SAndroid Build Coastguard Worker        # Transformer use case
1276*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(3, 3 * 4, device=device, dtype=dtype)
1277*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 3 * 4, device=device, dtype=dtype)
1278*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(1, 3 * 4, device=device, dtype=dtype)
1279*da0073e9SAndroid Build Coastguard Worker        a_chunks = a.chunk(3, dim=-1)
1280*da0073e9SAndroid Build Coastguard Worker        b_chunks = b.chunk(3, dim=-1)
1281*da0073e9SAndroid Build Coastguard Worker        c_chunks = c.chunk(3, dim=-1)
1282*da0073e9SAndroid Build Coastguard Worker
1283*da0073e9SAndroid Build Coastguard Worker        a_nt = [a_chunks[0], b_chunks[0], c_chunks[0]]
1284*da0073e9SAndroid Build Coastguard Worker        b_nt = [a_chunks[1], b_chunks[1], c_chunks[1]]
1285*da0073e9SAndroid Build Coastguard Worker        c_nt = [a_chunks[2], b_chunks[2], c_chunks[2]]
1286*da0073e9SAndroid Build Coastguard Worker
1287*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([a, b, c])
1288*da0073e9SAndroid Build Coastguard Worker        chunked = nt.chunk(3, dim=-1)
1289*da0073e9SAndroid Build Coastguard Worker
1290*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chunked[0], torch.nested.nested_tensor(a_nt))
1291*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chunked[1], torch.nested.nested_tensor(b_nt))
1292*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chunked[2], torch.nested.nested_tensor(c_nt))
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker        for chunk in chunked:
1295*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(chunk.is_contiguous())
1296*da0073e9SAndroid Build Coastguard Worker
1297*da0073e9SAndroid Build Coastguard Worker        # Failure chunking on ragged dimensions
1298*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1299*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1300*da0073e9SAndroid Build Coastguard Worker            "Chunk for nested tensors is currently only supported for the last dimension.",
1301*da0073e9SAndroid Build Coastguard Worker            lambda: torch.chunk(nt, 5, dim=1),
1302*da0073e9SAndroid Build Coastguard Worker        )
1303*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1304*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1305*da0073e9SAndroid Build Coastguard Worker            "Chunk for nested tensors is currently only supported for the last dimension.",
1306*da0073e9SAndroid Build Coastguard Worker            lambda: torch.chunk(nt, 5, dim=0),
1307*da0073e9SAndroid Build Coastguard Worker        )
1308*da0073e9SAndroid Build Coastguard Worker
1309*da0073e9SAndroid Build Coastguard Worker        # Failure on non-contiguous nt
1310*da0073e9SAndroid Build Coastguard Worker        _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
1311*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1312*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1313*da0073e9SAndroid Build Coastguard Worker            "chunk expects `self` to be contiguous.",
1314*da0073e9SAndroid Build Coastguard Worker            lambda: torch.chunk(nt_noncontiguous, 5, dim=-1),
1315*da0073e9SAndroid Build Coastguard Worker        )
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker        # Failure when calling non divisible n_chunks
1318*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1319*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1320*da0073e9SAndroid Build Coastguard Worker            "Chunk for nested tensors is only supported for "
1321*da0073e9SAndroid Build Coastguard Worker            "nested tensors with trailing dimension divisible by chunks.",
1322*da0073e9SAndroid Build Coastguard Worker            lambda: torch.chunk(nt, 5, dim=-1),
1323*da0073e9SAndroid Build Coastguard Worker        )
1324*da0073e9SAndroid Build Coastguard Worker
1325*da0073e9SAndroid Build Coastguard Worker        # Failure when calling backward on a chunk
1326*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(3, 3 * 4, device=device, dtype=dtype, requires_grad=True)
1327*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 3 * 4, device=device, dtype=dtype, requires_grad=True)
1328*da0073e9SAndroid Build Coastguard Worker        nt_grad = torch.nested.as_nested_tensor([a, b])
1329*da0073e9SAndroid Build Coastguard Worker        chunked = torch.chunk(nt_grad, 2, dim=-1)
1330*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1331*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1332*da0073e9SAndroid Build Coastguard Worker            "Nested Strided Tensor doesn't support chunk backward.",
1333*da0073e9SAndroid Build Coastguard Worker            lambda: chunked[0].backward(chunked[0].clone()),
1334*da0073e9SAndroid Build Coastguard Worker        )
1335*da0073e9SAndroid Build Coastguard Worker
1336*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and_half())
1337*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_split_with_sizes(self, device, dtype):
1338*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(3, 20, device=device, dtype=dtype)
1339*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 20, device=device, dtype=dtype)
1340*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(1, 20, device=device, dtype=dtype)
1341*da0073e9SAndroid Build Coastguard Worker
1342*da0073e9SAndroid Build Coastguard Worker        split_sizes = [4, 6, 10]
1343*da0073e9SAndroid Build Coastguard Worker        a_splits = a.split_with_sizes(split_sizes, dim=-1)
1344*da0073e9SAndroid Build Coastguard Worker        b_splits = b.split_with_sizes(split_sizes, dim=-1)
1345*da0073e9SAndroid Build Coastguard Worker        c_splits = c.split_with_sizes(split_sizes, dim=-1)
1346*da0073e9SAndroid Build Coastguard Worker
1347*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([a, b, c])
1348*da0073e9SAndroid Build Coastguard Worker        nt_splits = nt.split_with_sizes(split_sizes, dim=-1)
1349*da0073e9SAndroid Build Coastguard Worker
1350*da0073e9SAndroid Build Coastguard Worker        for i, nt_split in enumerate(nt_splits):
1351*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1352*da0073e9SAndroid Build Coastguard Worker                nt_split,
1353*da0073e9SAndroid Build Coastguard Worker                torch.nested.nested_tensor([a_splits[i], b_splits[i], c_splits[i]]),
1354*da0073e9SAndroid Build Coastguard Worker            )
1355*da0073e9SAndroid Build Coastguard Worker            dense_strides = torch.stack(
1356*da0073e9SAndroid Build Coastguard Worker                [
1357*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(a_splits[i].stride()),
1358*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(b_splits[i].stride()),
1359*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(c_splits[i].stride()),
1360*da0073e9SAndroid Build Coastguard Worker                ]
1361*da0073e9SAndroid Build Coastguard Worker            )
1362*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_split._nested_tensor_strides(), dense_strides)
1363*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(nt_split.is_contiguous())
1364*da0073e9SAndroid Build Coastguard Worker
1365*da0073e9SAndroid Build Coastguard Worker        # Failure calling on ragged dimensions
1366*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1367*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1368*da0073e9SAndroid Build Coastguard Worker            "split_with_sizes for nested tensors is currently only supported for the last dimension.",
1369*da0073e9SAndroid Build Coastguard Worker            lambda: torch.split_with_sizes(nt, split_sizes, dim=1),
1370*da0073e9SAndroid Build Coastguard Worker        )
1371*da0073e9SAndroid Build Coastguard Worker
1372*da0073e9SAndroid Build Coastguard Worker        # Failure calling on non-last dimension
1373*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1374*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1375*da0073e9SAndroid Build Coastguard Worker            "split_with_sizes for nested tensors is currently only supported for the last dimension.",
1376*da0073e9SAndroid Build Coastguard Worker            lambda: torch.split_with_sizes(nt, split_sizes, dim=0),
1377*da0073e9SAndroid Build Coastguard Worker        )
1378*da0073e9SAndroid Build Coastguard Worker
1379*da0073e9SAndroid Build Coastguard Worker        # Failure on non-contiguous nt
1380*da0073e9SAndroid Build Coastguard Worker        _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
1381*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1382*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1383*da0073e9SAndroid Build Coastguard Worker            "split_with_sizes expects `self` to be contiguous.",
1384*da0073e9SAndroid Build Coastguard Worker            lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1),
1385*da0073e9SAndroid Build Coastguard Worker        )
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker        # Failure when calling with split_sizes that don't cover the full dim size
1388*da0073e9SAndroid Build Coastguard Worker        bad_split_sizes = [4, 6, 9]  # don't add up to 20
1389*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1390*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1391*da0073e9SAndroid Build Coastguard Worker            "split_with_sizes expects split_sizes to sum exactly to 20",
1392*da0073e9SAndroid Build Coastguard Worker            lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1),
1393*da0073e9SAndroid Build Coastguard Worker        )
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
1396*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
1397*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_indexing_noncontiguous(self, device, dtype):
1398*da0073e9SAndroid Build Coastguard Worker        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
1399*da0073e9SAndroid Build Coastguard Worker            (2, 3, 6, 7), device, dtype
1400*da0073e9SAndroid Build Coastguard Worker        )
1401*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt_contiguous.size(0), nt_noncontiguous.size(0))
1402*da0073e9SAndroid Build Coastguard Worker        n = nt_contiguous.size(0)
1403*da0073e9SAndroid Build Coastguard Worker        for i in range(n):
1404*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_contiguous[i], nt_noncontiguous[i])
1405*da0073e9SAndroid Build Coastguard Worker
1406*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16)
1407*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1408*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
1409*da0073e9SAndroid Build Coastguard Worker    @parametrize("transpose", [True, False])
1410*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_add(self, device, dtype, transpose):
1411*da0073e9SAndroid Build Coastguard Worker        if transpose:
1412*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(2, 2, 2, device=device, dtype=dtype)
1413*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(2, 2, 2, device=device, dtype=dtype)
1414*da0073e9SAndroid Build Coastguard Worker            c = a.transpose(-1, -2).contiguous()
1415*da0073e9SAndroid Build Coastguard Worker            d = b.transpose(-1, -2).contiguous()
1416*da0073e9SAndroid Build Coastguard Worker            nt1 = torch.nested.nested_tensor([a, b, a, b])
1417*da0073e9SAndroid Build Coastguard Worker            nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2)
1418*da0073e9SAndroid Build Coastguard Worker        else:
1419*da0073e9SAndroid Build Coastguard Worker            (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
1420*da0073e9SAndroid Build Coastguard Worker        ref = torch.nested.nested_tensor(
1421*da0073e9SAndroid Build Coastguard Worker            [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
1422*da0073e9SAndroid Build Coastguard Worker        )
1423*da0073e9SAndroid Build Coastguard Worker        out = nt1 + nt2
1424*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, out)
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16)
1427*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1428*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
1429*da0073e9SAndroid Build Coastguard Worker    @parametrize("transpose", [True, False])
1430*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_sub(self, device, dtype, transpose):
1431*da0073e9SAndroid Build Coastguard Worker        if transpose:
1432*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(2, 2, 2, device=device, dtype=dtype)
1433*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(2, 2, 2, device=device, dtype=dtype)
1434*da0073e9SAndroid Build Coastguard Worker            c = a.transpose(-1, -2).contiguous()
1435*da0073e9SAndroid Build Coastguard Worker            d = b.transpose(-1, -2).contiguous()
1436*da0073e9SAndroid Build Coastguard Worker            nt1 = torch.nested.nested_tensor([a, b, a, b])
1437*da0073e9SAndroid Build Coastguard Worker            nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2)
1438*da0073e9SAndroid Build Coastguard Worker        else:
1439*da0073e9SAndroid Build Coastguard Worker            (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
1440*da0073e9SAndroid Build Coastguard Worker        ref = torch.nested.nested_tensor(
1441*da0073e9SAndroid Build Coastguard Worker            [t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
1442*da0073e9SAndroid Build Coastguard Worker        )
1443*da0073e9SAndroid Build Coastguard Worker        out = nt1 - nt2
1444*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, out)
1445*da0073e9SAndroid Build Coastguard Worker
1446*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1447*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16)
1448*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
1449*da0073e9SAndroid Build Coastguard Worker    @parametrize("embedding_dim", [8, 128, 256, 384])
1450*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim):
1451*da0073e9SAndroid Build Coastguard Worker        def _test_add_mul(nt, t):
1452*da0073e9SAndroid Build Coastguard Worker            ref_add = torch.nested.nested_tensor(
1453*da0073e9SAndroid Build Coastguard Worker                [t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]
1454*da0073e9SAndroid Build Coastguard Worker            )
1455*da0073e9SAndroid Build Coastguard Worker            ref_mul = torch.nested.nested_tensor(
1456*da0073e9SAndroid Build Coastguard Worker                [t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]
1457*da0073e9SAndroid Build Coastguard Worker            )
1458*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt.add(t), ref_add)
1459*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt.mul(t), ref_mul)
1460*da0073e9SAndroid Build Coastguard Worker
1461*da0073e9SAndroid Build Coastguard Worker        batch_size = 32
1462*da0073e9SAndroid Build Coastguard Worker        seq_lens = torch.randint(low=0, high=10, size=(batch_size,))
1463*da0073e9SAndroid Build Coastguard Worker
1464*da0073e9SAndroid Build Coastguard Worker        # [B, *, D], [B, 1, D] case
1465*da0073e9SAndroid Build Coastguard Worker        ts = [torch.randn((seq_len, embedding_dim)) for seq_len in seq_lens]
1466*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1467*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((batch_size, 1, embedding_dim), device=device, dtype=dtype)
1468*da0073e9SAndroid Build Coastguard Worker        _test_add_mul(nt, t)
1469*da0073e9SAndroid Build Coastguard Worker
1470*da0073e9SAndroid Build Coastguard Worker        # [B, *], [B, 1] case
1471*da0073e9SAndroid Build Coastguard Worker        ts = [torch.randn(seq_len) for seq_len in seq_lens]
1472*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
1473*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((batch_size, 1), device=device, dtype=dtype)
1474*da0073e9SAndroid Build Coastguard Worker        _test_add_mul(nt, t)
1475*da0073e9SAndroid Build Coastguard Worker
1476*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16)
1477*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1478*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
1479*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_mul(self, device, dtype):
1480*da0073e9SAndroid Build Coastguard Worker        # nested tensor * nested tensor
1481*da0073e9SAndroid Build Coastguard Worker        (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
1482*da0073e9SAndroid Build Coastguard Worker        ref = torch.nested.nested_tensor(
1483*da0073e9SAndroid Build Coastguard Worker            [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
1484*da0073e9SAndroid Build Coastguard Worker        )
1485*da0073e9SAndroid Build Coastguard Worker        out = nt1 * nt2
1486*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, out)
1487*da0073e9SAndroid Build Coastguard Worker        # nested tensor * scalar
1488*da0073e9SAndroid Build Coastguard Worker        number = 10.0
1489*da0073e9SAndroid Build Coastguard Worker        scalar = torch.tensor(number).to(dtype).to(device)
1490*da0073e9SAndroid Build Coastguard Worker        ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()])
1491*da0073e9SAndroid Build Coastguard Worker        out_number0 = nt1 * number
1492*da0073e9SAndroid Build Coastguard Worker        out_number1 = number * nt1
1493*da0073e9SAndroid Build Coastguard Worker        out_scalar0 = nt1 * scalar
1494*da0073e9SAndroid Build Coastguard Worker        out_scalar1 = scalar * nt1
1495*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_number0, ref)
1496*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_number1, ref)
1497*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_scalar0, ref)
1498*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_scalar1, ref)
1499*da0073e9SAndroid Build Coastguard Worker        # error case: numel == 1 but dim > 0
1500*da0073e9SAndroid Build Coastguard Worker        vector = torch.tensor([number]).to(dtype).to(device)
1501*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1502*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1503*da0073e9SAndroid Build Coastguard Worker            "Expected both self and other to be nested, but got a nested self and non-nested other",
1504*da0073e9SAndroid Build Coastguard Worker            lambda: nt1.mul(vector),
1505*da0073e9SAndroid Build Coastguard Worker        )
1506*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1507*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1508*da0073e9SAndroid Build Coastguard Worker            "Expected both self and other to be nested, but got a non-nested self and nested other",
1509*da0073e9SAndroid Build Coastguard Worker            lambda: vector.mul(nt1),
1510*da0073e9SAndroid Build Coastguard Worker        )
1511*da0073e9SAndroid Build Coastguard Worker
1512*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16)
1513*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1514*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
1515*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_div(self, device, dtype):
1516*da0073e9SAndroid Build Coastguard Worker        nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4))
1517*da0073e9SAndroid Build Coastguard Worker        scale = 4.0
1518*da0073e9SAndroid Build Coastguard Worker        ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()])
1519*da0073e9SAndroid Build Coastguard Worker        out = nt / 4.0
1520*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, out)
1521*da0073e9SAndroid Build Coastguard Worker        ref_transposed = ref.transpose(1, 2)
1522*da0073e9SAndroid Build Coastguard Worker        out = nt.transpose(1, 2) / 4.0
1523*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref_transposed, out)
1524*da0073e9SAndroid Build Coastguard Worker
1525*da0073e9SAndroid Build Coastguard Worker        ref = torch.nested.nested_tensor(
1526*da0073e9SAndroid Build Coastguard Worker            [t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())]
1527*da0073e9SAndroid Build Coastguard Worker        )
1528*da0073e9SAndroid Build Coastguard Worker        out = nt / nt2
1529*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, out)
1530*da0073e9SAndroid Build Coastguard Worker
1531*da0073e9SAndroid Build Coastguard Worker        out = nt.transpose(1, 2) / nt2.transpose(1, 2)
1532*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref.transpose(1, 2), out)
1533*da0073e9SAndroid Build Coastguard Worker
1534*da0073e9SAndroid Build Coastguard Worker        nt_transpose_copy = torch.nested.nested_tensor(
1535*da0073e9SAndroid Build Coastguard Worker            [t.transpose(0, 1) for t in nt.unbind()]
1536*da0073e9SAndroid Build Coastguard Worker        )
1537*da0073e9SAndroid Build Coastguard Worker
1538*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1539*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1540*da0073e9SAndroid Build Coastguard Worker            "div requires strides to match when given NestedTensors",
1541*da0073e9SAndroid Build Coastguard Worker            lambda: nt_transpose_copy.transpose(1, 2) / nt2,
1542*da0073e9SAndroid Build Coastguard Worker        )
1543*da0073e9SAndroid Build Coastguard Worker
1544*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
1545*da0073e9SAndroid Build Coastguard Worker            [torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype
1546*da0073e9SAndroid Build Coastguard Worker        )
1547*da0073e9SAndroid Build Coastguard Worker        nt_chunks = nt.chunk(2, -1)
1548*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1549*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1550*da0073e9SAndroid Build Coastguard Worker            "div requires offsets to match when given NestedTensors",
1551*da0073e9SAndroid Build Coastguard Worker            lambda: nt_chunks[0] / nt_chunks[1],
1552*da0073e9SAndroid Build Coastguard Worker        )
1553*da0073e9SAndroid Build Coastguard Worker
1554*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16)
1555*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1556*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
1557*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_add_in_place(self, device, dtype):
1558*da0073e9SAndroid Build Coastguard Worker        (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
1559*da0073e9SAndroid Build Coastguard Worker        ref = torch.nested.nested_tensor(
1560*da0073e9SAndroid Build Coastguard Worker            [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
1561*da0073e9SAndroid Build Coastguard Worker        )
1562*da0073e9SAndroid Build Coastguard Worker        nt1 += nt2
1563*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, nt1)
1564*da0073e9SAndroid Build Coastguard Worker
1565*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16)
1566*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1567*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
1568*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_mul_in_place(self, device, dtype):
1569*da0073e9SAndroid Build Coastguard Worker        # nested tensor * nested tensor
1570*da0073e9SAndroid Build Coastguard Worker        (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
1571*da0073e9SAndroid Build Coastguard Worker        ref = torch.nested.nested_tensor(
1572*da0073e9SAndroid Build Coastguard Worker            [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
1573*da0073e9SAndroid Build Coastguard Worker        )
1574*da0073e9SAndroid Build Coastguard Worker        nt1 *= nt2
1575*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, nt1)
1576*da0073e9SAndroid Build Coastguard Worker        # nested tensor * scalar
1577*da0073e9SAndroid Build Coastguard Worker        number = 10.0
1578*da0073e9SAndroid Build Coastguard Worker        scalar = torch.tensor(number).to(dtype).to(device)
1579*da0073e9SAndroid Build Coastguard Worker        ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()])
1580*da0073e9SAndroid Build Coastguard Worker        out_number = nt1.clone()
1581*da0073e9SAndroid Build Coastguard Worker        out_number *= number
1582*da0073e9SAndroid Build Coastguard Worker        out_scalar = nt1.clone()
1583*da0073e9SAndroid Build Coastguard Worker        out_scalar *= scalar
1584*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_number, ref)
1585*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_scalar, ref)
1586*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1587*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1588*da0073e9SAndroid Build Coastguard Worker            r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]",
1589*da0073e9SAndroid Build Coastguard Worker            lambda: scalar.mul_(nt1),
1590*da0073e9SAndroid Build Coastguard Worker        )
1591*da0073e9SAndroid Build Coastguard Worker        # error case: numel == 1 but dim > 0
1592*da0073e9SAndroid Build Coastguard Worker        vector = torch.tensor([number]).to(dtype).to(device)
1593*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1594*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1595*da0073e9SAndroid Build Coastguard Worker            "Expected both self and other to be nested, but got a nested self and non-nested other",
1596*da0073e9SAndroid Build Coastguard Worker            lambda: nt1.mul_(vector),
1597*da0073e9SAndroid Build Coastguard Worker        )
1598*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1599*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1600*da0073e9SAndroid Build Coastguard Worker            "Expected both self and other to be nested, but got a non-nested self and nested other",
1601*da0073e9SAndroid Build Coastguard Worker            lambda: vector.mul_(nt1),
1602*da0073e9SAndroid Build Coastguard Worker        )
1603*da0073e9SAndroid Build Coastguard Worker
1604*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1605*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1606*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1607*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_sum_dim(self, device, dtype):
1608*da0073e9SAndroid Build Coastguard Worker        params = ((2, (1, 1)), ((4), (4, 4)), (10, (3, 5, 7)))
1609*da0073e9SAndroid Build Coastguard Worker
1610*da0073e9SAndroid Build Coastguard Worker        def test_sum(device, dtype, ntensors, max_sizes, dim, keepdim=True):
1611*da0073e9SAndroid Build Coastguard Worker            nt = random_nt(device, dtype, ntensors, max_sizes, require_non_empty=False)
1612*da0073e9SAndroid Build Coastguard Worker            nt2 = nt.clone()
1613*da0073e9SAndroid Build Coastguard Worker            ub2 = nt2.unbind()
1614*da0073e9SAndroid Build Coastguard Worker            nt.requires_grad_(True)
1615*da0073e9SAndroid Build Coastguard Worker            [t.requires_grad_(True) for t in ub2]
1616*da0073e9SAndroid Build Coastguard Worker            nt_sum = nt.sum(dim=dim, keepdim=keepdim)
1617*da0073e9SAndroid Build Coastguard Worker            ub2_sum = [t.sum(-1, keepdim=keepdim) for t in ub2]
1618*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_sum, torch.nested.nested_tensor(ub2_sum))
1619*da0073e9SAndroid Build Coastguard Worker
1620*da0073e9SAndroid Build Coastguard Worker            # test backward
1621*da0073e9SAndroid Build Coastguard Worker            # generate gradient tensor that has the same size as the output
1622*da0073e9SAndroid Build Coastguard Worker            size = nt_sum._nested_tensor_size()
1623*da0073e9SAndroid Build Coastguard Worker            gt2 = []
1624*da0073e9SAndroid Build Coastguard Worker            for i in range(ntensors):
1625*da0073e9SAndroid Build Coastguard Worker                gt2.append(torch.randn(size[i].tolist(), device=device, dtype=dtype))
1626*da0073e9SAndroid Build Coastguard Worker            gt = torch.nested.nested_tensor(gt2).clone()
1627*da0073e9SAndroid Build Coastguard Worker            nt_sum.backward(gt)
1628*da0073e9SAndroid Build Coastguard Worker            for t2, g2 in zip(ub2_sum, gt2):
1629*da0073e9SAndroid Build Coastguard Worker                t2.backward(g2)
1630*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt.grad, torch.nested.nested_tensor([t.grad for t in ub2]))
1631*da0073e9SAndroid Build Coastguard Worker            return
1632*da0073e9SAndroid Build Coastguard Worker
1633*da0073e9SAndroid Build Coastguard Worker        for ntensors, max_sizes in params:
1634*da0073e9SAndroid Build Coastguard Worker            test_sum(device, dtype, ntensors, max_sizes, len(max_sizes))
1635*da0073e9SAndroid Build Coastguard Worker
1636*da0073e9SAndroid Build Coastguard Worker        # Test error inputs
1637*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1638*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "NestedTensor can only be reduced across the last"
1639*da0073e9SAndroid Build Coastguard Worker        ):
1640*da0073e9SAndroid Build Coastguard Worker            torch.nested.nested_tensor(
1641*da0073e9SAndroid Build Coastguard Worker                [torch.tensor([3, 4, 5]), torch.tensor([1, 2])]
1642*da0073e9SAndroid Build Coastguard Worker            ).sum(0, keepdim=True)
1643*da0073e9SAndroid Build Coastguard Worker
1644*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1645*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "NestedTensor only allows reduction of a single"
1646*da0073e9SAndroid Build Coastguard Worker        ):
1647*da0073e9SAndroid Build Coastguard Worker            torch.nested.nested_tensor(
1648*da0073e9SAndroid Build Coastguard Worker                [torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])]
1649*da0073e9SAndroid Build Coastguard Worker            ).sum([0, 1], keepdim=True)
1650*da0073e9SAndroid Build Coastguard Worker
1651*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1652*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "NestedTensor always requires keepdim=True for now."
1653*da0073e9SAndroid Build Coastguard Worker        ):
1654*da0073e9SAndroid Build Coastguard Worker            torch.nested.nested_tensor(
1655*da0073e9SAndroid Build Coastguard Worker                [torch.tensor([3, 4, 5]), torch.tensor([1, 2])]
1656*da0073e9SAndroid Build Coastguard Worker            ).sum(-1)
1657*da0073e9SAndroid Build Coastguard Worker
1658*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16)
1659*da0073e9SAndroid Build Coastguard Worker    def test_contiguous(self, device, dtype):
1660*da0073e9SAndroid Build Coastguard Worker        # Since we don't have access to the buffer in python this is harder to show what
1661*da0073e9SAndroid Build Coastguard Worker        # we are testing for. When we call chunk on a consistent dim of a NT
1662*da0073e9SAndroid Build Coastguard Worker        # for chunk_size > 1 the resulting tensors are views of the original NT
1663*da0073e9SAndroid Build Coastguard Worker        # whose numels is now less than the size of the buffer. Clone was
1664*da0073e9SAndroid Build Coastguard Worker        # previously creating a new NT with a buffer that was the same size as the
1665*da0073e9SAndroid Build Coastguard Worker        # original.
1666*da0073e9SAndroid Build Coastguard Worker        nt_contiguous = torch.nested.nested_tensor(
1667*da0073e9SAndroid Build Coastguard Worker            [
1668*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 20, device=device, dtype=dtype),
1669*da0073e9SAndroid Build Coastguard Worker                torch.randn(4, 20, device=device, dtype=dtype),
1670*da0073e9SAndroid Build Coastguard Worker            ]
1671*da0073e9SAndroid Build Coastguard Worker        )
1672*da0073e9SAndroid Build Coastguard Worker        # Split up the last dimension which has a consistent size of 20 into 5 chunks
1673*da0073e9SAndroid Build Coastguard Worker        chunks = nt_contiguous.chunk(5, dim=-1)
1674*da0073e9SAndroid Build Coastguard Worker
1675*da0073e9SAndroid Build Coastguard Worker        # # Check chunks are contiguous after calling contiguous
1676*da0073e9SAndroid Build Coastguard Worker        for chunk in chunks:
1677*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(chunk.is_contiguous())
1678*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(chunk.contiguous().is_contiguous())
1679*da0073e9SAndroid Build Coastguard Worker
1680*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16)
1681*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1682*da0073e9SAndroid Build Coastguard Worker    def test_clone(self, device, dtype):
1683*da0073e9SAndroid Build Coastguard Worker        nt1 = random_nt(device, dtype, 4, (4, 4), (1, 1))
1684*da0073e9SAndroid Build Coastguard Worker        nt2 = nt1.clone()
1685*da0073e9SAndroid Build Coastguard Worker        # Verify the values match
1686*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt1, nt2)
1687*da0073e9SAndroid Build Coastguard Worker        # Verify modifying nt2 doesn't affect nt1
1688*da0073e9SAndroid Build Coastguard Worker        nt2.mul_(nt1)
1689*da0073e9SAndroid Build Coastguard Worker        ub1 = nt1.unbind()
1690*da0073e9SAndroid Build Coastguard Worker        ub2 = nt2.unbind()
1691*da0073e9SAndroid Build Coastguard Worker        for i in range(len(ub1)):
1692*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(ub1[i], ub2[i])
1693*da0073e9SAndroid Build Coastguard Worker
1694*da0073e9SAndroid Build Coastguard Worker        nt1.clone(memory_format=torch.preserve_format)
1695*da0073e9SAndroid Build Coastguard Worker        msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast"
1696*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1697*da0073e9SAndroid Build Coastguard Worker            nt1.clone(memory_format=torch.channels_last)
1698*da0073e9SAndroid Build Coastguard Worker
1699*da0073e9SAndroid Build Coastguard Worker    # cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half'
1700*da0073e9SAndroid Build Coastguard Worker    @decorateIf(xfailIfTorchDynamo, lambda params: params["layout"] == torch.jagged)
1701*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1702*da0073e9SAndroid Build Coastguard Worker    @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
1703*da0073e9SAndroid Build Coastguard Worker    def test_dropout(self, device, dtype, layout):
1704*da0073e9SAndroid Build Coastguard Worker        # edge case: empty nested tensor
1705*da0073e9SAndroid Build Coastguard Worker        # TODO: support empty NT in jagged layout
1706*da0073e9SAndroid Build Coastguard Worker        if layout == torch.strided:
1707*da0073e9SAndroid Build Coastguard Worker            nt0 = torch.nested.nested_tensor([], layout=layout)
1708*da0073e9SAndroid Build Coastguard Worker            y = torch.nn.functional.dropout(nt0, 0.5)
1709*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt0, y)
1710*da0073e9SAndroid Build Coastguard Worker        # normal nested tensor
1711*da0073e9SAndroid Build Coastguard Worker        ntensors = 4
1712*da0073e9SAndroid Build Coastguard Worker        if layout == torch.jagged:
1713*da0073e9SAndroid Build Coastguard Worker            nt = random_nt(device, dtype, ntensors, (4, 4), (0, 3), layout=layout)
1714*da0073e9SAndroid Build Coastguard Worker        else:
1715*da0073e9SAndroid Build Coastguard Worker            nt = random_nt(device, dtype, ntensors, (4, 4), layout=layout)
1716*da0073e9SAndroid Build Coastguard Worker        # edge case: invalid dropout
1717*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
1718*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
1719*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1))
1720*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1))
1721*da0073e9SAndroid Build Coastguard Worker        # edge case: no dropout
1722*da0073e9SAndroid Build Coastguard Worker        dropouter = torch.nn.Dropout(0.0)
1723*da0073e9SAndroid Build Coastguard Worker        y0 = dropouter(nt)
1724*da0073e9SAndroid Build Coastguard Worker        y1 = torch.nn.functional.dropout(nt, 0.0)
1725*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt, y0)
1726*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt, y1)
1727*da0073e9SAndroid Build Coastguard Worker        # edge case: all dropout
1728*da0073e9SAndroid Build Coastguard Worker        dropouter = torch.nn.Dropout(1.0)
1729*da0073e9SAndroid Build Coastguard Worker        y0 = dropouter(nt)
1730*da0073e9SAndroid Build Coastguard Worker        y1 = torch.nn.functional.dropout(nt, 1.0)
1731*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.zeros_like(nt)
1732*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt0, y0)
1733*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt0, y1)
1734*da0073e9SAndroid Build Coastguard Worker        # normal case: normal dropout
1735*da0073e9SAndroid Build Coastguard Worker        p = 0.2
1736*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.dropout(nt, p)
1737*da0073e9SAndroid Build Coastguard Worker        expect = nt.clone()
1738*da0073e9SAndroid Build Coastguard Worker        if layout == torch.jagged:
1739*da0073e9SAndroid Build Coastguard Worker            expect = torch.where(y == 0.0, y, nt)
1740*da0073e9SAndroid Build Coastguard Worker            expect /= 1.0 - p
1741*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, expect)
1742*da0073e9SAndroid Build Coastguard Worker        else:
1743*da0073e9SAndroid Build Coastguard Worker            expect = nt.clone()
1744*da0073e9SAndroid Build Coastguard Worker            for i in range(ntensors):
1745*da0073e9SAndroid Build Coastguard Worker                actual_tensor = y[i].view(-1)
1746*da0073e9SAndroid Build Coastguard Worker                expect_tensor = expect[i].view(-1)
1747*da0073e9SAndroid Build Coastguard Worker                for j in range(actual_tensor.shape[0]):
1748*da0073e9SAndroid Build Coastguard Worker                    if actual_tensor[j].item() == 0.0:
1749*da0073e9SAndroid Build Coastguard Worker                        expect_tensor[j] = 0.0
1750*da0073e9SAndroid Build Coastguard Worker                    else:
1751*da0073e9SAndroid Build Coastguard Worker                        expect_tensor[j] /= 1.0 - p
1752*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, expect)
1753*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
1754*da0073e9SAndroid Build Coastguard Worker            dropouter = torch.nn.Dropout(p)
1755*da0073e9SAndroid Build Coastguard Worker            y0 = dropouter(nt)
1756*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
1757*da0073e9SAndroid Build Coastguard Worker            y1 = torch.nn.functional.dropout(nt, p)
1758*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y0, y1)
1759*da0073e9SAndroid Build Coastguard Worker
1760*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1761*da0073e9SAndroid Build Coastguard Worker    def test_dropout_noncontiguous(self, device, dtype):
1762*da0073e9SAndroid Build Coastguard Worker        ntensors = 4
1763*da0073e9SAndroid Build Coastguard Worker        nt0 = random_nt(device, dtype, ntensors, (4, 4))
1764*da0073e9SAndroid Build Coastguard Worker        nt1 = nt0.transpose(-1, -2)
1765*da0073e9SAndroid Build Coastguard Worker        p = 0.3
1766*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
1767*da0073e9SAndroid Build Coastguard Worker            dropouter = torch.nn.Dropout(p)
1768*da0073e9SAndroid Build Coastguard Worker            y0 = dropouter(nt0)
1769*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
1770*da0073e9SAndroid Build Coastguard Worker            y1 = torch.nn.functional.dropout(nt1, p).transpose(-1, -2)
1771*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y0, y1)
1772*da0073e9SAndroid Build Coastguard Worker
1773*da0073e9SAndroid Build Coastguard Worker    # cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half'
1774*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1775*da0073e9SAndroid Build Coastguard Worker    def test_softmax(self, device, dtype):
1776*da0073e9SAndroid Build Coastguard Worker        # normal nested tensor
1777*da0073e9SAndroid Build Coastguard Worker        ntensors = 4
1778*da0073e9SAndroid Build Coastguard Worker        nt = random_nt(device, dtype, ntensors, (4, 4))
1779*da0073e9SAndroid Build Coastguard Worker        # error case: softmax across nested dimension
1780*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1781*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1782*da0073e9SAndroid Build Coastguard Worker            "Cannot apply softmax across nested dimension 0",
1783*da0073e9SAndroid Build Coastguard Worker            lambda: torch.nn.functional.softmax(nt, 0),
1784*da0073e9SAndroid Build Coastguard Worker        )
1785*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1786*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1787*da0073e9SAndroid Build Coastguard Worker            "Cannot apply softmax across nested dimension 0",
1788*da0073e9SAndroid Build Coastguard Worker            lambda: torch.nn.functional.softmax(nt, -3),
1789*da0073e9SAndroid Build Coastguard Worker        )
1790*da0073e9SAndroid Build Coastguard Worker        # error case: dimension out of range
1791*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3))
1792*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4))
1793*da0073e9SAndroid Build Coastguard Worker        # normal case: should equal to padding -inf
1794*da0073e9SAndroid Build Coastguard Worker        softmaxer = torch.nn.Softmax(1)
1795*da0073e9SAndroid Build Coastguard Worker        y0 = softmaxer(nt)
1796*da0073e9SAndroid Build Coastguard Worker        y1 = torch.nn.functional.softmax(nt, 1)
1797*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y0, y1)
1798*da0073e9SAndroid Build Coastguard Worker        pt = torch.nested.to_padded_tensor(nt, float("-inf"))
1799*da0073e9SAndroid Build Coastguard Worker        # if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan
1800*da0073e9SAndroid Build Coastguard Worker        # however, physically speaking that should be 0.0
1801*da0073e9SAndroid Build Coastguard Worker        expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0)
1802*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nested.to_padded_tensor(y0, 0.0), expect)
1803*da0073e9SAndroid Build Coastguard Worker        # edge case: empty nested tensor
1804*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor([])
1805*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.softmax(nt0, 1)
1806*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt0, y)
1807*da0073e9SAndroid Build Coastguard Worker        # edge case: nesting scalars
1808*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)])
1809*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0))
1810*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1))
1811*da0073e9SAndroid Build Coastguard Worker
1812*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1813*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
1814*da0073e9SAndroid Build Coastguard Worker    def test_softmax_noncontiguous(self, device, dtype):
1815*da0073e9SAndroid Build Coastguard Worker        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
1816*da0073e9SAndroid Build Coastguard Worker            (2, 3, 6, 7), device, dtype
1817*da0073e9SAndroid Build Coastguard Worker        )
1818*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1819*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.softmax(nt_contiguous, -1),
1820*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.softmax(nt_noncontiguous, -1),
1821*da0073e9SAndroid Build Coastguard Worker        )
1822*da0073e9SAndroid Build Coastguard Worker
1823*da0073e9SAndroid Build Coastguard Worker    def _test_bmm(self, device, dtype):
1824*da0073e9SAndroid Build Coastguard Worker        # error case: not 3D tensors
1825*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype)
1826*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
1827*da0073e9SAndroid Build Coastguard Worker            [torch.randn(2), torch.randn(3)], device=device, dtype=dtype
1828*da0073e9SAndroid Build Coastguard Worker        )
1829*da0073e9SAndroid Build Coastguard Worker        nt2 = torch.nested.nested_tensor(
1830*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
1831*da0073e9SAndroid Build Coastguard Worker        )
1832*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1833*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt0)
1834*da0073e9SAndroid Build Coastguard Worker        )
1835*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1836*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt1)
1837*da0073e9SAndroid Build Coastguard Worker        )
1838*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1839*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt2)
1840*da0073e9SAndroid Build Coastguard Worker        )
1841*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1842*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt0)
1843*da0073e9SAndroid Build Coastguard Worker        )
1844*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1845*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt1)
1846*da0073e9SAndroid Build Coastguard Worker        )
1847*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1848*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt2)
1849*da0073e9SAndroid Build Coastguard Worker        )
1850*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1851*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt0)
1852*da0073e9SAndroid Build Coastguard Worker        )
1853*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1854*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt1)
1855*da0073e9SAndroid Build Coastguard Worker        )
1856*da0073e9SAndroid Build Coastguard Worker        # error case: incompatible batch size
1857*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
1858*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
1859*da0073e9SAndroid Build Coastguard Worker        )
1860*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
1861*da0073e9SAndroid Build Coastguard Worker            [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))],
1862*da0073e9SAndroid Build Coastguard Worker            device=device,
1863*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
1864*da0073e9SAndroid Build Coastguard Worker        )
1865*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1866*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1867*da0073e9SAndroid Build Coastguard Worker            "Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.",
1868*da0073e9SAndroid Build Coastguard Worker            lambda: nt0.bmm(nt1),
1869*da0073e9SAndroid Build Coastguard Worker        )
1870*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1871*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1872*da0073e9SAndroid Build Coastguard Worker            "Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.",
1873*da0073e9SAndroid Build Coastguard Worker            lambda: nt1.bmm(nt0),
1874*da0073e9SAndroid Build Coastguard Worker        )
1875*da0073e9SAndroid Build Coastguard Worker        # error case: underlying matrices cannot be multiplied
1876*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
1877*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
1878*da0073e9SAndroid Build Coastguard Worker        )
1879*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1880*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1881*da0073e9SAndroid Build Coastguard Worker            r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)",
1882*da0073e9SAndroid Build Coastguard Worker            lambda: nt0.bmm(nt0),
1883*da0073e9SAndroid Build Coastguard Worker        )
1884*da0073e9SAndroid Build Coastguard Worker        # normal nested tensor
1885*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
1886*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype
1887*da0073e9SAndroid Build Coastguard Worker        )
1888*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
1889*da0073e9SAndroid Build Coastguard Worker            [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype
1890*da0073e9SAndroid Build Coastguard Worker        )
1891*da0073e9SAndroid Build Coastguard Worker        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
1892*da0073e9SAndroid Build Coastguard Worker        expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(
1893*da0073e9SAndroid Build Coastguard Worker            torch.nested.to_padded_tensor(nt1, 0.0)
1894*da0073e9SAndroid Build Coastguard Worker        )
1895*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float16:
1896*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
1897*da0073e9SAndroid Build Coastguard Worker        else:
1898*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect)
1899*da0073e9SAndroid Build Coastguard Worker
1900*da0073e9SAndroid Build Coastguard Worker        # nested tensor bmm normal tensor
1901*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
1902*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 7)), torch.randn((3, 7))], device=device, dtype=dtype
1903*da0073e9SAndroid Build Coastguard Worker        )
1904*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.rand(2, 7, 5, dtype=dtype, device=device)
1905*da0073e9SAndroid Build Coastguard Worker        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
1906*da0073e9SAndroid Build Coastguard Worker        expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1)
1907*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float16:
1908*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
1909*da0073e9SAndroid Build Coastguard Worker        else:
1910*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect)
1911*da0073e9SAndroid Build Coastguard Worker
1912*da0073e9SAndroid Build Coastguard Worker        # nested tensor bmm normal tensor with non-contiguous view
1913*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.rand(2, 5, 7, dtype=dtype, device=device)
1914*da0073e9SAndroid Build Coastguard Worker        nt1 = nt1.transpose(1, 2)
1915*da0073e9SAndroid Build Coastguard Worker        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
1916*da0073e9SAndroid Build Coastguard Worker        expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1)
1917*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float16:
1918*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
1919*da0073e9SAndroid Build Coastguard Worker        else:
1920*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect)
1921*da0073e9SAndroid Build Coastguard Worker
1922*da0073e9SAndroid Build Coastguard Worker        # normal tensor bmm nested tensor
1923*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.rand(2, 5, 7, dtype=dtype, device=device)
1924*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
1925*da0073e9SAndroid Build Coastguard Worker            [torch.randn((7, 6)), torch.randn((7, 5))], device=device, dtype=dtype
1926*da0073e9SAndroid Build Coastguard Worker        )
1927*da0073e9SAndroid Build Coastguard Worker        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
1928*da0073e9SAndroid Build Coastguard Worker        expect = nt0.bmm(torch.nested.to_padded_tensor(nt1, 0.0))
1929*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float16:
1930*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
1931*da0073e9SAndroid Build Coastguard Worker        else:
1932*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect)
1933*da0073e9SAndroid Build Coastguard Worker
1934*da0073e9SAndroid Build Coastguard Worker        # test tensorcore path
1935*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
1936*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype
1937*da0073e9SAndroid Build Coastguard Worker        )
1938*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
1939*da0073e9SAndroid Build Coastguard Worker            [torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype
1940*da0073e9SAndroid Build Coastguard Worker        )
1941*da0073e9SAndroid Build Coastguard Worker        actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
1942*da0073e9SAndroid Build Coastguard Worker        expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(
1943*da0073e9SAndroid Build Coastguard Worker            torch.nested.to_padded_tensor(nt1, 0.0)
1944*da0073e9SAndroid Build Coastguard Worker        )
1945*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float16:
1946*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
1947*da0073e9SAndroid Build Coastguard Worker        else:
1948*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect)
1949*da0073e9SAndroid Build Coastguard Worker
1950*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1951*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.float16)
1952*da0073e9SAndroid Build Coastguard Worker    def test_bmm_cuda(self, device, dtype):
1953*da0073e9SAndroid Build Coastguard Worker        self._test_bmm(device, dtype)
1954*da0073e9SAndroid Build Coastguard Worker
1955*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1956*da0073e9SAndroid Build Coastguard Worker    # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
1957*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1958*da0073e9SAndroid Build Coastguard Worker    def test_bmm_cpu(self, device, dtype):
1959*da0073e9SAndroid Build Coastguard Worker        self._test_bmm(device, dtype)
1960*da0073e9SAndroid Build Coastguard Worker
1961*da0073e9SAndroid Build Coastguard Worker    # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
1962*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1963*da0073e9SAndroid Build Coastguard Worker    def test_bmm_noncontiguous(self, device, dtype):
1964*da0073e9SAndroid Build Coastguard Worker        nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair(
1965*da0073e9SAndroid Build Coastguard Worker            (2, 3), device, dtype
1966*da0073e9SAndroid Build Coastguard Worker        )
1967*da0073e9SAndroid Build Coastguard Worker        nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair(
1968*da0073e9SAndroid Build Coastguard Worker            (6, 7), device, dtype
1969*da0073e9SAndroid Build Coastguard Worker        )
1970*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1971*da0073e9SAndroid Build Coastguard Worker            nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
1972*da0073e9SAndroid Build Coastguard Worker            nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous),
1973*da0073e9SAndroid Build Coastguard Worker        )
1974*da0073e9SAndroid Build Coastguard Worker
1975*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1976*da0073e9SAndroid Build Coastguard Worker    def test_matmul_with_bmm_path(self, device, dtype):
1977*da0073e9SAndroid Build Coastguard Worker        def unbind_rebind_matmul(nt1, nt2):
1978*da0073e9SAndroid Build Coastguard Worker            t1s = nt1.unbind()
1979*da0073e9SAndroid Build Coastguard Worker            t2s = nt2.unbind()
1980*da0073e9SAndroid Build Coastguard Worker            out_ts = [t1.matmul(t2) for t1, t2 in zip(t1s, t2s)]
1981*da0073e9SAndroid Build Coastguard Worker            return torch.nested.nested_tensor(out_ts)
1982*da0073e9SAndroid Build Coastguard Worker
1983*da0073e9SAndroid Build Coastguard Worker        # [N, n_head, *, head_dim], [N, n_head, head_dim, *]
1984*da0073e9SAndroid Build Coastguard Worker        Ns = [1, 2, 5]
1985*da0073e9SAndroid Build Coastguard Worker        n_heads = np.random.randint(2, 5)
1986*da0073e9SAndroid Build Coastguard Worker        head_dim = 3
1987*da0073e9SAndroid Build Coastguard Worker        t1s = []
1988*da0073e9SAndroid Build Coastguard Worker        t2s = []
1989*da0073e9SAndroid Build Coastguard Worker        for N in Ns:
1990*da0073e9SAndroid Build Coastguard Worker            for _ in range(N):
1991*da0073e9SAndroid Build Coastguard Worker                seq_len1 = np.random.randint(2, 5)
1992*da0073e9SAndroid Build Coastguard Worker                seq_len2 = np.random.randint(2, 5)
1993*da0073e9SAndroid Build Coastguard Worker                t1s.append(torch.randn(n_heads, seq_len1, head_dim))
1994*da0073e9SAndroid Build Coastguard Worker                t2s.append(torch.randn(n_heads, head_dim, seq_len2))
1995*da0073e9SAndroid Build Coastguard Worker            nt1 = torch.nested.nested_tensor(t1s, device=device, dtype=dtype)
1996*da0073e9SAndroid Build Coastguard Worker            nt2 = torch.nested.nested_tensor(t2s, device=device, dtype=dtype)
1997*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.matmul(nt1, nt2), unbind_rebind_matmul(nt1, nt2))
1998*da0073e9SAndroid Build Coastguard Worker
1999*da0073e9SAndroid Build Coastguard Worker        # test with noncontiguous
2000*da0073e9SAndroid Build Coastguard Worker        t3s = []
2001*da0073e9SAndroid Build Coastguard Worker        t4s = []
2002*da0073e9SAndroid Build Coastguard Worker        for _ in range(N):
2003*da0073e9SAndroid Build Coastguard Worker            seq_len = np.random.randint(2, 5)
2004*da0073e9SAndroid Build Coastguard Worker            t3s.append(torch.randn(seq_len, n_heads, head_dim))
2005*da0073e9SAndroid Build Coastguard Worker            t4s.append(torch.randn(seq_len, n_heads, head_dim))
2006*da0073e9SAndroid Build Coastguard Worker        nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose(
2007*da0073e9SAndroid Build Coastguard Worker            1, 2
2008*da0073e9SAndroid Build Coastguard Worker        )
2009*da0073e9SAndroid Build Coastguard Worker        nt4 = (
2010*da0073e9SAndroid Build Coastguard Worker            torch.nested.nested_tensor(t4s, device=device, dtype=dtype)
2011*da0073e9SAndroid Build Coastguard Worker            .transpose(1, 2)
2012*da0073e9SAndroid Build Coastguard Worker            .transpose(2, 3)
2013*da0073e9SAndroid Build Coastguard Worker        )
2014*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4))
2015*da0073e9SAndroid Build Coastguard Worker
2016*da0073e9SAndroid Build Coastguard Worker    # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
2017*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
2018*da0073e9SAndroid Build Coastguard Worker    def test_matmul(self, device, dtype):
2019*da0073e9SAndroid Build Coastguard Worker        # error case: one is nested but the other is not
2020*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
2021*da0073e9SAndroid Build Coastguard Worker            [torch.randn(2), torch.randn(3)], device=device, dtype=dtype
2022*da0073e9SAndroid Build Coastguard Worker        )
2023*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(4, device=device, dtype=dtype)
2024*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2025*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2026*da0073e9SAndroid Build Coastguard Worker            "Expected both to be nested, but got a nested self and non-nested other",
2027*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt, t),
2028*da0073e9SAndroid Build Coastguard Worker        )
2029*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2030*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2031*da0073e9SAndroid Build Coastguard Worker            "Expected both to be nested, but got a non-nested self and nested other",
2032*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(t, nt),
2033*da0073e9SAndroid Build Coastguard Worker        )
2034*da0073e9SAndroid Build Coastguard Worker        # error case: not 3+D tensors
2035*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype)
2036*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
2037*da0073e9SAndroid Build Coastguard Worker            [torch.randn(2), torch.randn(3)], device=device, dtype=dtype
2038*da0073e9SAndroid Build Coastguard Worker        )
2039*da0073e9SAndroid Build Coastguard Worker        nt2 = torch.nested.nested_tensor(
2040*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
2041*da0073e9SAndroid Build Coastguard Worker        )
2042*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2043*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2044*da0073e9SAndroid Build Coastguard Worker            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
2045*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt0, nt0),
2046*da0073e9SAndroid Build Coastguard Worker        )
2047*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2048*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2049*da0073e9SAndroid Build Coastguard Worker            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
2050*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt0, nt1),
2051*da0073e9SAndroid Build Coastguard Worker        )
2052*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2053*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2054*da0073e9SAndroid Build Coastguard Worker            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
2055*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt0, nt2),
2056*da0073e9SAndroid Build Coastguard Worker        )
2057*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2058*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2059*da0073e9SAndroid Build Coastguard Worker            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
2060*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt1, nt0),
2061*da0073e9SAndroid Build Coastguard Worker        )
2062*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2063*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2064*da0073e9SAndroid Build Coastguard Worker            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
2065*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt1, nt1),
2066*da0073e9SAndroid Build Coastguard Worker        )
2067*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2068*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2069*da0073e9SAndroid Build Coastguard Worker            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
2070*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt1, nt2),
2071*da0073e9SAndroid Build Coastguard Worker        )
2072*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2073*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2074*da0073e9SAndroid Build Coastguard Worker            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+",
2075*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt2, nt0),
2076*da0073e9SAndroid Build Coastguard Worker        )
2077*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2078*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2079*da0073e9SAndroid Build Coastguard Worker            r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+",
2080*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt2, nt1),
2081*da0073e9SAndroid Build Coastguard Worker        )
2082*da0073e9SAndroid Build Coastguard Worker        # error case: incompatible batch size
2083*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
2084*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
2085*da0073e9SAndroid Build Coastguard Worker        )
2086*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
2087*da0073e9SAndroid Build Coastguard Worker            [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))],
2088*da0073e9SAndroid Build Coastguard Worker            device=device,
2089*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
2090*da0073e9SAndroid Build Coastguard Worker        )
2091*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2092*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2093*da0073e9SAndroid Build Coastguard Worker            r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
2094*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt0, nt1),
2095*da0073e9SAndroid Build Coastguard Worker        )
2096*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2097*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2098*da0073e9SAndroid Build Coastguard Worker            r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
2099*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt1, nt0),
2100*da0073e9SAndroid Build Coastguard Worker        )
2101*da0073e9SAndroid Build Coastguard Worker        # error case: incompatible (wrong) batch sizes that shouldn't even broadcast?
2102*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
2103*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 2, 4)), torch.randn((2, 3, 4))], device=device, dtype=dtype
2104*da0073e9SAndroid Build Coastguard Worker        )
2105*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
2106*da0073e9SAndroid Build Coastguard Worker            [torch.randn((3, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype
2107*da0073e9SAndroid Build Coastguard Worker        )
2108*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2109*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2110*da0073e9SAndroid Build Coastguard Worker            "matmul(): For nested tensors, batch dimensions must have the same sizes,",
2111*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt0, nt1),
2112*da0073e9SAndroid Build Coastguard Worker        )
2113*da0073e9SAndroid Build Coastguard Worker        # error case: incompatible batch sizes that should technically broadcast
2114*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
2115*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 2, 4)), torch.randn((1, 3, 4))], device=device, dtype=dtype
2116*da0073e9SAndroid Build Coastguard Worker        )
2117*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
2118*da0073e9SAndroid Build Coastguard Worker            [torch.randn((1, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype
2119*da0073e9SAndroid Build Coastguard Worker        )
2120*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2121*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2122*da0073e9SAndroid Build Coastguard Worker            "matmul(): For nested tensors, batch dimensions must have the same sizes,",
2123*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt0, nt1),
2124*da0073e9SAndroid Build Coastguard Worker        )
2125*da0073e9SAndroid Build Coastguard Worker        # error case: underlying matrices cannot be multiplied
2126*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
2127*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
2128*da0073e9SAndroid Build Coastguard Worker        )
2129*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2130*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2131*da0073e9SAndroid Build Coastguard Worker            "matmul(): Nested tensors cannot be matrix multiplied",
2132*da0073e9SAndroid Build Coastguard Worker            lambda: torch.matmul(nt0, nt0),
2133*da0073e9SAndroid Build Coastguard Worker        )
2134*da0073e9SAndroid Build Coastguard Worker        # normal nested tensor: 3D
2135*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
2136*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype
2137*da0073e9SAndroid Build Coastguard Worker        )
2138*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
2139*da0073e9SAndroid Build Coastguard Worker            [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype
2140*da0073e9SAndroid Build Coastguard Worker        )
2141*da0073e9SAndroid Build Coastguard Worker        actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
2142*da0073e9SAndroid Build Coastguard Worker        expect = torch.matmul(
2143*da0073e9SAndroid Build Coastguard Worker            torch.nested.to_padded_tensor(nt0, 0.0),
2144*da0073e9SAndroid Build Coastguard Worker            torch.nested.to_padded_tensor(nt1, 0.0),
2145*da0073e9SAndroid Build Coastguard Worker        )
2146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expect)
2147*da0073e9SAndroid Build Coastguard Worker        # normal nested tensor: 4D (with testing for batch_size=1)
2148*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
2149*da0073e9SAndroid Build Coastguard Worker            [torch.randn((1, 2, 4)), torch.randn((8, 3, 7))], device=device, dtype=dtype
2150*da0073e9SAndroid Build Coastguard Worker        )
2151*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
2152*da0073e9SAndroid Build Coastguard Worker            [torch.randn((1, 4, 6)), torch.randn((8, 7, 5))], device=device, dtype=dtype
2153*da0073e9SAndroid Build Coastguard Worker        )
2154*da0073e9SAndroid Build Coastguard Worker        actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
2155*da0073e9SAndroid Build Coastguard Worker        expect = torch.matmul(
2156*da0073e9SAndroid Build Coastguard Worker            torch.nested.to_padded_tensor(nt0, 0.0),
2157*da0073e9SAndroid Build Coastguard Worker            torch.nested.to_padded_tensor(nt1, 0.0),
2158*da0073e9SAndroid Build Coastguard Worker        )
2159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expect)
2160*da0073e9SAndroid Build Coastguard Worker        # normal nested tensor: 5D
2161*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
2162*da0073e9SAndroid Build Coastguard Worker            [torch.randn((8, 9, 2, 4)), torch.randn((8, 9, 3, 7))],
2163*da0073e9SAndroid Build Coastguard Worker            device=device,
2164*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
2165*da0073e9SAndroid Build Coastguard Worker        )
2166*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
2167*da0073e9SAndroid Build Coastguard Worker            [torch.randn((8, 9, 4, 6)), torch.randn((8, 9, 7, 5))],
2168*da0073e9SAndroid Build Coastguard Worker            device=device,
2169*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
2170*da0073e9SAndroid Build Coastguard Worker        )
2171*da0073e9SAndroid Build Coastguard Worker        actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
2172*da0073e9SAndroid Build Coastguard Worker        expect = torch.matmul(
2173*da0073e9SAndroid Build Coastguard Worker            torch.nested.to_padded_tensor(nt0, 0.0),
2174*da0073e9SAndroid Build Coastguard Worker            torch.nested.to_padded_tensor(nt1, 0.0),
2175*da0073e9SAndroid Build Coastguard Worker        )
2176*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expect)
2177*da0073e9SAndroid Build Coastguard Worker
2178*da0073e9SAndroid Build Coastguard Worker    # only supported on CUDA for now
2179*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
2180*da0073e9SAndroid Build Coastguard Worker    def test_matmul_nt_with_broadcasted_t(self, device, dtype):
2181*da0073e9SAndroid Build Coastguard Worker        # NT (B, *, C, D) with T (D, E) broadcasting case
2182*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims([3, None, 4, 5], device=device, dtype=dtype)
2183*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(5, 6, device=device, dtype=dtype)
2184*da0073e9SAndroid Build Coastguard Worker        output = torch.matmul(nt, t)
2185*da0073e9SAndroid Build Coastguard Worker
2186*da0073e9SAndroid Build Coastguard Worker        # should be equivalent to matmul-ing each component with the dense tensor
2187*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.size(0), output.size(0))
2188*da0073e9SAndroid Build Coastguard Worker        for component, out_component in zip(nt, output):
2189*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_component, torch.matmul(component, t))
2190*da0073e9SAndroid Build Coastguard Worker
2191*da0073e9SAndroid Build Coastguard Worker    # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
2192*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
2193*da0073e9SAndroid Build Coastguard Worker    def test_matmul_noncontiguous(self, device, dtype):
2194*da0073e9SAndroid Build Coastguard Worker        nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair(
2195*da0073e9SAndroid Build Coastguard Worker            (2, 3), device, dtype
2196*da0073e9SAndroid Build Coastguard Worker        )
2197*da0073e9SAndroid Build Coastguard Worker        nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair(
2198*da0073e9SAndroid Build Coastguard Worker            (6, 7), device, dtype
2199*da0073e9SAndroid Build Coastguard Worker        )
2200*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2201*da0073e9SAndroid Build Coastguard Worker            torch.matmul(nt0_contiguous.transpose(-1, -2), nt1_contiguous),
2202*da0073e9SAndroid Build Coastguard Worker            torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous),
2203*da0073e9SAndroid Build Coastguard Worker        )
2204*da0073e9SAndroid Build Coastguard Worker
2205*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
2206*da0073e9SAndroid Build Coastguard Worker    def test_linear(self, device, dtype):
2207*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, device=device, dtype=dtype)
2208*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, device=device, dtype=dtype)
2209*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(3, 2, device=device, dtype=dtype)
2210*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([a, b, c])
2211*da0073e9SAndroid Build Coastguard Worker
2212*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(2, 2, device=device, dtype=dtype)
2213*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn(2, device=device, dtype=dtype)
2214*da0073e9SAndroid Build Coastguard Worker        # success case
2215*da0073e9SAndroid Build Coastguard Worker        torch.functional.F.linear(nt, weight, bias)
2216*da0073e9SAndroid Build Coastguard Worker
2217*da0073e9SAndroid Build Coastguard Worker        # invalid nested tensor dimension
2218*da0073e9SAndroid Build Coastguard Worker        msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2"
2219*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
2220*da0073e9SAndroid Build Coastguard Worker            [
2221*da0073e9SAndroid Build Coastguard Worker                torch.randn(1, device=device, dtype=dtype),
2222*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, device=device, dtype=dtype),
2223*da0073e9SAndroid Build Coastguard Worker            ]
2224*da0073e9SAndroid Build Coastguard Worker        )
2225*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
2226*da0073e9SAndroid Build Coastguard Worker            torch.functional.F.linear(nt1, weight, bias)
2227*da0073e9SAndroid Build Coastguard Worker
2228*da0073e9SAndroid Build Coastguard Worker        # invalid weight shape
2229*da0073e9SAndroid Build Coastguard Worker        msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3"
2230*da0073e9SAndroid Build Coastguard Worker        weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype)
2231*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
2232*da0073e9SAndroid Build Coastguard Worker            torch.functional.F.linear(nt, weight1, bias)
2233*da0073e9SAndroid Build Coastguard Worker
2234*da0073e9SAndroid Build Coastguard Worker        # inconsistent last dim of nested tensor
2235*da0073e9SAndroid Build Coastguard Worker        msg = r"Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals:"
2236*da0073e9SAndroid Build Coastguard Worker        nt2 = torch.nested.nested_tensor(
2237*da0073e9SAndroid Build Coastguard Worker            [
2238*da0073e9SAndroid Build Coastguard Worker                torch.randn(1, 2, device=device, dtype=dtype),
2239*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 3, device=device, dtype=dtype),
2240*da0073e9SAndroid Build Coastguard Worker            ]
2241*da0073e9SAndroid Build Coastguard Worker        )
2242*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
2243*da0073e9SAndroid Build Coastguard Worker            torch.functional.F.linear(nt2, weight, bias)
2244*da0073e9SAndroid Build Coastguard Worker
2245*da0073e9SAndroid Build Coastguard Worker        # Mismatch of nested tensor last dim and weight dimension
2246*da0073e9SAndroid Build Coastguard Worker        weight2 = torch.randn(2, 4, device=device, dtype=dtype)
2247*da0073e9SAndroid Build Coastguard Worker        msg = (
2248*da0073e9SAndroid Build Coastguard Worker            r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'"
2249*da0073e9SAndroid Build Coastguard Worker            r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4"
2250*da0073e9SAndroid Build Coastguard Worker        )
2251*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
2252*da0073e9SAndroid Build Coastguard Worker            torch.functional.F.linear(nt, weight2, bias)
2253*da0073e9SAndroid Build Coastguard Worker
2254*da0073e9SAndroid Build Coastguard Worker        # Nested tensor input and nested weight
2255*da0073e9SAndroid Build Coastguard Worker        nt_weight = nt.clone()
2256*da0073e9SAndroid Build Coastguard Worker        msg = r"Linear does not support nested weight when input is a nested tensor."
2257*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
2258*da0073e9SAndroid Build Coastguard Worker            torch.functional.F.linear(nt, nt_weight, bias)
2259*da0073e9SAndroid Build Coastguard Worker
2260*da0073e9SAndroid Build Coastguard Worker    # TODO: test noncontiguous linear
2261*da0073e9SAndroid Build Coastguard Worker    # For now this tests the error message of linear
2262*da0073e9SAndroid Build Coastguard Worker    # since linear does not support noncontiguous buffer yet
2263*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
2264*da0073e9SAndroid Build Coastguard Worker    def test_linear_noncontiguous(self, device, dtype):
2265*da0073e9SAndroid Build Coastguard Worker        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
2266*da0073e9SAndroid Build Coastguard Worker            (2, 3, 6, 7), device, dtype
2267*da0073e9SAndroid Build Coastguard Worker        )
2268*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn((8, 5), device=device, dtype=dtype)
2269*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2270*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2271*da0073e9SAndroid Build Coastguard Worker            r"for now linear only supports contiguous nested tensor",
2272*da0073e9SAndroid Build Coastguard Worker            lambda: torch.nn.functional.linear(nt_noncontiguous, weight),
2273*da0073e9SAndroid Build Coastguard Worker        )
2274*da0073e9SAndroid Build Coastguard Worker
2275*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
2276*da0073e9SAndroid Build Coastguard Worker    def test_to_padded_tensor_zero_numel_errors(self, device, dtype):
2277*da0073e9SAndroid Build Coastguard Worker        ts = [torch.ones(1, 0), torch.ones(0, 0)]
2278*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
2279*da0073e9SAndroid Build Coastguard Worker            ts, device=device, dtype=dtype, layout=torch.strided
2280*da0073e9SAndroid Build Coastguard Worker        )
2281*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2282*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2283*da0073e9SAndroid Build Coastguard Worker            r"at least one constituent tensor should have non-zero numel",
2284*da0073e9SAndroid Build Coastguard Worker            lambda: torch.nested.to_padded_tensor(nt, 0.0),
2285*da0073e9SAndroid Build Coastguard Worker        )
2286*da0073e9SAndroid Build Coastguard Worker
2287*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
2288*da0073e9SAndroid Build Coastguard Worker    def test_transpose(self, device, dtype):
2289*da0073e9SAndroid Build Coastguard Worker        nt = random_nt(device, dtype, 4, (4, 4))
2290*da0073e9SAndroid Build Coastguard Worker        # error case: transpose nested dimension
2291*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2292*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2293*da0073e9SAndroid Build Coastguard Worker            "Nested tensor dimension 0 cannot be transposed",
2294*da0073e9SAndroid Build Coastguard Worker            lambda: nt.transpose(0, 1),
2295*da0073e9SAndroid Build Coastguard Worker        )
2296*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2297*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2298*da0073e9SAndroid Build Coastguard Worker            "Nested tensor dimension 0 cannot be transposed",
2299*da0073e9SAndroid Build Coastguard Worker            lambda: nt.transpose(1, -3),
2300*da0073e9SAndroid Build Coastguard Worker        )
2301*da0073e9SAndroid Build Coastguard Worker        # error case: dimension out of range
2302*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: nt.transpose(1, 3))
2303*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: nt.transpose(-4, -1))
2304*da0073e9SAndroid Build Coastguard Worker        # normal case
2305*da0073e9SAndroid Build Coastguard Worker        ntT = nt.transpose(-1, -2)
2306*da0073e9SAndroid Build Coastguard Worker        ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
2307*da0073e9SAndroid Build Coastguard Worker        pt = torch.nested.to_padded_tensor(nt, 0.0)
2308*da0073e9SAndroid Build Coastguard Worker        ptT = pt.transpose(-1, -2)
2309*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ptT, ptT_from_ntT)
2310*da0073e9SAndroid Build Coastguard Worker
2311*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
2312*da0073e9SAndroid Build Coastguard Worker    def test_squeeze_unsqueeze(self, device, dtype):
2313*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(6).reshape(2, 3)
2314*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(15).reshape(5, 3)
2315*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([a, b], device=device, dtype=dtype)
2316*da0073e9SAndroid Build Coastguard Worker        # error case: squeeze no dimension
2317*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2318*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2319*da0073e9SAndroid Build Coastguard Worker            "For nested tensors, squeeze without the dim argument",
2320*da0073e9SAndroid Build Coastguard Worker            lambda: nt.squeeze(),
2321*da0073e9SAndroid Build Coastguard Worker        )
2322*da0073e9SAndroid Build Coastguard Worker        # error case: squeeze nested dimension
2323*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2324*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2325*da0073e9SAndroid Build Coastguard Worker            "For nested tensors, squeezing dimension 0",
2326*da0073e9SAndroid Build Coastguard Worker            lambda: nt.squeeze(0),
2327*da0073e9SAndroid Build Coastguard Worker        )
2328*da0073e9SAndroid Build Coastguard Worker        # error case: dimension out of range
2329*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: nt.squeeze(3))
2330*da0073e9SAndroid Build Coastguard Worker        # error case: squeeze nested tensor of singleton tensors
2331*da0073e9SAndroid Build Coastguard Worker        c = torch.ones(1)
2332*da0073e9SAndroid Build Coastguard Worker        nt_singleton = torch.nested.nested_tensor([c, c], device=device, dtype=dtype)
2333*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2334*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2335*da0073e9SAndroid Build Coastguard Worker            "For nested tensors, squeezing a nested tensor of singleton",
2336*da0073e9SAndroid Build Coastguard Worker            lambda: nt_singleton.squeeze(1),
2337*da0073e9SAndroid Build Coastguard Worker        )
2338*da0073e9SAndroid Build Coastguard Worker
2339*da0073e9SAndroid Build Coastguard Worker        # squeezing a dim which does not have size 1 should be a no-op
2340*da0073e9SAndroid Build Coastguard Worker        nt2 = nt.squeeze(-1)
2341*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt, nt2)
2342*da0073e9SAndroid Build Coastguard Worker
2343*da0073e9SAndroid Build Coastguard Worker        # test cases that should work
2344*da0073e9SAndroid Build Coastguard Worker        nt_sizes = nt._nested_tensor_size()
2345*da0073e9SAndroid Build Coastguard Worker        nt_strides = nt._nested_tensor_strides()
2346*da0073e9SAndroid Build Coastguard Worker        for i in range(-2, 4):
2347*da0073e9SAndroid Build Coastguard Worker            if i == 0:
2348*da0073e9SAndroid Build Coastguard Worker                # cannot unsqueeze batch dim
2349*da0073e9SAndroid Build Coastguard Worker                continue
2350*da0073e9SAndroid Build Coastguard Worker            nt_unsqueezed = nt.unsqueeze(i)
2351*da0073e9SAndroid Build Coastguard Worker            # negative dim will correspond to unsqueeze() applied at dim = dim + nt.dim() + 1
2352*da0073e9SAndroid Build Coastguard Worker            wrapped_i = i + nt.dim() + 1 if i < 0 else i
2353*da0073e9SAndroid Build Coastguard Worker            # col_index into nt size tensor is requires subtraction of 1 to ignore batch dim
2354*da0073e9SAndroid Build Coastguard Worker            size_idx = wrapped_i - 1
2355*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
2356*da0073e9SAndroid Build Coastguard Worker                nt_unsqueezed._nested_tensor_size()[:, size_idx],
2357*da0073e9SAndroid Build Coastguard Worker                torch.ones(2, dtype=torch.long),
2358*da0073e9SAndroid Build Coastguard Worker            )
2359*da0073e9SAndroid Build Coastguard Worker            unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx]
2360*da0073e9SAndroid Build Coastguard Worker            if i == nt.ndim or i == -1:
2361*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long))
2362*da0073e9SAndroid Build Coastguard Worker            else:
2363*da0073e9SAndroid Build Coastguard Worker                stride_col_after = nt_strides[:, size_idx]
2364*da0073e9SAndroid Build Coastguard Worker                size_col_after = nt_sizes[:, size_idx]
2365*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(unsqueezed_stride, stride_col_after * size_col_after)
2366*da0073e9SAndroid Build Coastguard Worker            nt_squeezed = nt_unsqueezed.squeeze(i)
2367*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_squeezed, nt)
2368*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_squeezed._nested_tensor_size(), nt_sizes)
2369*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_squeezed._nested_tensor_strides(), nt_strides)
2370*da0073e9SAndroid Build Coastguard Worker
2371*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
2372*da0073e9SAndroid Build Coastguard Worker    def test_transpose_inference_mode_interaction(self, device, dtype):
2373*da0073e9SAndroid Build Coastguard Worker        nt = random_nt(device, dtype, 4, (4, 4))
2374*da0073e9SAndroid Build Coastguard Worker        # Construct in default mode and transpose while in inference mode
2375*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
2376*da0073e9SAndroid Build Coastguard Worker            ntT = nt.transpose(-1, -2)
2377*da0073e9SAndroid Build Coastguard Worker            ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
2378*da0073e9SAndroid Build Coastguard Worker            pt = torch.nested.to_padded_tensor(nt, 0.0)
2379*da0073e9SAndroid Build Coastguard Worker            ptT = pt.transpose(-1, -2)
2380*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ptT, ptT_from_ntT)
2381*da0073e9SAndroid Build Coastguard Worker
2382*da0073e9SAndroid Build Coastguard Worker        # Construct and transpose while in inference mode
2383*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
2384*da0073e9SAndroid Build Coastguard Worker            nt = random_nt(device, dtype, 4, (4, 4))
2385*da0073e9SAndroid Build Coastguard Worker            ntT = nt.transpose(-1, -2)
2386*da0073e9SAndroid Build Coastguard Worker            ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
2387*da0073e9SAndroid Build Coastguard Worker            pt = torch.nested.to_padded_tensor(nt, 0.0)
2388*da0073e9SAndroid Build Coastguard Worker            ptT = pt.transpose(-1, -2)
2389*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ptT, ptT_from_ntT)
2390*da0073e9SAndroid Build Coastguard Worker
2391*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
2392*da0073e9SAndroid Build Coastguard Worker    def test_view(self, device, dtype):
2393*da0073e9SAndroid Build Coastguard Worker        nt = random_nt(device, dtype, 4, (4, 4))
2394*da0073e9SAndroid Build Coastguard Worker        # error case: empty shape
2395*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2396*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2397*da0073e9SAndroid Build Coastguard Worker            r"shape '\[\]' is invalid for a nested tensor",
2398*da0073e9SAndroid Build Coastguard Worker            lambda: nt.view(()),
2399*da0073e9SAndroid Build Coastguard Worker        )
2400*da0073e9SAndroid Build Coastguard Worker        # error case: empty nested tensor
2401*da0073e9SAndroid Build Coastguard Worker        nt_empty = torch.nested.nested_tensor([])
2402*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2403*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2404*da0073e9SAndroid Build Coastguard Worker            "empty nested tensor cannot be reshaped",
2405*da0073e9SAndroid Build Coastguard Worker            lambda: nt_empty.view(-1),
2406*da0073e9SAndroid Build Coastguard Worker        )
2407*da0073e9SAndroid Build Coastguard Worker        # error case: -1 for batch size
2408*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2409*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2410*da0073e9SAndroid Build Coastguard Worker            r"view: For now nested view cannot change or infer the implicit batch dimension",
2411*da0073e9SAndroid Build Coastguard Worker            lambda: nt.view(-1, 2, 3),
2412*da0073e9SAndroid Build Coastguard Worker        )
2413*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2414*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2415*da0073e9SAndroid Build Coastguard Worker            r"shape '\[.*\]' is invalid for input of size [0-9]+",
2416*da0073e9SAndroid Build Coastguard Worker            lambda: nt.view(4, 2, 3),
2417*da0073e9SAndroid Build Coastguard Worker        )
2418*da0073e9SAndroid Build Coastguard Worker        # normal case
2419*da0073e9SAndroid Build Coastguard Worker        x0 = torch.randn((2, 20), device=device, dtype=dtype)
2420*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn((3, 20), device=device, dtype=dtype)
2421*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([x0, x1])
2422*da0073e9SAndroid Build Coastguard Worker        pt = torch.nested.to_padded_tensor(nt, 0.0)
2423*da0073e9SAndroid Build Coastguard Worker        # error case, trying to reshape batch dim to a legit shape
2424*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2425*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2426*da0073e9SAndroid Build Coastguard Worker            r"For now nested view cannot change or infer the implicit batch dimension",
2427*da0073e9SAndroid Build Coastguard Worker            lambda: nt.transpose(-1, -2).view(40, -1),
2428*da0073e9SAndroid Build Coastguard Worker        )
2429*da0073e9SAndroid Build Coastguard Worker        # inherit only the ragged dimension
2430*da0073e9SAndroid Build Coastguard Worker        # (2, 20) -> (2, 5, 4)
2431*da0073e9SAndroid Build Coastguard Worker        # (3, 20) -> (3, 5, 4)
2432*da0073e9SAndroid Build Coastguard Worker        nt1 = nt.view(2, -1, 5, 4)
2433*da0073e9SAndroid Build Coastguard Worker        # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
2434*da0073e9SAndroid Build Coastguard Worker        pt1 = pt.view(2, -1, 5, 4)
2435*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
2436*da0073e9SAndroid Build Coastguard Worker
2437*da0073e9SAndroid Build Coastguard Worker        # more than one -1 (even for "old" dims), should fail
2438*da0073e9SAndroid Build Coastguard Worker        # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
2439*da0073e9SAndroid Build Coastguard Worker        # but we ban "inherit old behavior" for >1 dimension
2440*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2441*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2442*da0073e9SAndroid Build Coastguard Worker            r"only one dimension can be inferred",
2443*da0073e9SAndroid Build Coastguard Worker            lambda: nt1.view(2, -1, -1, 2, 2),
2444*da0073e9SAndroid Build Coastguard Worker        )
2445*da0073e9SAndroid Build Coastguard Worker
2446*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
2447*da0073e9SAndroid Build Coastguard Worker    def test_view_inference_mode_interaction(self, device, dtype):
2448*da0073e9SAndroid Build Coastguard Worker        # Construct in default mode and view while in inference mode
2449*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
2450*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype
2451*da0073e9SAndroid Build Coastguard Worker        )
2452*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
2453*da0073e9SAndroid Build Coastguard Worker            ntT = nt.view(2, -1, 4, 5)
2454*da0073e9SAndroid Build Coastguard Worker            ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
2455*da0073e9SAndroid Build Coastguard Worker            pt = torch.nested.to_padded_tensor(nt, 0.0)
2456*da0073e9SAndroid Build Coastguard Worker            ptT = pt.view(2, -1, 4, 5)
2457*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ptT, ptT_from_ntT)
2458*da0073e9SAndroid Build Coastguard Worker        # Construct and view while in inference mode
2459*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
2460*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
2461*da0073e9SAndroid Build Coastguard Worker                [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype
2462*da0073e9SAndroid Build Coastguard Worker            )
2463*da0073e9SAndroid Build Coastguard Worker            ntT = nt.view(2, -1, 4, 5)
2464*da0073e9SAndroid Build Coastguard Worker            ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
2465*da0073e9SAndroid Build Coastguard Worker            pt = torch.nested.to_padded_tensor(nt, 0.0)
2466*da0073e9SAndroid Build Coastguard Worker            ptT = pt.view(2, -1, 4, 5)
2467*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ptT, ptT_from_ntT)
2468*da0073e9SAndroid Build Coastguard Worker
2469*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
2470*da0073e9SAndroid Build Coastguard Worker    def test_reshape(self, device, dtype):
2471*da0073e9SAndroid Build Coastguard Worker        nt = random_nt(device, dtype, 4, (4, 4))
2472*da0073e9SAndroid Build Coastguard Worker        # error case: empty shape
2473*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2474*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2475*da0073e9SAndroid Build Coastguard Worker            r"shape '\[\]' is invalid for a nested tensor",
2476*da0073e9SAndroid Build Coastguard Worker            lambda: nt.reshape(()),
2477*da0073e9SAndroid Build Coastguard Worker        )
2478*da0073e9SAndroid Build Coastguard Worker        # error case: empty nested tensor
2479*da0073e9SAndroid Build Coastguard Worker        nt_empty = torch.nested.nested_tensor([])
2480*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2481*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2482*da0073e9SAndroid Build Coastguard Worker            "empty nested tensor cannot be reshaped",
2483*da0073e9SAndroid Build Coastguard Worker            lambda: nt_empty.reshape(-1),
2484*da0073e9SAndroid Build Coastguard Worker        )
2485*da0073e9SAndroid Build Coastguard Worker        # error case: -1 for batch size
2486*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2487*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2488*da0073e9SAndroid Build Coastguard Worker            r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
2489*da0073e9SAndroid Build Coastguard Worker            lambda: nt.reshape(-1, 2, 3),
2490*da0073e9SAndroid Build Coastguard Worker        )
2491*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2492*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2493*da0073e9SAndroid Build Coastguard Worker            r"shape '\[.*\]' is invalid for input of size [0-9]+",
2494*da0073e9SAndroid Build Coastguard Worker            lambda: nt.reshape(4, 2, 3),
2495*da0073e9SAndroid Build Coastguard Worker        )
2496*da0073e9SAndroid Build Coastguard Worker        # normal case
2497*da0073e9SAndroid Build Coastguard Worker        x0 = torch.randn((2, 20), device=device, dtype=dtype)
2498*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn((3, 20), device=device, dtype=dtype)
2499*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([x0, x1])  # (2, (2, 3), 20)
2500*da0073e9SAndroid Build Coastguard Worker        pt = torch.nested.to_padded_tensor(nt, 0.0)
2501*da0073e9SAndroid Build Coastguard Worker        # error case, trying to reshape batch dim to a legit shape
2502*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2503*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2504*da0073e9SAndroid Build Coastguard Worker            r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
2505*da0073e9SAndroid Build Coastguard Worker            lambda: nt.transpose(-1, -2).reshape(40, -1),
2506*da0073e9SAndroid Build Coastguard Worker        )
2507*da0073e9SAndroid Build Coastguard Worker        # inherit only the ragged dimension
2508*da0073e9SAndroid Build Coastguard Worker        # (2, 20) -> (2, 5, 4)
2509*da0073e9SAndroid Build Coastguard Worker        # (3, 20) -> (3, 5, 4)
2510*da0073e9SAndroid Build Coastguard Worker        nt1 = nt.reshape(2, -1, 5, 4)
2511*da0073e9SAndroid Build Coastguard Worker        # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
2512*da0073e9SAndroid Build Coastguard Worker        pt1 = pt.reshape(2, -1, 5, 4)
2513*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
2514*da0073e9SAndroid Build Coastguard Worker
2515*da0073e9SAndroid Build Coastguard Worker        # more than one -1 (even for "old" dims), should fail
2516*da0073e9SAndroid Build Coastguard Worker        # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
2517*da0073e9SAndroid Build Coastguard Worker        # but we ban "inherit old behavior" for >1 dimension
2518*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2519*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2520*da0073e9SAndroid Build Coastguard Worker            r"only one dimension can be inferred",
2521*da0073e9SAndroid Build Coastguard Worker            lambda: nt1.reshape(2, -1, -1, 2, 2),
2522*da0073e9SAndroid Build Coastguard Worker        )
2523*da0073e9SAndroid Build Coastguard Worker
2524*da0073e9SAndroid Build Coastguard Worker    def test_nested_masked_select(self, device):
2525*da0073e9SAndroid Build Coastguard Worker        t = torch.randn([3, 3], device=device)
2526*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor([False], device=device)
2527*da0073e9SAndroid Build Coastguard Worker
2528*da0073e9SAndroid Build Coastguard Worker        njt = torch.nested.masked_select(t, mask)
2529*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(njt.values(), torch.tensor([], device=device))
2530*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 0], device=device))
2531*da0073e9SAndroid Build Coastguard Worker
2532*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor([[False], [False], [True]], device=device)
2533*da0073e9SAndroid Build Coastguard Worker        njt = torch.nested.masked_select(t, mask)
2534*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(njt.values(), t[-1], atol=0.1, rtol=0.1)
2535*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 3], device=device))
2536*da0073e9SAndroid Build Coastguard Worker
2537*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor(
2538*da0073e9SAndroid Build Coastguard Worker            [[False, False, True], [True, False, True], [False, False, True]],
2539*da0073e9SAndroid Build Coastguard Worker            device=device,
2540*da0073e9SAndroid Build Coastguard Worker        )
2541*da0073e9SAndroid Build Coastguard Worker        njt = torch.nested.masked_select(t, mask)
2542*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(njt.values(), t.masked_select(mask))
2543*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(njt.offsets(), torch.tensor([0, 1, 3, 4], device=device))
2544*da0073e9SAndroid Build Coastguard Worker
2545*da0073e9SAndroid Build Coastguard Worker        t = torch.randn([2, 3, 3, 1], device=device)
2546*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor(
2547*da0073e9SAndroid Build Coastguard Worker            [
2548*da0073e9SAndroid Build Coastguard Worker                [
2549*da0073e9SAndroid Build Coastguard Worker                    [[True], [False], [True]],
2550*da0073e9SAndroid Build Coastguard Worker                    [[True], [False], [True]],
2551*da0073e9SAndroid Build Coastguard Worker                    [[True], [False], [True]],
2552*da0073e9SAndroid Build Coastguard Worker                ],
2553*da0073e9SAndroid Build Coastguard Worker                [
2554*da0073e9SAndroid Build Coastguard Worker                    [[False], [True], [True]],
2555*da0073e9SAndroid Build Coastguard Worker                    [[False], [True], [True]],
2556*da0073e9SAndroid Build Coastguard Worker                    [[True], [True], [True]],
2557*da0073e9SAndroid Build Coastguard Worker                ],
2558*da0073e9SAndroid Build Coastguard Worker            ],
2559*da0073e9SAndroid Build Coastguard Worker            device=device,
2560*da0073e9SAndroid Build Coastguard Worker        )
2561*da0073e9SAndroid Build Coastguard Worker        njt = torch.nested.masked_select(t, mask)
2562*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(njt.values(), t.masked_select(mask))
2563*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2564*da0073e9SAndroid Build Coastguard Worker            njt.offsets(),
2565*da0073e9SAndroid Build Coastguard Worker            torch.tensor(
2566*da0073e9SAndroid Build Coastguard Worker                [0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 11, 12, 13],
2567*da0073e9SAndroid Build Coastguard Worker                device=device,
2568*da0073e9SAndroid Build Coastguard Worker            ),
2569*da0073e9SAndroid Build Coastguard Worker        )
2570*da0073e9SAndroid Build Coastguard Worker
2571*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
2572*da0073e9SAndroid Build Coastguard Worker    def test_narrow(self, device, dtype):
2573*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims([5, None, None, None], device=device, dtype=dtype)
2574*da0073e9SAndroid Build Coastguard Worker
2575*da0073e9SAndroid Build Coastguard Worker        # narrow on dim=0 from start to end
2576*da0073e9SAndroid Build Coastguard Worker        bounds = [(0, 5), (0, 3), (1, 2), (1, 5), (2, 4)]
2577*da0073e9SAndroid Build Coastguard Worker        for start, end in bounds:
2578*da0073e9SAndroid Build Coastguard Worker            length = end - start
2579*da0073e9SAndroid Build Coastguard Worker            narrowed = nt.narrow(dim=0, start=start, length=length)
2580*da0073e9SAndroid Build Coastguard Worker            # ensure output is a view
2581*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(narrowed._base is nt)
2582*da0073e9SAndroid Build Coastguard Worker            for nc, c in zip(narrowed.unbind(), nt.unbind()[start:end]):
2583*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(nc, c)
2584*da0073e9SAndroid Build Coastguard Worker
2585*da0073e9SAndroid Build Coastguard Worker        # dim != 0 is not supported
2586*da0073e9SAndroid Build Coastguard Worker        for dim in range(1, nt.dim()):
2587*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
2588*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "only dim=0 supported for nested tensors"
2589*da0073e9SAndroid Build Coastguard Worker            ):
2590*da0073e9SAndroid Build Coastguard Worker                nt.narrow(dim=dim, start=0, length=1)
2591*da0073e9SAndroid Build Coastguard Worker
2592*da0073e9SAndroid Build Coastguard Worker        # error case: non-contiguous NT
2593*da0073e9SAndroid Build Coastguard Worker        _, nt_noncont = random_nt_noncontiguous_pair((2, 3, 4))
2594*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2595*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "only contiguous nested tensors supported"
2596*da0073e9SAndroid Build Coastguard Worker        ):
2597*da0073e9SAndroid Build Coastguard Worker            nt_noncont.narrow(dim=0, start=0, length=1)
2598*da0073e9SAndroid Build Coastguard Worker
2599*da0073e9SAndroid Build Coastguard Worker    @parametrize("input_dim", [3, 4])
2600*da0073e9SAndroid Build Coastguard Worker    def test_scaled_dot_product_attention(self, device, input_dim):
2601*da0073e9SAndroid Build Coastguard Worker        def rand_tensor(*shape):
2602*da0073e9SAndroid Build Coastguard Worker            return torch.randn(shape, device=device)
2603*da0073e9SAndroid Build Coastguard Worker
2604*da0073e9SAndroid Build Coastguard Worker        E = 8
2605*da0073e9SAndroid Build Coastguard Worker        if input_dim == 3:
2606*da0073e9SAndroid Build Coastguard Worker            # Shape: (N, L, E); ragged L
2607*da0073e9SAndroid Build Coastguard Worker            query = torch.nested.nested_tensor(
2608*da0073e9SAndroid Build Coastguard Worker                [rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)]
2609*da0073e9SAndroid Build Coastguard Worker            )
2610*da0073e9SAndroid Build Coastguard Worker
2611*da0073e9SAndroid Build Coastguard Worker            # Shape: (N, S, E); ragged S
2612*da0073e9SAndroid Build Coastguard Worker            key = torch.nested.nested_tensor(
2613*da0073e9SAndroid Build Coastguard Worker                [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]
2614*da0073e9SAndroid Build Coastguard Worker            )
2615*da0073e9SAndroid Build Coastguard Worker            value = torch.nested.nested_tensor(
2616*da0073e9SAndroid Build Coastguard Worker                [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]
2617*da0073e9SAndroid Build Coastguard Worker            )
2618*da0073e9SAndroid Build Coastguard Worker        elif input_dim == 4:
2619*da0073e9SAndroid Build Coastguard Worker            # In the 4D case the L and S is ragged
2620*da0073e9SAndroid Build Coastguard Worker            # Shape: (N, N', L, E); ragged N' and L
2621*da0073e9SAndroid Build Coastguard Worker            query = torch.nested.nested_tensor(
2622*da0073e9SAndroid Build Coastguard Worker                [rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)]
2623*da0073e9SAndroid Build Coastguard Worker            )
2624*da0073e9SAndroid Build Coastguard Worker            # Shape: (N, N', S, E); ragged N' and S
2625*da0073e9SAndroid Build Coastguard Worker            key = torch.nested.nested_tensor(
2626*da0073e9SAndroid Build Coastguard Worker                [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]
2627*da0073e9SAndroid Build Coastguard Worker            )
2628*da0073e9SAndroid Build Coastguard Worker            value = torch.nested.nested_tensor(
2629*da0073e9SAndroid Build Coastguard Worker                [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]
2630*da0073e9SAndroid Build Coastguard Worker            )
2631*da0073e9SAndroid Build Coastguard Worker        else:
2632*da0073e9SAndroid Build Coastguard Worker            self.fail(f"Invalid input_dim {input_dim} encountered in SDP test")
2633*da0073e9SAndroid Build Coastguard Worker
2634*da0073e9SAndroid Build Coastguard Worker        def rand_mask(size):
2635*da0073e9SAndroid Build Coastguard Worker            return torch.randint(0, 2, size=size, dtype=torch.bool, device=device)
2636*da0073e9SAndroid Build Coastguard Worker
2637*da0073e9SAndroid Build Coastguard Worker        # Shape: (N, L, S); ragged L and S matching above
2638*da0073e9SAndroid Build Coastguard Worker        attn_mask = torch.nested.nested_tensor(
2639*da0073e9SAndroid Build Coastguard Worker            [rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))]
2640*da0073e9SAndroid Build Coastguard Worker        )
2641*da0073e9SAndroid Build Coastguard Worker
2642*da0073e9SAndroid Build Coastguard Worker        dropout_p = 0.0  # no dropout for reproducibility
2643*da0073e9SAndroid Build Coastguard Worker
2644*da0073e9SAndroid Build Coastguard Worker        # Success case: no attn_mask set and is_causal=False.
2645*da0073e9SAndroid Build Coastguard Worker        actual = torch.nn.functional.scaled_dot_product_attention(
2646*da0073e9SAndroid Build Coastguard Worker            query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p
2647*da0073e9SAndroid Build Coastguard Worker        )
2648*da0073e9SAndroid Build Coastguard Worker
2649*da0073e9SAndroid Build Coastguard Worker        expected_outputs = []
2650*da0073e9SAndroid Build Coastguard Worker        for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()):
2651*da0073e9SAndroid Build Coastguard Worker            output = torch.nn.functional.scaled_dot_product_attention(
2652*da0073e9SAndroid Build Coastguard Worker                q.unsqueeze(0),
2653*da0073e9SAndroid Build Coastguard Worker                k.unsqueeze(0),
2654*da0073e9SAndroid Build Coastguard Worker                v.unsqueeze(0),
2655*da0073e9SAndroid Build Coastguard Worker                attn_mask=None,
2656*da0073e9SAndroid Build Coastguard Worker                dropout_p=dropout_p,
2657*da0073e9SAndroid Build Coastguard Worker            )
2658*da0073e9SAndroid Build Coastguard Worker            expected_outputs.append(output.squeeze(0))
2659*da0073e9SAndroid Build Coastguard Worker        expected_output_nested = torch.nested.nested_tensor(expected_outputs)
2660*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected_output_nested)
2661*da0073e9SAndroid Build Coastguard Worker
2662*da0073e9SAndroid Build Coastguard Worker        # Error case: explicit attn_mask set.
2663*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2664*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "not supported when an explicit attn_mask is set"
2665*da0073e9SAndroid Build Coastguard Worker        ):
2666*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.scaled_dot_product_attention(
2667*da0073e9SAndroid Build Coastguard Worker                query, key, value, attn_mask=attn_mask, dropout_p=dropout_p
2668*da0073e9SAndroid Build Coastguard Worker            )
2669*da0073e9SAndroid Build Coastguard Worker
2670*da0073e9SAndroid Build Coastguard Worker        # Error case: is_causal=True.
2671*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"):
2672*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.scaled_dot_product_attention(
2673*da0073e9SAndroid Build Coastguard Worker                query, key, value, dropout_p=dropout_p, is_causal=True
2674*da0073e9SAndroid Build Coastguard Worker            )
2675*da0073e9SAndroid Build Coastguard Worker
2676*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.float16, torch.double)
2677*da0073e9SAndroid Build Coastguard Worker    def test_empty_like(self, device, dtype):
2678*da0073e9SAndroid Build Coastguard Worker        ntensors = 4
2679*da0073e9SAndroid Build Coastguard Worker        nt = random_nt(device, dtype, ntensors, (4, 4))
2680*da0073e9SAndroid Build Coastguard Worker
2681*da0073e9SAndroid Build Coastguard Worker        # Create empty on same device as original nested tensor
2682*da0073e9SAndroid Build Coastguard Worker        nt_empty = torch.empty_like(nt)
2683*da0073e9SAndroid Build Coastguard Worker        assert nt.is_same_size(nt_empty)
2684*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.dtype, nt_empty.dtype)
2685*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.device, nt_empty.device)
2686*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.layout, nt_empty.layout)
2687*da0073e9SAndroid Build Coastguard Worker
2688*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
2689*da0073e9SAndroid Build Coastguard Worker            if device == "cpu":
2690*da0073e9SAndroid Build Coastguard Worker                nt_cuda = torch.empty_like(nt, device="cuda")
2691*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.device("cuda").type, nt_cuda.device.type)
2692*da0073e9SAndroid Build Coastguard Worker            else:
2693*da0073e9SAndroid Build Coastguard Worker                nt_cpu = torch.empty_like(nt, device="cpu")
2694*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.device("cpu").type, nt_cpu.device.type)
2695*da0073e9SAndroid Build Coastguard Worker
2696*da0073e9SAndroid Build Coastguard Worker        # Check changing dtype of empty_like nested tensor output
2697*da0073e9SAndroid Build Coastguard Worker        dtype_set = {torch.float, torch.float16, torch.double}
2698*da0073e9SAndroid Build Coastguard Worker        for other_dtype in dtype_set - {dtype}:
2699*da0073e9SAndroid Build Coastguard Worker            nt_empty_other_dtype = torch.empty_like(nt, dtype=other_dtype)
2700*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt.dtype, dtype)
2701*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_empty_other_dtype.dtype, other_dtype)
2702*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt.device, nt_empty.device)
2703*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt.layout, nt_empty.layout)
2704*da0073e9SAndroid Build Coastguard Worker
2705*da0073e9SAndroid Build Coastguard Worker        # Create tensor for autograd
2706*da0073e9SAndroid Build Coastguard Worker        nt_empty_req_grad = torch.empty_like(nt, requires_grad=True)
2707*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt_empty_req_grad.requires_grad, True)
2708*da0073e9SAndroid Build Coastguard Worker
2709*da0073e9SAndroid Build Coastguard Worker        # Test noncontiguous tensor does not fail to copy
2710*da0073e9SAndroid Build Coastguard Worker        nt_cont, nt_noncont = random_nt_noncontiguous_pair((2, 3, 6, 7))
2711*da0073e9SAndroid Build Coastguard Worker        nt_empty = torch.empty_like(nt_cont)
2712*da0073e9SAndroid Build Coastguard Worker        assert nt_cont.is_same_size(nt_empty)
2713*da0073e9SAndroid Build Coastguard Worker        nt_empty_non_contig = torch.empty_like(nt_noncont)
2714*da0073e9SAndroid Build Coastguard Worker        assert nt_noncont.is_same_size(nt_empty_non_contig)
2715*da0073e9SAndroid Build Coastguard Worker
2716*da0073e9SAndroid Build Coastguard Worker        # Test the contiguous memory format option
2717*da0073e9SAndroid Build Coastguard Worker        nt_empty_contig = torch.empty_like(
2718*da0073e9SAndroid Build Coastguard Worker            nt_cont, memory_format=torch.contiguous_format
2719*da0073e9SAndroid Build Coastguard Worker        )
2720*da0073e9SAndroid Build Coastguard Worker        assert nt_cont.is_same_size(nt_empty_contig)
2721*da0073e9SAndroid Build Coastguard Worker        assert nt_empty_contig.is_contiguous()
2722*da0073e9SAndroid Build Coastguard Worker
2723*da0073e9SAndroid Build Coastguard Worker        nt_empty_non_contig = torch.empty_like(
2724*da0073e9SAndroid Build Coastguard Worker            nt_noncont, memory_format=torch.contiguous_format
2725*da0073e9SAndroid Build Coastguard Worker        )
2726*da0073e9SAndroid Build Coastguard Worker        assert nt_noncont.is_same_size(nt_empty_non_contig)
2727*da0073e9SAndroid Build Coastguard Worker        assert nt_empty_non_contig.is_contiguous()
2728*da0073e9SAndroid Build Coastguard Worker
2729*da0073e9SAndroid Build Coastguard Worker        # Test other memory formats fail
2730*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
2731*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2732*da0073e9SAndroid Build Coastguard Worker            lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last),
2733*da0073e9SAndroid Build Coastguard Worker        )
2734*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
2735*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2736*da0073e9SAndroid Build Coastguard Worker            lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last),
2737*da0073e9SAndroid Build Coastguard Worker        )
2738*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
2739*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2740*da0073e9SAndroid Build Coastguard Worker            lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d),
2741*da0073e9SAndroid Build Coastguard Worker        )
2742*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
2743*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2744*da0073e9SAndroid Build Coastguard Worker            lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d),
2745*da0073e9SAndroid Build Coastguard Worker        )
2746*da0073e9SAndroid Build Coastguard Worker
2747*da0073e9SAndroid Build Coastguard Worker
2748*da0073e9SAndroid Build Coastguard Worker@markDynamoStrictTest
2749*da0073e9SAndroid Build Coastguard Workerclass TestNestedTensorAutograd(NestedTensorTestCase):
2750*da0073e9SAndroid Build Coastguard Worker    # Note [Gradcheck args check_batched_grad=False] the common_utils testing version of gradcheck
2751*da0073e9SAndroid Build Coastguard Worker    # includes the default parameters used for testing ops with gradcheck. However nested tensor
2752*da0073e9SAndroid Build Coastguard Worker    # does not support the stack op therefore we turn it off for these tests
2753*da0073e9SAndroid Build Coastguard Worker    def _create_leaf_nested_tensor_from_list(self, tensor_device, requires_grad=False):
2754*da0073e9SAndroid Build Coastguard Worker        return torch.nested.nested_tensor(
2755*da0073e9SAndroid Build Coastguard Worker            [torch.randn(1, 2), torch.randn(7, 8)],
2756*da0073e9SAndroid Build Coastguard Worker            requires_grad=requires_grad,
2757*da0073e9SAndroid Build Coastguard Worker            device=tensor_device,
2758*da0073e9SAndroid Build Coastguard Worker        )
2759*da0073e9SAndroid Build Coastguard Worker
2760*da0073e9SAndroid Build Coastguard Worker    def _create_nested_tensor_from_list(self, tensor_device, requires_grad=False):
2761*da0073e9SAndroid Build Coastguard Worker        return torch.nested.as_nested_tensor(
2762*da0073e9SAndroid Build Coastguard Worker            [
2763*da0073e9SAndroid Build Coastguard Worker                torch.randn(1, 2, requires_grad=requires_grad),
2764*da0073e9SAndroid Build Coastguard Worker                torch.randn(7, 8, requires_grad=requires_grad),
2765*da0073e9SAndroid Build Coastguard Worker            ],
2766*da0073e9SAndroid Build Coastguard Worker            device=tensor_device,
2767*da0073e9SAndroid Build Coastguard Worker        )
2768*da0073e9SAndroid Build Coastguard Worker
2769*da0073e9SAndroid Build Coastguard Worker    def _create_nested_tensor_from_mask(self, tensor_device, requires_grad=False):
2770*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(2, 3, 4, requires_grad=requires_grad, device=tensor_device)
2771*da0073e9SAndroid Build Coastguard Worker        mask = torch.ones_like(data[:, :, 0]).bool()
2772*da0073e9SAndroid Build Coastguard Worker        return torch._nested_tensor_from_mask(data, mask)
2773*da0073e9SAndroid Build Coastguard Worker
2774*da0073e9SAndroid Build Coastguard Worker    def test_as_nested_tensor_propagates_gradients(self, device):
2775*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(3, dtype=torch.float, device=device)
2776*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(5, dtype=torch.float, device=device)
2777*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.as_nested_tensor([a, b])
2778*da0073e9SAndroid Build Coastguard Worker        # tensors with requires_grad=False are leaves
2779*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nt.is_leaf)
2780*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not nt.requires_grad)
2781*da0073e9SAndroid Build Coastguard Worker
2782*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device)
2783*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device)
2784*da0073e9SAndroid Build Coastguard Worker        nt2 = torch.nested.as_nested_tensor([a, b])
2785*da0073e9SAndroid Build Coastguard Worker        fake_grad = torch.nested.nested_tensor(
2786*da0073e9SAndroid Build Coastguard Worker            [torch.ones_like(a), torch.zeros_like(b)], device=device
2787*da0073e9SAndroid Build Coastguard Worker        )
2788*da0073e9SAndroid Build Coastguard Worker        nt2.backward(fake_grad)
2789*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, fake_grad[0])
2790*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad, fake_grad[1])
2791*da0073e9SAndroid Build Coastguard Worker
2792*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_generates_leaf(self, device):
2793*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device)
2794*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device)
2795*da0073e9SAndroid Build Coastguard Worker
2796*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([a, b], requires_grad=False)
2797*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nt.is_leaf)
2798*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not nt.requires_grad)
2799*da0073e9SAndroid Build Coastguard Worker
2800*da0073e9SAndroid Build Coastguard Worker        nt2 = torch.nested.nested_tensor([a, b], requires_grad=True)
2801*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nt2.is_leaf)
2802*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nt2.requires_grad)
2803*da0073e9SAndroid Build Coastguard Worker
2804*da0073e9SAndroid Build Coastguard Worker        fake_grad = torch.nested.nested_tensor(
2805*da0073e9SAndroid Build Coastguard Worker            [torch.ones_like(a), torch.zeros_like(b)], device=device
2806*da0073e9SAndroid Build Coastguard Worker        )
2807*da0073e9SAndroid Build Coastguard Worker        nt2.backward(fake_grad)
2808*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt2.grad, fake_grad)
2809*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, None)
2810*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad, None)
2811*da0073e9SAndroid Build Coastguard Worker
2812*da0073e9SAndroid Build Coastguard Worker    def test_set_requires_grad_from_list(self, device):
2813*da0073e9SAndroid Build Coastguard Worker        nt = self._create_nested_tensor_from_list(device)
2814*da0073e9SAndroid Build Coastguard Worker        nt.requires_grad_()
2815*da0073e9SAndroid Build Coastguard Worker        assert nt.requires_grad
2816*da0073e9SAndroid Build Coastguard Worker
2817*da0073e9SAndroid Build Coastguard Worker    def test_set_requires_grad_from_mask(self, device):
2818*da0073e9SAndroid Build Coastguard Worker        nt = self._create_nested_tensor_from_mask(device)
2819*da0073e9SAndroid Build Coastguard Worker        nt.requires_grad_()
2820*da0073e9SAndroid Build Coastguard Worker        assert nt.requires_grad
2821*da0073e9SAndroid Build Coastguard Worker
2822*da0073e9SAndroid Build Coastguard Worker    def test_backward_for_add_op(self, device):
2823*da0073e9SAndroid Build Coastguard Worker        nt_1 = self._create_nested_tensor_from_mask(device)
2824*da0073e9SAndroid Build Coastguard Worker        nt_2 = self._create_nested_tensor_from_mask(device)
2825*da0073e9SAndroid Build Coastguard Worker
2826*da0073e9SAndroid Build Coastguard Worker        nt_1.requires_grad_()
2827*da0073e9SAndroid Build Coastguard Worker        c = nt_1 + nt_2
2828*da0073e9SAndroid Build Coastguard Worker
2829*da0073e9SAndroid Build Coastguard Worker        assert nt_1.requires_grad
2830*da0073e9SAndroid Build Coastguard Worker        assert c.requires_grad
2831*da0073e9SAndroid Build Coastguard Worker        grad_output = self._create_nested_tensor_from_mask(device)
2832*da0073e9SAndroid Build Coastguard Worker        c.backward(grad_output)
2833*da0073e9SAndroid Build Coastguard Worker
2834*da0073e9SAndroid Build Coastguard Worker        #  Grad check doesn't work with nested yet.
2835*da0073e9SAndroid Build Coastguard Worker        # d/dnt_1 (nt + nt_1) = 1*grad_output
2836*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt_1.grad, grad_output)
2837*da0073e9SAndroid Build Coastguard Worker
2838*da0073e9SAndroid Build Coastguard Worker    def test_backward_for_sub_op(self, device):
2839*da0073e9SAndroid Build Coastguard Worker        nt_1 = self._create_nested_tensor_from_mask(device)
2840*da0073e9SAndroid Build Coastguard Worker        nt_2 = self._create_nested_tensor_from_mask(device)
2841*da0073e9SAndroid Build Coastguard Worker
2842*da0073e9SAndroid Build Coastguard Worker        nt_1.requires_grad_()
2843*da0073e9SAndroid Build Coastguard Worker        nt_2.requires_grad_()
2844*da0073e9SAndroid Build Coastguard Worker        c = nt_1 - nt_2
2845*da0073e9SAndroid Build Coastguard Worker
2846*da0073e9SAndroid Build Coastguard Worker        assert nt_1.requires_grad
2847*da0073e9SAndroid Build Coastguard Worker        assert nt_2.requires_grad
2848*da0073e9SAndroid Build Coastguard Worker        assert c.requires_grad
2849*da0073e9SAndroid Build Coastguard Worker        grad_output = self._create_nested_tensor_from_mask(device)
2850*da0073e9SAndroid Build Coastguard Worker        c.backward(grad_output)
2851*da0073e9SAndroid Build Coastguard Worker
2852*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt_1.grad, grad_output)
2853*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt_2.grad, -1 * grad_output)
2854*da0073e9SAndroid Build Coastguard Worker
2855*da0073e9SAndroid Build Coastguard Worker    def test_backward_sub_strided(self, device):
2856*da0073e9SAndroid Build Coastguard Worker        a = torch.nested.nested_tensor(
2857*da0073e9SAndroid Build Coastguard Worker            [torch.randn(9, 2, 4), torch.randn(12, 2, 4)],
2858*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
2859*da0073e9SAndroid Build Coastguard Worker            device=device,
2860*da0073e9SAndroid Build Coastguard Worker        )
2861*da0073e9SAndroid Build Coastguard Worker        b = torch.nested.nested_tensor(
2862*da0073e9SAndroid Build Coastguard Worker            [torch.randn(9, 4, 2), torch.randn(12, 4, 2)],
2863*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
2864*da0073e9SAndroid Build Coastguard Worker            device=device,
2865*da0073e9SAndroid Build Coastguard Worker        )
2866*da0073e9SAndroid Build Coastguard Worker        c = a - b.transpose(-1, -2)
2867*da0073e9SAndroid Build Coastguard Worker        grad_output = c.clone()
2868*da0073e9SAndroid Build Coastguard Worker        c.backward(grad_output)
2869*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, grad_output)
2870*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad, -1 * grad_output.transpose(-1, -2))
2871*da0073e9SAndroid Build Coastguard Worker
2872*da0073e9SAndroid Build Coastguard Worker    def test_backward_add_strided(self, device):
2873*da0073e9SAndroid Build Coastguard Worker        a = torch.nested.nested_tensor(
2874*da0073e9SAndroid Build Coastguard Worker            [torch.randn(9, 2, 4), torch.randn(12, 2, 4)],
2875*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
2876*da0073e9SAndroid Build Coastguard Worker            device=device,
2877*da0073e9SAndroid Build Coastguard Worker        )
2878*da0073e9SAndroid Build Coastguard Worker        b = torch.nested.nested_tensor(
2879*da0073e9SAndroid Build Coastguard Worker            [torch.randn(9, 4, 2), torch.randn(12, 4, 2)],
2880*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
2881*da0073e9SAndroid Build Coastguard Worker            device=device,
2882*da0073e9SAndroid Build Coastguard Worker        )
2883*da0073e9SAndroid Build Coastguard Worker        c = a + b.transpose(-1, -2)
2884*da0073e9SAndroid Build Coastguard Worker        grad_output = c.clone()
2885*da0073e9SAndroid Build Coastguard Worker        c.backward(grad_output)
2886*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, grad_output)
2887*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad, grad_output.transpose(-1, -2))
2888*da0073e9SAndroid Build Coastguard Worker
2889*da0073e9SAndroid Build Coastguard Worker    # Test Factory Functions
2890*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_to_padded_tensor(self, device):
2891*da0073e9SAndroid Build Coastguard Worker        for padding_val in [0, 1]:
2892*da0073e9SAndroid Build Coastguard Worker            nt = self._create_leaf_nested_tensor_from_list(
2893*da0073e9SAndroid Build Coastguard Worker                tensor_device=device, requires_grad=True
2894*da0073e9SAndroid Build Coastguard Worker            )
2895*da0073e9SAndroid Build Coastguard Worker
2896*da0073e9SAndroid Build Coastguard Worker            out = torch.nested.to_padded_tensor(nt, padding_val)
2897*da0073e9SAndroid Build Coastguard Worker            grad_output = torch.ones(out.shape, device=device)
2898*da0073e9SAndroid Build Coastguard Worker            out.backward(grad_output)
2899*da0073e9SAndroid Build Coastguard Worker
2900*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
2901*da0073e9SAndroid Build Coastguard Worker                nt.grad,
2902*da0073e9SAndroid Build Coastguard Worker                torch.nested.nested_tensor(
2903*da0073e9SAndroid Build Coastguard Worker                    [torch.ones(1, 2), torch.ones(7, 8)], device=device
2904*da0073e9SAndroid Build Coastguard Worker                ),
2905*da0073e9SAndroid Build Coastguard Worker            )
2906*da0073e9SAndroid Build Coastguard Worker
2907*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_from_mask_and_to_padded(self, device):
2908*da0073e9SAndroid Build Coastguard Worker        N, L, D = 2, 4, 4
2909*da0073e9SAndroid Build Coastguard Worker        mask = torch.ones(N, L, device=device)
2910*da0073e9SAndroid Build Coastguard Worker        for i in range(1, N):
2911*da0073e9SAndroid Build Coastguard Worker            end = torch.randint(1, L - 1, (1,), device=device)
2912*da0073e9SAndroid Build Coastguard Worker            mask[i, end:] = 0
2913*da0073e9SAndroid Build Coastguard Worker
2914*da0073e9SAndroid Build Coastguard Worker        mask[0, :] = 1
2915*da0073e9SAndroid Build Coastguard Worker        mask = mask.bool()
2916*da0073e9SAndroid Build Coastguard Worker
2917*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(
2918*da0073e9SAndroid Build Coastguard Worker            N, L, D, requires_grad=True, dtype=torch.float64, device=device
2919*da0073e9SAndroid Build Coastguard Worker        )
2920*da0073e9SAndroid Build Coastguard Worker
2921*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(inpt):
2922*da0073e9SAndroid Build Coastguard Worker            nt = torch._nested_tensor_from_mask(inpt, mask)
2923*da0073e9SAndroid Build Coastguard Worker            # This implicitly tests to_padded_tensor grads
2924*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(nt, 0)
2925*da0073e9SAndroid Build Coastguard Worker
2926*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
2927*da0073e9SAndroid Build Coastguard Worker
2928*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_from_padded(self, device):
2929*da0073e9SAndroid Build Coastguard Worker        nested_size = torch.tensor([[1, 2], [2, 2]])
2930*da0073e9SAndroid Build Coastguard Worker        padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64, device=device)
2931*da0073e9SAndroid Build Coastguard Worker        padded_tensor[0, 1, :] = 0
2932*da0073e9SAndroid Build Coastguard Worker        padded_tensor.requires_grad_()
2933*da0073e9SAndroid Build Coastguard Worker
2934*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(tensor, nested_size):
2935*da0073e9SAndroid Build Coastguard Worker            nt = torch._nested_from_padded(
2936*da0073e9SAndroid Build Coastguard Worker                tensor, nested_size, fuse_transform_0213=False
2937*da0073e9SAndroid Build Coastguard Worker            )
2938*da0073e9SAndroid Build Coastguard Worker            # This implicitly tests to_padded_tensor grads
2939*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(nt, 0)
2940*da0073e9SAndroid Build Coastguard Worker
2941*da0073e9SAndroid Build Coastguard Worker        data = (padded_tensor, nested_size)
2942*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
2943*da0073e9SAndroid Build Coastguard Worker
2944*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_from_padded_fused(self, device):
2945*da0073e9SAndroid Build Coastguard Worker        nested_size = torch.tensor([[1, 8], [2, 8]])
2946*da0073e9SAndroid Build Coastguard Worker        padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64, device=device)
2947*da0073e9SAndroid Build Coastguard Worker        padded_tensor[0, 1, :] = 0
2948*da0073e9SAndroid Build Coastguard Worker        padded_tensor.requires_grad_()
2949*da0073e9SAndroid Build Coastguard Worker
2950*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(tensor, nested_size):
2951*da0073e9SAndroid Build Coastguard Worker            nt = torch._nested_from_padded(
2952*da0073e9SAndroid Build Coastguard Worker                tensor, nested_size, fuse_transform_0213=True
2953*da0073e9SAndroid Build Coastguard Worker            )
2954*da0073e9SAndroid Build Coastguard Worker            # This implicitly tests to_padded_tensor grads
2955*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(nt, 0)
2956*da0073e9SAndroid Build Coastguard Worker
2957*da0073e9SAndroid Build Coastguard Worker        data = (padded_tensor, nested_size)
2958*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
2959*da0073e9SAndroid Build Coastguard Worker
2960*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_from_list(self, device):
2961*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
2962*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
2963*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64, device=device)
2964*da0073e9SAndroid Build Coastguard Worker
2965*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
2966*da0073e9SAndroid Build Coastguard Worker            c = torch.nested.as_nested_tensor([a, b, c])
2967*da0073e9SAndroid Build Coastguard Worker            # This implictily tests to_padded_tensor grads
2968*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(c, 0)
2969*da0073e9SAndroid Build Coastguard Worker
2970*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c)
2971*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
2972*da0073e9SAndroid Build Coastguard Worker
2973*da0073e9SAndroid Build Coastguard Worker    @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
2974*da0073e9SAndroid Build Coastguard Worker    def test_dropout_backward(self, layout):
2975*da0073e9SAndroid Build Coastguard Worker        if layout == torch.jagged:
2976*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
2977*da0073e9SAndroid Build Coastguard Worker                [torch.randn((2, 5)), torch.randn((3, 5))],
2978*da0073e9SAndroid Build Coastguard Worker                requires_grad=True,
2979*da0073e9SAndroid Build Coastguard Worker                layout=layout,
2980*da0073e9SAndroid Build Coastguard Worker            )
2981*da0073e9SAndroid Build Coastguard Worker        else:
2982*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
2983*da0073e9SAndroid Build Coastguard Worker                [torch.randn((2, 5)), torch.randn((3, 4))],
2984*da0073e9SAndroid Build Coastguard Worker                requires_grad=True,
2985*da0073e9SAndroid Build Coastguard Worker                layout=layout,
2986*da0073e9SAndroid Build Coastguard Worker            )
2987*da0073e9SAndroid Build Coastguard Worker        p = 0.2
2988*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.dropout(nt, p)
2989*da0073e9SAndroid Build Coastguard Worker        y.backward(nt.clone().detach())
2990*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.grad, y)
2991*da0073e9SAndroid Build Coastguard Worker
2992*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_bmm_gradcheck(self, device):
2993*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device)
2994*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device)
2995*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device)
2996*da0073e9SAndroid Build Coastguard Worker        d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device)
2997*da0073e9SAndroid Build Coastguard Worker
2998*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c, d):
2999*da0073e9SAndroid Build Coastguard Worker            nt0 = torch.nested.as_nested_tensor([a, b])
3000*da0073e9SAndroid Build Coastguard Worker            nt1 = torch.nested.as_nested_tensor([c, d])
3001*da0073e9SAndroid Build Coastguard Worker            result = nt0.bmm(nt1)
3002*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(result, 0.0)
3003*da0073e9SAndroid Build Coastguard Worker
3004*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c, d)
3005*da0073e9SAndroid Build Coastguard Worker        assert torch.autograd.gradcheck(grad_test_func, inputs=data)
3006*da0073e9SAndroid Build Coastguard Worker
3007*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_bmm_backward(self, device):
3008*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
3009*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 6)), torch.randn((3, 6))],
3010*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
3011*da0073e9SAndroid Build Coastguard Worker            device=device,
3012*da0073e9SAndroid Build Coastguard Worker        )
3013*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
3014*da0073e9SAndroid Build Coastguard Worker            [torch.randn((6, 4)), torch.randn((6, 5))],
3015*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
3016*da0073e9SAndroid Build Coastguard Worker            device=device,
3017*da0073e9SAndroid Build Coastguard Worker        )
3018*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
3019*da0073e9SAndroid Build Coastguard Worker            pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True)
3020*da0073e9SAndroid Build Coastguard Worker            pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True)
3021*da0073e9SAndroid Build Coastguard Worker
3022*da0073e9SAndroid Build Coastguard Worker        ynt = nt0.bmm(nt1)
3023*da0073e9SAndroid Build Coastguard Worker        ypt = pt0.bmm(pt1)
3024*da0073e9SAndroid Build Coastguard Worker        ynt.backward(ynt.clone())
3025*da0073e9SAndroid Build Coastguard Worker        ypt.backward(ypt.clone())
3026*da0073e9SAndroid Build Coastguard Worker
3027*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad)
3028*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad)
3029*da0073e9SAndroid Build Coastguard Worker
3030*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_matmul_gradcheck(self, device):
3031*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device)
3032*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device)
3033*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device)
3034*da0073e9SAndroid Build Coastguard Worker        d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device)
3035*da0073e9SAndroid Build Coastguard Worker
3036*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c, d):
3037*da0073e9SAndroid Build Coastguard Worker            nt0 = torch.nested.as_nested_tensor([a, b])
3038*da0073e9SAndroid Build Coastguard Worker            nt1 = torch.nested.as_nested_tensor([c, d])
3039*da0073e9SAndroid Build Coastguard Worker            result = torch.matmul(nt0, nt1)
3040*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(result, 0.0)
3041*da0073e9SAndroid Build Coastguard Worker
3042*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c, d)
3043*da0073e9SAndroid Build Coastguard Worker        assert torch.autograd.gradcheck(grad_test_func, inputs=data)
3044*da0073e9SAndroid Build Coastguard Worker
3045*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_matmul_backward(self, device):
3046*da0073e9SAndroid Build Coastguard Worker        nt0 = torch.nested.nested_tensor(
3047*da0073e9SAndroid Build Coastguard Worker            [torch.randn((7, 2, 6)), torch.randn((7, 3, 6))],
3048*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
3049*da0073e9SAndroid Build Coastguard Worker            device=device,
3050*da0073e9SAndroid Build Coastguard Worker        )
3051*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor(
3052*da0073e9SAndroid Build Coastguard Worker            [torch.randn((7, 6, 4)), torch.randn((7, 6, 5))],
3053*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
3054*da0073e9SAndroid Build Coastguard Worker            device=device,
3055*da0073e9SAndroid Build Coastguard Worker        )
3056*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
3057*da0073e9SAndroid Build Coastguard Worker            pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True)
3058*da0073e9SAndroid Build Coastguard Worker            pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True)
3059*da0073e9SAndroid Build Coastguard Worker
3060*da0073e9SAndroid Build Coastguard Worker        ynt = torch.matmul(nt0, nt1)
3061*da0073e9SAndroid Build Coastguard Worker        ypt = torch.matmul(pt0, pt1)
3062*da0073e9SAndroid Build Coastguard Worker        ynt.backward(ynt.clone())
3063*da0073e9SAndroid Build Coastguard Worker        ypt.backward(ypt.clone())
3064*da0073e9SAndroid Build Coastguard Worker
3065*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad)
3066*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad)
3067*da0073e9SAndroid Build Coastguard Worker
3068*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_transpose_gradcheck(self, device):
3069*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 5, requires_grad=True, device=device)
3070*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 4, requires_grad=True, device=device)
3071*da0073e9SAndroid Build Coastguard Worker
3072*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b):
3073*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b])
3074*da0073e9SAndroid Build Coastguard Worker            result = nt.transpose(-2, -1).transpose(-2, -1)
3075*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(result, 0.0)
3076*da0073e9SAndroid Build Coastguard Worker
3077*da0073e9SAndroid Build Coastguard Worker        data = (a, b)
3078*da0073e9SAndroid Build Coastguard Worker        assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3)
3079*da0073e9SAndroid Build Coastguard Worker
3080*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_transpose_backward(self, device):
3081*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
3082*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 5)), torch.randn((3, 4))],
3083*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
3084*da0073e9SAndroid Build Coastguard Worker            device=device,
3085*da0073e9SAndroid Build Coastguard Worker        )
3086*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
3087*da0073e9SAndroid Build Coastguard Worker            pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
3088*da0073e9SAndroid Build Coastguard Worker
3089*da0073e9SAndroid Build Coastguard Worker        ynt = nt.transpose(-2, -1)
3090*da0073e9SAndroid Build Coastguard Worker        ypt = pt.transpose(-2, -1)
3091*da0073e9SAndroid Build Coastguard Worker        ynt.backward(ynt.clone())
3092*da0073e9SAndroid Build Coastguard Worker        ypt.backward(ypt.clone())
3093*da0073e9SAndroid Build Coastguard Worker
3094*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
3095*da0073e9SAndroid Build Coastguard Worker
3096*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_reshape_gradcheck(self, device):
3097*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 6, requires_grad=True, device=device)
3098*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 6, requires_grad=True, device=device)
3099*da0073e9SAndroid Build Coastguard Worker
3100*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b):
3101*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b])
3102*da0073e9SAndroid Build Coastguard Worker            result = nt.reshape(2, -1, 2, 3)
3103*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(result, 0.0)
3104*da0073e9SAndroid Build Coastguard Worker
3105*da0073e9SAndroid Build Coastguard Worker        data = (a, b)
3106*da0073e9SAndroid Build Coastguard Worker        assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3)
3107*da0073e9SAndroid Build Coastguard Worker
3108*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_reshape_backward(self):
3109*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
3110*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True
3111*da0073e9SAndroid Build Coastguard Worker        )
3112*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
3113*da0073e9SAndroid Build Coastguard Worker            pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
3114*da0073e9SAndroid Build Coastguard Worker
3115*da0073e9SAndroid Build Coastguard Worker        ynt = nt.reshape(2, -1, 2, 3)
3116*da0073e9SAndroid Build Coastguard Worker        ypt = pt.reshape(2, -1, 2, 3)
3117*da0073e9SAndroid Build Coastguard Worker        ynt.backward(ynt.clone())
3118*da0073e9SAndroid Build Coastguard Worker        ypt.backward(ypt.clone())
3119*da0073e9SAndroid Build Coastguard Worker
3120*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
3121*da0073e9SAndroid Build Coastguard Worker
3122*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_squeeze_backward(self, device):
3123*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
3124*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 6, 1)), torch.randn((3, 6, 1))],
3125*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
3126*da0073e9SAndroid Build Coastguard Worker            device=device,
3127*da0073e9SAndroid Build Coastguard Worker        )
3128*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
3129*da0073e9SAndroid Build Coastguard Worker            pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
3130*da0073e9SAndroid Build Coastguard Worker
3131*da0073e9SAndroid Build Coastguard Worker        ynt = nt.squeeze(-1)
3132*da0073e9SAndroid Build Coastguard Worker        ypt = pt.squeeze(-1)
3133*da0073e9SAndroid Build Coastguard Worker        ynt.backward(ynt.clone())
3134*da0073e9SAndroid Build Coastguard Worker        ypt.backward(ypt.clone())
3135*da0073e9SAndroid Build Coastguard Worker
3136*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
3137*da0073e9SAndroid Build Coastguard Worker
3138*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_squeeze_gradcheck(self, device):
3139*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(
3140*da0073e9SAndroid Build Coastguard Worker            (2, 6, 1), dtype=torch.float64, requires_grad=True, device=device
3141*da0073e9SAndroid Build Coastguard Worker        )
3142*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(
3143*da0073e9SAndroid Build Coastguard Worker            (3, 6, 1), dtype=torch.float64, requires_grad=True, device=device
3144*da0073e9SAndroid Build Coastguard Worker        )
3145*da0073e9SAndroid Build Coastguard Worker
3146*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b):
3147*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b])
3148*da0073e9SAndroid Build Coastguard Worker            result = nt.squeeze(-1)
3149*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(result, 0.0)
3150*da0073e9SAndroid Build Coastguard Worker
3151*da0073e9SAndroid Build Coastguard Worker        assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
3152*da0073e9SAndroid Build Coastguard Worker
3153*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_unsqueeze_backward(self, device):
3154*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
3155*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 6)), torch.randn((3, 6))],
3156*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
3157*da0073e9SAndroid Build Coastguard Worker            device=device,
3158*da0073e9SAndroid Build Coastguard Worker        )
3159*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
3160*da0073e9SAndroid Build Coastguard Worker            pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
3161*da0073e9SAndroid Build Coastguard Worker
3162*da0073e9SAndroid Build Coastguard Worker        ynt = nt.unsqueeze(2)
3163*da0073e9SAndroid Build Coastguard Worker        ypt = pt.unsqueeze(2)
3164*da0073e9SAndroid Build Coastguard Worker        ynt.backward(ynt.clone())
3165*da0073e9SAndroid Build Coastguard Worker        ypt.backward(ypt.clone())
3166*da0073e9SAndroid Build Coastguard Worker
3167*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
3168*da0073e9SAndroid Build Coastguard Worker
3169*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_unsqueeze_gradcheck(self, device):
3170*da0073e9SAndroid Build Coastguard Worker        a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True, device=device)
3171*da0073e9SAndroid Build Coastguard Worker        b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True, device=device)
3172*da0073e9SAndroid Build Coastguard Worker
3173*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b):
3174*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b])
3175*da0073e9SAndroid Build Coastguard Worker            result = nt.unsqueeze(-1)
3176*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(result, 0.0)
3177*da0073e9SAndroid Build Coastguard Worker
3178*da0073e9SAndroid Build Coastguard Worker        assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
3179*da0073e9SAndroid Build Coastguard Worker
3180*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_linear(self, device):
3181*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
3182*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
3183*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
3184*da0073e9SAndroid Build Coastguard Worker
3185*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(
3186*da0073e9SAndroid Build Coastguard Worker            2, 2, requires_grad=True, dtype=torch.float64, device=device
3187*da0073e9SAndroid Build Coastguard Worker        )
3188*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device)
3189*da0073e9SAndroid Build Coastguard Worker
3190*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c, weight, bias=None):
3191*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3192*da0073e9SAndroid Build Coastguard Worker            # This implicitly tests to_padded_tensor grads
3193*da0073e9SAndroid Build Coastguard Worker            d = torch.functional.F.linear(nt, weight, bias)
3194*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(d, 0)
3195*da0073e9SAndroid Build Coastguard Worker
3196*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c, weight, bias)
3197*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3198*da0073e9SAndroid Build Coastguard Worker
3199*da0073e9SAndroid Build Coastguard Worker        # Test linear with no bias added
3200*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c, weight)
3201*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3202*da0073e9SAndroid Build Coastguard Worker
3203*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_linear_plus_transpose(self, device):
3204*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
3205*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
3206*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
3207*da0073e9SAndroid Build Coastguard Worker
3208*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(
3209*da0073e9SAndroid Build Coastguard Worker            2, 2, requires_grad=True, dtype=torch.float64, device=device
3210*da0073e9SAndroid Build Coastguard Worker        )
3211*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device)
3212*da0073e9SAndroid Build Coastguard Worker
3213*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c, weight, bias=None):
3214*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3215*da0073e9SAndroid Build Coastguard Worker            # This implicitly tests to_padded_tensor grads
3216*da0073e9SAndroid Build Coastguard Worker            d = torch.functional.F.linear(nt, weight, bias)
3217*da0073e9SAndroid Build Coastguard Worker            d = d.transpose(-1, -2).contiguous()
3218*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(d, 0)
3219*da0073e9SAndroid Build Coastguard Worker
3220*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c, weight, bias)
3221*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3222*da0073e9SAndroid Build Coastguard Worker
3223*da0073e9SAndroid Build Coastguard Worker        # Test linear with no bias added
3224*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c, weight)
3225*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3226*da0073e9SAndroid Build Coastguard Worker
3227*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_softmax(self, device):
3228*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
3229*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
3230*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
3231*da0073e9SAndroid Build Coastguard Worker
3232*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c, dim):
3233*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3234*da0073e9SAndroid Build Coastguard Worker            # This implicitly tests to_padded_tensor grads
3235*da0073e9SAndroid Build Coastguard Worker            d = torch.functional.F.softmax(nt, dim=dim)
3236*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(d, 0)
3237*da0073e9SAndroid Build Coastguard Worker
3238*da0073e9SAndroid Build Coastguard Worker        # softmax over last dim
3239*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c, -1)
3240*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3241*da0073e9SAndroid Build Coastguard Worker
3242*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_linear_backward(self, device):
3243*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, requires_grad=False, device=device)
3244*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, requires_grad=False, device=device)
3245*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(3, 2, requires_grad=False, device=device)
3246*da0073e9SAndroid Build Coastguard Worker
3247*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(2, 2, requires_grad=True, device=device)
3248*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn(2, requires_grad=True, device=device)
3249*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.as_nested_tensor([a, b, c], device=device)
3250*da0073e9SAndroid Build Coastguard Worker
3251*da0073e9SAndroid Build Coastguard Worker        out = torch.functional.F.linear(nt, weight, bias)
3252*da0073e9SAndroid Build Coastguard Worker
3253*da0073e9SAndroid Build Coastguard Worker        out.backward(out.clone())
3254*da0073e9SAndroid Build Coastguard Worker
3255*da0073e9SAndroid Build Coastguard Worker        assert weight.grad is not None
3256*da0073e9SAndroid Build Coastguard Worker        assert bias.grad is not None
3257*da0073e9SAndroid Build Coastguard Worker
3258*da0073e9SAndroid Build Coastguard Worker        assert a.grad is None
3259*da0073e9SAndroid Build Coastguard Worker        assert b.grad is None
3260*da0073e9SAndroid Build Coastguard Worker        assert c.grad is None
3261*da0073e9SAndroid Build Coastguard Worker
3262*da0073e9SAndroid Build Coastguard Worker    def test_values_grad_with_broadcast(self, device):
3263*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3264*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3265*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3266*da0073e9SAndroid Build Coastguard Worker
3267*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3268*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3269*da0073e9SAndroid Build Coastguard Worker            buffer = nt.values()
3270*da0073e9SAndroid Build Coastguard Worker            return buffer.sum()
3271*da0073e9SAndroid Build Coastguard Worker
3272*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c)
3273*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3274*da0073e9SAndroid Build Coastguard Worker
3275*da0073e9SAndroid Build Coastguard Worker    def test_to_buffer_series_ops_grad_with_broadcast(self, device):
3276*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
3277*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
3278*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
3279*da0073e9SAndroid Build Coastguard Worker
3280*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3281*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3282*da0073e9SAndroid Build Coastguard Worker            buffer = nt.values()
3283*da0073e9SAndroid Build Coastguard Worker            buffer = buffer * 2
3284*da0073e9SAndroid Build Coastguard Worker            return buffer.exp()
3285*da0073e9SAndroid Build Coastguard Worker
3286*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c)
3287*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3288*da0073e9SAndroid Build Coastguard Worker
3289*da0073e9SAndroid Build Coastguard Worker    def test_unbind_flow_through(self, device):
3290*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3291*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3292*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3293*da0073e9SAndroid Build Coastguard Worker
3294*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3295*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3296*da0073e9SAndroid Build Coastguard Worker            ntT = nt.transpose(-1, -2)
3297*da0073e9SAndroid Build Coastguard Worker            unbound = ntT.unbind()
3298*da0073e9SAndroid Build Coastguard Worker            d = unbound[0]
3299*da0073e9SAndroid Build Coastguard Worker            d = torch.pow(d, 2)
3300*da0073e9SAndroid Build Coastguard Worker            return d
3301*da0073e9SAndroid Build Coastguard Worker
3302*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c)
3303*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3304*da0073e9SAndroid Build Coastguard Worker
3305*da0073e9SAndroid Build Coastguard Worker    def test_split_with_sizes_flow_through(self, device):
3306*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 5, requires_grad=True, dtype=torch.float64, device=device)
3307*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 5, requires_grad=True, dtype=torch.float64, device=device)
3308*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(4, 5, requires_grad=True, dtype=torch.float64, device=device)
3309*da0073e9SAndroid Build Coastguard Worker
3310*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3311*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3312*da0073e9SAndroid Build Coastguard Worker            splits = nt.split_with_sizes([2, 3], dim=-1)
3313*da0073e9SAndroid Build Coastguard Worker            unbound = splits[1].unbind()
3314*da0073e9SAndroid Build Coastguard Worker            d = unbound[0]
3315*da0073e9SAndroid Build Coastguard Worker            d = torch.pow(d, 2)
3316*da0073e9SAndroid Build Coastguard Worker            return d
3317*da0073e9SAndroid Build Coastguard Worker
3318*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c)
3319*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3320*da0073e9SAndroid Build Coastguard Worker
3321*da0073e9SAndroid Build Coastguard Worker    def test_indexing_backward(self, device):
3322*da0073e9SAndroid Build Coastguard Worker        x0 = torch.randn((2, 5))
3323*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn((3, 4))
3324*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([x0, x1], device=device, requires_grad=True)
3325*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt[0], x0)
3326*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt[-1], x1)
3327*da0073e9SAndroid Build Coastguard Worker        grad_x0 = torch.randn((2, 5), device=device)
3328*da0073e9SAndroid Build Coastguard Worker        nt[0].backward(grad_x0)
3329*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.nested.nested_tensor(
3330*da0073e9SAndroid Build Coastguard Worker            [grad_x0, torch.zeros((3, 4), device=device)]
3331*da0073e9SAndroid Build Coastguard Worker        )
3332*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.grad, expected_grad)
3333*da0073e9SAndroid Build Coastguard Worker
3334*da0073e9SAndroid Build Coastguard Worker    def test_masked_fill_backward(self, device):
3335*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3336*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3337*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3338*da0073e9SAndroid Build Coastguard Worker
3339*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3340*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3341*da0073e9SAndroid Build Coastguard Worker            mask = nt.detach().clone().to(bool)
3342*da0073e9SAndroid Build Coastguard Worker            out = nt.masked_fill(mask, 0)
3343*da0073e9SAndroid Build Coastguard Worker            out = torch.nested.to_padded_tensor(out, 0)
3344*da0073e9SAndroid Build Coastguard Worker            return out
3345*da0073e9SAndroid Build Coastguard Worker
3346*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c)
3347*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3348*da0073e9SAndroid Build Coastguard Worker
3349*da0073e9SAndroid Build Coastguard Worker    def test_gelu_backward(self, device):
3350*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3351*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3352*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3353*da0073e9SAndroid Build Coastguard Worker
3354*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3355*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3356*da0073e9SAndroid Build Coastguard Worker            nt_gelu = torch.nn.functional.gelu(nt)
3357*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(nt_gelu, 0)
3358*da0073e9SAndroid Build Coastguard Worker
3359*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c)
3360*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3361*da0073e9SAndroid Build Coastguard Worker
3362*da0073e9SAndroid Build Coastguard Worker    def test_relu_backward(self, device):
3363*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3364*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3365*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3366*da0073e9SAndroid Build Coastguard Worker
3367*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3368*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3369*da0073e9SAndroid Build Coastguard Worker            nt_relu = torch.nn.functional.relu(nt)
3370*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(nt_relu, 0)
3371*da0073e9SAndroid Build Coastguard Worker
3372*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c)
3373*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3374*da0073e9SAndroid Build Coastguard Worker
3375*da0073e9SAndroid Build Coastguard Worker    def test_selu_backward(self, device):
3376*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3377*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3378*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3379*da0073e9SAndroid Build Coastguard Worker
3380*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3381*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3382*da0073e9SAndroid Build Coastguard Worker            nt_relu = torch.nn.functional.silu(nt)
3383*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(nt_relu, 0)
3384*da0073e9SAndroid Build Coastguard Worker
3385*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c)
3386*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3387*da0073e9SAndroid Build Coastguard Worker
3388*da0073e9SAndroid Build Coastguard Worker    def test_abs_backward(self, device):
3389*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3390*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3391*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
3392*da0073e9SAndroid Build Coastguard Worker
3393*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3394*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3395*da0073e9SAndroid Build Coastguard Worker            nt_abs = torch.abs(nt)
3396*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(nt_abs, 0)
3397*da0073e9SAndroid Build Coastguard Worker
3398*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c)
3399*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3400*da0073e9SAndroid Build Coastguard Worker
3401*da0073e9SAndroid Build Coastguard Worker    # Previously would error when input NT doesn't require grad
3402*da0073e9SAndroid Build Coastguard Worker    # NotImplementedError: Cannot access storage of UndefinedTensorImpl
3403*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm_backward_edge_case(self, device):
3404*da0073e9SAndroid Build Coastguard Worker        size = 4
3405*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(
3406*da0073e9SAndroid Build Coastguard Worker            1, 2, size, requires_grad=False, dtype=torch.float64, device=device
3407*da0073e9SAndroid Build Coastguard Worker        )
3408*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([a])
3409*da0073e9SAndroid Build Coastguard Worker        nt_layer_norm = torch.nn.LayerNorm(
3410*da0073e9SAndroid Build Coastguard Worker            nt.size(-1), device=device, dtype=torch.float64
3411*da0073e9SAndroid Build Coastguard Worker        )
3412*da0073e9SAndroid Build Coastguard Worker        out = nt_layer_norm(nt)
3413*da0073e9SAndroid Build Coastguard Worker        out.backward(out.clone())
3414*da0073e9SAndroid Build Coastguard Worker
3415*da0073e9SAndroid Build Coastguard Worker    def test_accumulate_grad_different_strides(self, device):
3416*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1, 4, 2, requires_grad=True, dtype=torch.float64, device=device)
3417*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1, 8, 2, requires_grad=True, dtype=torch.float64, device=device)
3418*da0073e9SAndroid Build Coastguard Worker
3419*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b):
3420*da0073e9SAndroid Build Coastguard Worker            nt_1 = torch.nested.as_nested_tensor([a, b])
3421*da0073e9SAndroid Build Coastguard Worker            nt_2 = nt_1.clone()
3422*da0073e9SAndroid Build Coastguard Worker            out = torch.nn.functional.scaled_dot_product_attention(nt_1, nt_2, nt_2)
3423*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(out, 0)
3424*da0073e9SAndroid Build Coastguard Worker
3425*da0073e9SAndroid Build Coastguard Worker        data = (a, b)
3426*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3427*da0073e9SAndroid Build Coastguard Worker
3428*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/95562
3429*da0073e9SAndroid Build Coastguard Worker    @skipIfSlowGradcheckEnv
3430*da0073e9SAndroid Build Coastguard Worker    @parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2])
3431*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm_backward(self, device, size):
3432*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(
3433*da0073e9SAndroid Build Coastguard Worker            1, 2, size, requires_grad=True, dtype=torch.float64, device=device
3434*da0073e9SAndroid Build Coastguard Worker        )
3435*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(
3436*da0073e9SAndroid Build Coastguard Worker            2, 2, size, requires_grad=True, dtype=torch.float64, device=device
3437*da0073e9SAndroid Build Coastguard Worker        )
3438*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(
3439*da0073e9SAndroid Build Coastguard Worker            3, 2, size, requires_grad=True, dtype=torch.float64, device=device
3440*da0073e9SAndroid Build Coastguard Worker        )
3441*da0073e9SAndroid Build Coastguard Worker
3442*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3443*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3444*da0073e9SAndroid Build Coastguard Worker            layer_norm = torch.nn.LayerNorm(
3445*da0073e9SAndroid Build Coastguard Worker                nt.size(-1), device=device, dtype=torch.float64
3446*da0073e9SAndroid Build Coastguard Worker            )
3447*da0073e9SAndroid Build Coastguard Worker            nt_layer_norm = layer_norm(nt)
3448*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(nt_layer_norm, 0)
3449*da0073e9SAndroid Build Coastguard Worker
3450*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c)
3451*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3452*da0073e9SAndroid Build Coastguard Worker
3453*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/95562
3454*da0073e9SAndroid Build Coastguard Worker    @skipIfSlowGradcheckEnv
3455*da0073e9SAndroid Build Coastguard Worker    # Could either mark slow or reduce size
3456*da0073e9SAndroid Build Coastguard Worker    @parametrize("size", [128, 32, 4, 2])
3457*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm_backward_5d(self, device, size):
3458*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(
3459*da0073e9SAndroid Build Coastguard Worker            4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
3460*da0073e9SAndroid Build Coastguard Worker        )
3461*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(
3462*da0073e9SAndroid Build Coastguard Worker            7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
3463*da0073e9SAndroid Build Coastguard Worker        )
3464*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(
3465*da0073e9SAndroid Build Coastguard Worker            10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
3466*da0073e9SAndroid Build Coastguard Worker        )
3467*da0073e9SAndroid Build Coastguard Worker
3468*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3469*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c])
3470*da0073e9SAndroid Build Coastguard Worker            layer_norm = torch.nn.LayerNorm(
3471*da0073e9SAndroid Build Coastguard Worker                (size, size, nt.size(-1)), device=device, dtype=torch.float64
3472*da0073e9SAndroid Build Coastguard Worker            )
3473*da0073e9SAndroid Build Coastguard Worker            nt_layer_norm = layer_norm(nt)
3474*da0073e9SAndroid Build Coastguard Worker            return torch.nested.to_padded_tensor(nt_layer_norm, 0)
3475*da0073e9SAndroid Build Coastguard Worker
3476*da0073e9SAndroid Build Coastguard Worker        data = (a, b, c)
3477*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
3478*da0073e9SAndroid Build Coastguard Worker
3479*da0073e9SAndroid Build Coastguard Worker
3480*da0073e9SAndroid Build Coastguard Worker# Found in torch/testing/_comparison.py
3481*da0073e9SAndroid Build Coastguard Workerdefault_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5}
3482*da0073e9SAndroid Build Coastguard Workerdefault_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6}
3483*da0073e9SAndroid Build Coastguard Worker
3484*da0073e9SAndroid Build Coastguard Worker
3485*da0073e9SAndroid Build Coastguard Workerdef get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
3486*da0073e9SAndroid Build Coastguard Worker    deviation = true_value - computed_value
3487*da0073e9SAndroid Build Coastguard Worker    deviation = torch.abs(deviation / true_value)
3488*da0073e9SAndroid Build Coastguard Worker    # Fill in the nans with the default rtol
3489*da0073e9SAndroid Build Coastguard Worker    torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype])
3490*da0073e9SAndroid Build Coastguard Worker    return deviation.max().item()
3491*da0073e9SAndroid Build Coastguard Worker
3492*da0073e9SAndroid Build Coastguard Worker
3493*da0073e9SAndroid Build Coastguard Workerdef get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
3494*da0073e9SAndroid Build Coastguard Worker    deviation = true_value - computed_value
3495*da0073e9SAndroid Build Coastguard Worker    atol = torch.abs(deviation).max().item()
3496*da0073e9SAndroid Build Coastguard Worker    return atol
3497*da0073e9SAndroid Build Coastguard Worker
3498*da0073e9SAndroid Build Coastguard Worker
3499*da0073e9SAndroid Build Coastguard Workerdef get_tolerances(
3500*da0073e9SAndroid Build Coastguard Worker    true_value: torch.Tensor,
3501*da0073e9SAndroid Build Coastguard Worker    computed_value: torch.Tensor,
3502*da0073e9SAndroid Build Coastguard Worker    fudge_factor: Optional[float] = None,
3503*da0073e9SAndroid Build Coastguard Worker) -> Tuple[float, float]:
3504*da0073e9SAndroid Build Coastguard Worker    """Returns the absolute and relative tolerances for comparing two tensors."""
3505*da0073e9SAndroid Build Coastguard Worker    fudge_factor = fudge_factor if fudge_factor is not None else 1.0
3506*da0073e9SAndroid Build Coastguard Worker    atol = get_atol(true_value, computed_value)
3507*da0073e9SAndroid Build Coastguard Worker    rtol = get_rtol(true_value, computed_value)
3508*da0073e9SAndroid Build Coastguard Worker
3509*da0073e9SAndroid Build Coastguard Worker    atol = fudge_factor * max(atol, default_atol[computed_value.dtype])
3510*da0073e9SAndroid Build Coastguard Worker    rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype])
3511*da0073e9SAndroid Build Coastguard Worker    # torch.isclose() has weird behavior around see:
3512*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/102400
3513*da0073e9SAndroid Build Coastguard Worker    if rtol > 1e30:
3514*da0073e9SAndroid Build Coastguard Worker        rtol = default_rtol[computed_value.dtype]
3515*da0073e9SAndroid Build Coastguard Worker    return atol, rtol
3516*da0073e9SAndroid Build Coastguard Worker
3517*da0073e9SAndroid Build Coastguard Worker
3518*da0073e9SAndroid Build Coastguard Worker# We can probably parametrizing existing tests instead of having a separate
3519*da0073e9SAndroid Build Coastguard Worker# test class as we begin to support more ops. Also maybe rewrite with OpInfos.
3520*da0073e9SAndroid Build Coastguard Worker@markDynamoStrictTest
3521*da0073e9SAndroid Build Coastguard Workerclass TestNestedTensorSubclass(NestedTensorTestCase):
3522*da0073e9SAndroid Build Coastguard Worker    # TODO: consolidate with the below
3523*da0073e9SAndroid Build Coastguard Worker    def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True):
3524*da0073e9SAndroid Build Coastguard Worker        Ds = nested_size[1:]
3525*da0073e9SAndroid Build Coastguard Worker        out = []
3526*da0073e9SAndroid Build Coastguard Worker        for s in nested_size[0]:
3527*da0073e9SAndroid Build Coastguard Worker            out.append(
3528*da0073e9SAndroid Build Coastguard Worker                torch.randn(
3529*da0073e9SAndroid Build Coastguard Worker                    s,
3530*da0073e9SAndroid Build Coastguard Worker                    *Ds,
3531*da0073e9SAndroid Build Coastguard Worker                    requires_grad=requires_grad,
3532*da0073e9SAndroid Build Coastguard Worker                    device=device,
3533*da0073e9SAndroid Build Coastguard Worker                    dtype=torch.float64,
3534*da0073e9SAndroid Build Coastguard Worker                )
3535*da0073e9SAndroid Build Coastguard Worker            )
3536*da0073e9SAndroid Build Coastguard Worker        return out
3537*da0073e9SAndroid Build Coastguard Worker
3538*da0073e9SAndroid Build Coastguard Worker    def _get_example_tensor_lists(
3539*da0073e9SAndroid Build Coastguard Worker        self,
3540*da0073e9SAndroid Build Coastguard Worker        include_list_of_lists=True,
3541*da0073e9SAndroid Build Coastguard Worker        include_requires_grad=True,
3542*da0073e9SAndroid Build Coastguard Worker        include_inner_dim_size_1=False,
3543*da0073e9SAndroid Build Coastguard Worker        include_2d_tensor=False,
3544*da0073e9SAndroid Build Coastguard Worker    ):
3545*da0073e9SAndroid Build Coastguard Worker        def _make_tensor(
3546*da0073e9SAndroid Build Coastguard Worker            *shape, include_requires_grad=include_requires_grad, requires_grad=True
3547*da0073e9SAndroid Build Coastguard Worker        ):
3548*da0073e9SAndroid Build Coastguard Worker            return torch.randn(
3549*da0073e9SAndroid Build Coastguard Worker                *shape,
3550*da0073e9SAndroid Build Coastguard Worker                requires_grad=(requires_grad if include_requires_grad else False),
3551*da0073e9SAndroid Build Coastguard Worker            )
3552*da0073e9SAndroid Build Coastguard Worker
3553*da0073e9SAndroid Build Coastguard Worker        # Purposefully introduce mixed requires_grad settings for the components
3554*da0073e9SAndroid Build Coastguard Worker        # when include_requires_grad=True.
3555*da0073e9SAndroid Build Coastguard Worker        example_lists = [
3556*da0073e9SAndroid Build Coastguard Worker            # (B, *, D) with B=4
3557*da0073e9SAndroid Build Coastguard Worker            [
3558*da0073e9SAndroid Build Coastguard Worker                _make_tensor(2, 5),
3559*da0073e9SAndroid Build Coastguard Worker                _make_tensor(3, 5, requires_grad=False),
3560*da0073e9SAndroid Build Coastguard Worker                _make_tensor(4, 5, requires_grad=False),
3561*da0073e9SAndroid Build Coastguard Worker                _make_tensor(6, 5),
3562*da0073e9SAndroid Build Coastguard Worker            ],
3563*da0073e9SAndroid Build Coastguard Worker            # (B, *, D_0, D_1) with B=5
3564*da0073e9SAndroid Build Coastguard Worker            [
3565*da0073e9SAndroid Build Coastguard Worker                _make_tensor(2, 5, 6),
3566*da0073e9SAndroid Build Coastguard Worker                _make_tensor(3, 5, 6),
3567*da0073e9SAndroid Build Coastguard Worker                _make_tensor(4, 5, 6, requires_grad=False),
3568*da0073e9SAndroid Build Coastguard Worker                _make_tensor(5, 5, 6),
3569*da0073e9SAndroid Build Coastguard Worker                _make_tensor(6, 5, 6),
3570*da0073e9SAndroid Build Coastguard Worker            ],
3571*da0073e9SAndroid Build Coastguard Worker            # (B, *, D_0, D_1, D_2) with B=6
3572*da0073e9SAndroid Build Coastguard Worker            [
3573*da0073e9SAndroid Build Coastguard Worker                _make_tensor(2, 5, 6, 7),
3574*da0073e9SAndroid Build Coastguard Worker                _make_tensor(3, 5, 6, 7),
3575*da0073e9SAndroid Build Coastguard Worker                _make_tensor(4, 5, 6, 7, requires_grad=False),
3576*da0073e9SAndroid Build Coastguard Worker                _make_tensor(5, 5, 6, 7),
3577*da0073e9SAndroid Build Coastguard Worker                _make_tensor(6, 5, 6, 7),
3578*da0073e9SAndroid Build Coastguard Worker                _make_tensor(7, 5, 6, 7),
3579*da0073e9SAndroid Build Coastguard Worker            ],
3580*da0073e9SAndroid Build Coastguard Worker        ]
3581*da0073e9SAndroid Build Coastguard Worker
3582*da0073e9SAndroid Build Coastguard Worker        if include_list_of_lists:
3583*da0073e9SAndroid Build Coastguard Worker            example_lists.append(
3584*da0073e9SAndroid Build Coastguard Worker                # (B, *, D) with B=3 in list form
3585*da0073e9SAndroid Build Coastguard Worker                [
3586*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(2, 5, requires_grad=False).tolist(),
3587*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(3, 5).tolist(),
3588*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(4, 5).tolist(),
3589*da0073e9SAndroid Build Coastguard Worker                ]
3590*da0073e9SAndroid Build Coastguard Worker            )
3591*da0073e9SAndroid Build Coastguard Worker
3592*da0073e9SAndroid Build Coastguard Worker        if include_inner_dim_size_1:
3593*da0073e9SAndroid Build Coastguard Worker            example_lists.append(
3594*da0073e9SAndroid Build Coastguard Worker                [
3595*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(2, 1),
3596*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(3, 1, requires_grad=False),
3597*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(4, 1, requires_grad=False),
3598*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(6, 1),
3599*da0073e9SAndroid Build Coastguard Worker                ]  # (B, *, 1)
3600*da0073e9SAndroid Build Coastguard Worker            )
3601*da0073e9SAndroid Build Coastguard Worker            example_lists.append(
3602*da0073e9SAndroid Build Coastguard Worker                [
3603*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(2, 5, 1),
3604*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(3, 5, 1, requires_grad=False),
3605*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(4, 5, 1, requires_grad=False),
3606*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(6, 5, 1),
3607*da0073e9SAndroid Build Coastguard Worker                ]  # (B, *, 5, 1)
3608*da0073e9SAndroid Build Coastguard Worker            )
3609*da0073e9SAndroid Build Coastguard Worker
3610*da0073e9SAndroid Build Coastguard Worker        if include_2d_tensor:
3611*da0073e9SAndroid Build Coastguard Worker            example_lists.append(
3612*da0073e9SAndroid Build Coastguard Worker                [
3613*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(2),
3614*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(3, requires_grad=False),
3615*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(4, requires_grad=False),
3616*da0073e9SAndroid Build Coastguard Worker                    _make_tensor(6),
3617*da0073e9SAndroid Build Coastguard Worker                ]  # (B, *)
3618*da0073e9SAndroid Build Coastguard Worker            )
3619*da0073e9SAndroid Build Coastguard Worker
3620*da0073e9SAndroid Build Coastguard Worker        return example_lists
3621*da0073e9SAndroid Build Coastguard Worker
3622*da0073e9SAndroid Build Coastguard Worker    def test_tensor_attributes(self, device):
3623*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
3624*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
3625*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
3626*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3627*da0073e9SAndroid Build Coastguard Worker        _offsets = nt.offsets()
3628*da0073e9SAndroid Build Coastguard Worker
3629*da0073e9SAndroid Build Coastguard Worker        for op in (
3630*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.is_non_overlapping_and_dense.default,
3631*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.sym_size.default,
3632*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.dim.default,
3633*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.numel.default,
3634*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.sym_numel.default,
3635*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.sym_stride.default,
3636*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.sym_storage_offset.default,
3637*da0073e9SAndroid Build Coastguard Worker        ):
3638*da0073e9SAndroid Build Coastguard Worker            op(nt)
3639*da0073e9SAndroid Build Coastguard Worker
3640*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3641*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "directly calling torch.ops.aten.size"
3642*da0073e9SAndroid Build Coastguard Worker        ):
3643*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.size.default(nt)
3644*da0073e9SAndroid Build Coastguard Worker
3645*da0073e9SAndroid Build Coastguard Worker        nested_int = torch.nested._internal.nested_tensor.get_tensor_symint(
3646*da0073e9SAndroid Build Coastguard Worker            _offsets, coeff=1
3647*da0073e9SAndroid Build Coastguard Worker        )
3648*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.size(), (3, nested_int, 3))
3649*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.shape, (3, nested_int, 3))
3650*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.dim(), 3)
3651*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.numel(), 27)
3652*da0073e9SAndroid Build Coastguard Worker
3653*da0073e9SAndroid Build Coastguard Worker    @parametrize("nt_dim", [3, 4, 5])
3654*da0073e9SAndroid Build Coastguard Worker    def test_linear(self, device, nt_dim):
3655*da0073e9SAndroid Build Coastguard Worker        if nt_dim == 3:
3656*da0073e9SAndroid Build Coastguard Worker            fixed_shape = (3,)
3657*da0073e9SAndroid Build Coastguard Worker        elif nt_dim == 4:
3658*da0073e9SAndroid Build Coastguard Worker            fixed_shape = (4, 3)
3659*da0073e9SAndroid Build Coastguard Worker        elif nt_dim == 5:
3660*da0073e9SAndroid Build Coastguard Worker            fixed_shape = (5, 4, 3)
3661*da0073e9SAndroid Build Coastguard Worker
3662*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(
3663*da0073e9SAndroid Build Coastguard Worker            2, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
3664*da0073e9SAndroid Build Coastguard Worker        )
3665*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(
3666*da0073e9SAndroid Build Coastguard Worker            3, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
3667*da0073e9SAndroid Build Coastguard Worker        )
3668*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(
3669*da0073e9SAndroid Build Coastguard Worker            4, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
3670*da0073e9SAndroid Build Coastguard Worker        )
3671*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(
3672*da0073e9SAndroid Build Coastguard Worker            4, 3, requires_grad=True, dtype=torch.float64, device=device
3673*da0073e9SAndroid Build Coastguard Worker        )
3674*da0073e9SAndroid Build Coastguard Worker
3675*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c, weight):
3676*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3677*da0073e9SAndroid Build Coastguard Worker            out = torch.nn.functional.linear(nt, weight)
3678*da0073e9SAndroid Build Coastguard Worker            return out.values()
3679*da0073e9SAndroid Build Coastguard Worker
3680*da0073e9SAndroid Build Coastguard Worker        gradcheck(grad_test_func, inputs=(a, b, c, weight), check_batched_grad=False)
3681*da0073e9SAndroid Build Coastguard Worker
3682*da0073e9SAndroid Build Coastguard Worker    def test_unary_pointwise(self, device):
3683*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
3684*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
3685*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
3686*da0073e9SAndroid Build Coastguard Worker
3687*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3688*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3689*da0073e9SAndroid Build Coastguard Worker            out = torch.nn.functional.silu(nt.sin().cos())
3690*da0073e9SAndroid Build Coastguard Worker            return out.values()
3691*da0073e9SAndroid Build Coastguard Worker
3692*da0073e9SAndroid Build Coastguard Worker        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
3693*da0073e9SAndroid Build Coastguard Worker
3694*da0073e9SAndroid Build Coastguard Worker    def test_unary_pointwise_transposed_inputs(self, device):
3695*da0073e9SAndroid Build Coastguard Worker        a, b, c = (
3696*da0073e9SAndroid Build Coastguard Worker            torch.randn(
3697*da0073e9SAndroid Build Coastguard Worker                i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
3698*da0073e9SAndroid Build Coastguard Worker            )
3699*da0073e9SAndroid Build Coastguard Worker            for i in range(3)
3700*da0073e9SAndroid Build Coastguard Worker        )
3701*da0073e9SAndroid Build Coastguard Worker
3702*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
3703*da0073e9SAndroid Build Coastguard Worker            [a.detach(), b.detach(), c.detach()], layout=torch.jagged
3704*da0073e9SAndroid Build Coastguard Worker        )
3705*da0073e9SAndroid Build Coastguard Worker        nt_t = nt.transpose(1, 2)
3706*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(nt_t.is_contiguous())
3707*da0073e9SAndroid Build Coastguard Worker        out = torch.nn.functional.silu(nt_t.sin().cos())
3708*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3709*da0073e9SAndroid Build Coastguard Worker            out.is_contiguous(),
3710*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous(),
3711*da0073e9SAndroid Build Coastguard Worker        )
3712*da0073e9SAndroid Build Coastguard Worker
3713*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt_t.shape, out.shape)
3714*da0073e9SAndroid Build Coastguard Worker
3715*da0073e9SAndroid Build Coastguard Worker        a, b, c = (
3716*da0073e9SAndroid Build Coastguard Worker            torch.randn(
3717*da0073e9SAndroid Build Coastguard Worker                i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
3718*da0073e9SAndroid Build Coastguard Worker            )
3719*da0073e9SAndroid Build Coastguard Worker            for i in range(3)
3720*da0073e9SAndroid Build Coastguard Worker        )
3721*da0073e9SAndroid Build Coastguard Worker
3722*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3723*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3724*da0073e9SAndroid Build Coastguard Worker            nt_t = nt.transpose(1, 2)
3725*da0073e9SAndroid Build Coastguard Worker            out = torch.nn.functional.silu(nt_t.sin().cos())
3726*da0073e9SAndroid Build Coastguard Worker            return out.values()
3727*da0073e9SAndroid Build Coastguard Worker
3728*da0073e9SAndroid Build Coastguard Worker        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
3729*da0073e9SAndroid Build Coastguard Worker
3730*da0073e9SAndroid Build Coastguard Worker    def test_binary_pointwise(self, device):
3731*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
3732*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
3733*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
3734*da0073e9SAndroid Build Coastguard Worker
3735*da0073e9SAndroid Build Coastguard Worker        # Incorrect usage: shape check will fail if the offsets tensor are not
3736*da0073e9SAndroid Build Coastguard Worker        #                  the same exact tensor object
3737*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3738*da0073e9SAndroid Build Coastguard Worker        nt2 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3739*da0073e9SAndroid Build Coastguard Worker
3740*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
3741*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
3742*da0073e9SAndroid Build Coastguard Worker            "cannot call binary pointwise function .* with inputs of shapes",
3743*da0073e9SAndroid Build Coastguard Worker            lambda: nt1 * nt2,
3744*da0073e9SAndroid Build Coastguard Worker        )
3745*da0073e9SAndroid Build Coastguard Worker
3746*da0073e9SAndroid Build Coastguard Worker        # Correct usage: chain the calls using the same offsets tensor object
3747*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3748*da0073e9SAndroid Build Coastguard Worker            nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3749*da0073e9SAndroid Build Coastguard Worker            # TODO: Switch to public API that takes in (values, offsets) once it exists
3750*da0073e9SAndroid Build Coastguard Worker            nt2, offsets = jagged_from_list([a, b, c], nt1.offsets())
3751*da0073e9SAndroid Build Coastguard Worker            out = nt1 * nt2
3752*da0073e9SAndroid Build Coastguard Worker            return out.values()
3753*da0073e9SAndroid Build Coastguard Worker
3754*da0073e9SAndroid Build Coastguard Worker        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
3755*da0073e9SAndroid Build Coastguard Worker
3756*da0073e9SAndroid Build Coastguard Worker    def test_binary_pointwise_transposed(self, device):
3757*da0073e9SAndroid Build Coastguard Worker        a, b, c = (
3758*da0073e9SAndroid Build Coastguard Worker            torch.randn(i + 2, 5, dtype=torch.float64, device=device) for i in range(3)
3759*da0073e9SAndroid Build Coastguard Worker        )
3760*da0073e9SAndroid Build Coastguard Worker
3761*da0073e9SAndroid Build Coastguard Worker        nt1, offsets = jagged_from_list([a, b, c], None)
3762*da0073e9SAndroid Build Coastguard Worker        nt2, offsets = jagged_from_list([a, b, c], offsets)
3763*da0073e9SAndroid Build Coastguard Worker
3764*da0073e9SAndroid Build Coastguard Worker        nt1_t = nt1.transpose(1, 2)
3765*da0073e9SAndroid Build Coastguard Worker        nt2_t = nt2.transpose(1, 2)
3766*da0073e9SAndroid Build Coastguard Worker
3767*da0073e9SAndroid Build Coastguard Worker        # out = nt1_t * nt2_t
3768*da0073e9SAndroid Build Coastguard Worker        # self.assertFalse(nt1_t.is_contiguous())
3769*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(out.is_contiguous(), (b.transpose(-1, -2) * b.transpose(-1, -2)).is_contiguous())
3770*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(out.shape, nt1_t.shape)
3771*da0073e9SAndroid Build Coastguard Worker
3772*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
3773*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
3774*da0073e9SAndroid Build Coastguard Worker            "cannot call binary pointwise function mul.Tensor with inputs of shapes",
3775*da0073e9SAndroid Build Coastguard Worker            lambda: nt1 * nt2_t,
3776*da0073e9SAndroid Build Coastguard Worker        )
3777*da0073e9SAndroid Build Coastguard Worker
3778*da0073e9SAndroid Build Coastguard Worker        a, b, c = (
3779*da0073e9SAndroid Build Coastguard Worker            torch.randn(
3780*da0073e9SAndroid Build Coastguard Worker                i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
3781*da0073e9SAndroid Build Coastguard Worker            )
3782*da0073e9SAndroid Build Coastguard Worker            for i in range(3)
3783*da0073e9SAndroid Build Coastguard Worker        )
3784*da0073e9SAndroid Build Coastguard Worker
3785*da0073e9SAndroid Build Coastguard Worker        # Correct usage: chain the calls using the same offsets tensor object
3786*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b, c):
3787*da0073e9SAndroid Build Coastguard Worker            nt1, offsets = jagged_from_list([a, b, c], None)
3788*da0073e9SAndroid Build Coastguard Worker            nt2, offsets = jagged_from_list([a, b, c], offsets)
3789*da0073e9SAndroid Build Coastguard Worker            nt1_t = nt1.transpose(1, 2)
3790*da0073e9SAndroid Build Coastguard Worker            nt2_t = nt2.transpose(1, 2)
3791*da0073e9SAndroid Build Coastguard Worker            out = nt1_t * nt2_t
3792*da0073e9SAndroid Build Coastguard Worker            return out.values()
3793*da0073e9SAndroid Build Coastguard Worker
3794*da0073e9SAndroid Build Coastguard Worker        gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
3795*da0073e9SAndroid Build Coastguard Worker
3796*da0073e9SAndroid Build Coastguard Worker    def test_split(self, device):
3797*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
3798*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
3799*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
3800*da0073e9SAndroid Build Coastguard Worker
3801*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3802*da0073e9SAndroid Build Coastguard Worker        out = torch.split(nt, 2, -1)
3803*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(out), 2)
3804*da0073e9SAndroid Build Coastguard Worker        self.assertEqualIgnoringNestedInts(
3805*da0073e9SAndroid Build Coastguard Worker            out[0],
3806*da0073e9SAndroid Build Coastguard Worker            torch.nested.as_nested_tensor(
3807*da0073e9SAndroid Build Coastguard Worker                [a[:, 0:2], b[:, 0:2], c[:, 0:2]], layout=torch.jagged
3808*da0073e9SAndroid Build Coastguard Worker            ),
3809*da0073e9SAndroid Build Coastguard Worker        )
3810*da0073e9SAndroid Build Coastguard Worker        self.assertEqualIgnoringNestedInts(
3811*da0073e9SAndroid Build Coastguard Worker            out[1],
3812*da0073e9SAndroid Build Coastguard Worker            torch.nested.as_nested_tensor(
3813*da0073e9SAndroid Build Coastguard Worker                [a[:, 2:], b[:, 2:], c[:, 2:]], layout=torch.jagged
3814*da0073e9SAndroid Build Coastguard Worker            ),
3815*da0073e9SAndroid Build Coastguard Worker        )
3816*da0073e9SAndroid Build Coastguard Worker
3817*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3818*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
3819*da0073e9SAndroid Build Coastguard Worker            r"split\(\): not supported for NestedTensor on dim=1",
3820*da0073e9SAndroid Build Coastguard Worker        ):
3821*da0073e9SAndroid Build Coastguard Worker            torch.split(nt, 2, 1)
3822*da0073e9SAndroid Build Coastguard Worker
3823*da0073e9SAndroid Build Coastguard Worker    def test_split_with_sizes(self, device):
3824*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
3825*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
3826*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
3827*da0073e9SAndroid Build Coastguard Worker
3828*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
3829*da0073e9SAndroid Build Coastguard Worker        out = torch.split(nt, [1, 2], -1)
3830*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(out), 2)
3831*da0073e9SAndroid Build Coastguard Worker        self.assertEqualIgnoringNestedInts(
3832*da0073e9SAndroid Build Coastguard Worker            out[0],
3833*da0073e9SAndroid Build Coastguard Worker            torch.nested.as_nested_tensor(
3834*da0073e9SAndroid Build Coastguard Worker                [a[:, 0:1], b[:, 0:1], c[:, 0:1]], layout=torch.jagged
3835*da0073e9SAndroid Build Coastguard Worker            ),
3836*da0073e9SAndroid Build Coastguard Worker        )
3837*da0073e9SAndroid Build Coastguard Worker        self.assertEqualIgnoringNestedInts(
3838*da0073e9SAndroid Build Coastguard Worker            out[1],
3839*da0073e9SAndroid Build Coastguard Worker            torch.nested.as_nested_tensor(
3840*da0073e9SAndroid Build Coastguard Worker                [a[:, 1:], b[:, 1:], c[:, 1:]], layout=torch.jagged
3841*da0073e9SAndroid Build Coastguard Worker            ),
3842*da0073e9SAndroid Build Coastguard Worker        )
3843*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3844*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
3845*da0073e9SAndroid Build Coastguard Worker            r"split_with_sizes\(\): not supported for NestedTensor on dim=1",
3846*da0073e9SAndroid Build Coastguard Worker        ):
3847*da0073e9SAndroid Build Coastguard Worker            torch.split(nt, [1, 2], 1)
3848*da0073e9SAndroid Build Coastguard Worker
3849*da0073e9SAndroid Build Coastguard Worker    def test_softmax(self, device):
3850*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
3851*da0073e9SAndroid Build Coastguard Worker            [3, None, 5],
3852*da0073e9SAndroid Build Coastguard Worker            device=device,
3853*da0073e9SAndroid Build Coastguard Worker            dtype=torch.float32,
3854*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
3855*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
3856*da0073e9SAndroid Build Coastguard Worker        )
3857*da0073e9SAndroid Build Coastguard Worker
3858*da0073e9SAndroid Build Coastguard Worker        # operate on dim=2
3859*da0073e9SAndroid Build Coastguard Worker        output = nt.softmax(dim=2)
3860*da0073e9SAndroid Build Coastguard Worker
3861*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.disable
3862*da0073e9SAndroid Build Coastguard Worker        def _compare_to_ref(nt, output, dim):
3863*da0073e9SAndroid Build Coastguard Worker            for in_component, out_component in zip(nt.unbind(), output.unbind()):
3864*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(in_component.softmax(dim=dim), out_component)
3865*da0073e9SAndroid Build Coastguard Worker
3866*da0073e9SAndroid Build Coastguard Worker        # dim=2 -> dim=1 after unbind
3867*da0073e9SAndroid Build Coastguard Worker        _compare_to_ref(nt, output, dim=1)
3868*da0073e9SAndroid Build Coastguard Worker
3869*da0073e9SAndroid Build Coastguard Worker        # operate on dim=-1
3870*da0073e9SAndroid Build Coastguard Worker        output2 = nt.softmax(dim=-1)
3871*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.disable(self.assertEqual)(output, output2)
3872*da0073e9SAndroid Build Coastguard Worker        _compare_to_ref(nt, output2, dim=-1)
3873*da0073e9SAndroid Build Coastguard Worker
3874*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(a, b):
3875*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged)
3876*da0073e9SAndroid Build Coastguard Worker            out = nt.softmax(dim=-1)
3877*da0073e9SAndroid Build Coastguard Worker            return out.values()
3878*da0073e9SAndroid Build Coastguard Worker
3879*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4, 5, requires_grad=True, dtype=torch.float64, device=device)
3880*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(8, 5, requires_grad=True, dtype=torch.float64, device=device)
3881*da0073e9SAndroid Build Coastguard Worker        gradcheck(grad_test_func, inputs=(a, b), check_batched_grad=False)
3882*da0073e9SAndroid Build Coastguard Worker
3883*da0073e9SAndroid Build Coastguard Worker    def test_views_inherit_ragged_dim(self, device):
3884*da0073e9SAndroid Build Coastguard Worker        # view
3885*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
3886*da0073e9SAndroid Build Coastguard Worker            [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged
3887*da0073e9SAndroid Build Coastguard Worker        )
3888*da0073e9SAndroid Build Coastguard Worker        # inherit ragged dim via -1
3889*da0073e9SAndroid Build Coastguard Worker        view = nt.view(4, -1, 80)
3890*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.shape[1], view.shape[1])
3891*da0073e9SAndroid Build Coastguard Worker        # inherit batch and ragged dims via -1
3892*da0073e9SAndroid Build Coastguard Worker        view2 = nt.view(-1, -1, 80)
3893*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.shape[:2], view2.shape[:2])
3894*da0073e9SAndroid Build Coastguard Worker
3895*da0073e9SAndroid Build Coastguard Worker        # expand
3896*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
3897*da0073e9SAndroid Build Coastguard Worker            [3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged
3898*da0073e9SAndroid Build Coastguard Worker        )
3899*da0073e9SAndroid Build Coastguard Worker        # inherit batch and ragged dims via -1
3900*da0073e9SAndroid Build Coastguard Worker        view = nt.expand(-1, -1, 5)
3901*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.shape[:2], view.shape[:2])
3902*da0073e9SAndroid Build Coastguard Worker
3903*da0073e9SAndroid Build Coastguard Worker    def test_view_ragged_idx_not_one(self, device):
3904*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
3905*da0073e9SAndroid Build Coastguard Worker            [2, None, 20], device=device, dtype=torch.float32, layout=torch.jagged
3906*da0073e9SAndroid Build Coastguard Worker        )
3907*da0073e9SAndroid Build Coastguard Worker
3908*da0073e9SAndroid Build Coastguard Worker        view_transposed = nt.transpose(1, 2).view(2, 20, nt.size(1))
3909*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((2, 20, nt.size(1)), (view_transposed.size()))
3910*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(view_transposed._base, nt._base)
3911*da0073e9SAndroid Build Coastguard Worker
3912*da0073e9SAndroid Build Coastguard Worker    def test_unsafe_view(self, device):
3913*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
3914*da0073e9SAndroid Build Coastguard Worker            [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged
3915*da0073e9SAndroid Build Coastguard Worker        )
3916*da0073e9SAndroid Build Coastguard Worker        # basic view
3917*da0073e9SAndroid Build Coastguard Worker        view1 = torch.ops.aten._unsafe_view(nt, (4, -1, 80))
3918*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((4, nt.size(1), 80), tuple(view1.size()))
3919*da0073e9SAndroid Build Coastguard Worker        # _unsafe_view differs from view in that the view information is not tracked
3920*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(view1._base is None)
3921*da0073e9SAndroid Build Coastguard Worker
3922*da0073e9SAndroid Build Coastguard Worker        # test an unsafe_view when ragged_idx != 1, currently only supports identity view
3923*da0073e9SAndroid Build Coastguard Worker        nt_t = nt.transpose(1, 2)
3924*da0073e9SAndroid Build Coastguard Worker        view2 = torch.ops.aten._unsafe_view(nt_t, (4, 8, nt.size(1), 10))
3925*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((4, 8, nt.size(1), 10), tuple(view2.size()))
3926*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(view2._base is None)
3927*da0073e9SAndroid Build Coastguard Worker
3928*da0073e9SAndroid Build Coastguard Worker    @xfailIfTorchDynamo
3929*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
3930*da0073e9SAndroid Build Coastguard Worker    def test_reshape_decomp(self, device, requires_grad):
3931*da0073e9SAndroid Build Coastguard Worker        # contiguous NT should result in view.
3932*da0073e9SAndroid Build Coastguard Worker        nt = (
3933*da0073e9SAndroid Build Coastguard Worker            random_nt_from_dims(
3934*da0073e9SAndroid Build Coastguard Worker                [3, None, 10],
3935*da0073e9SAndroid Build Coastguard Worker                device=device,
3936*da0073e9SAndroid Build Coastguard Worker                dtype=torch.float32,
3937*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
3938*da0073e9SAndroid Build Coastguard Worker            )
3939*da0073e9SAndroid Build Coastguard Worker            .detach()
3940*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(requires_grad)
3941*da0073e9SAndroid Build Coastguard Worker        )
3942*da0073e9SAndroid Build Coastguard Worker        view = nt.reshape(-1, -1, 5, 2)
3943*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(view.shape[:2], nt.shape[:2])
3944*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(view._is_view() and view._base is nt)
3945*da0073e9SAndroid Build Coastguard Worker        # make sure gradients flow back
3946*da0073e9SAndroid Build Coastguard Worker        if requires_grad:
3947*da0073e9SAndroid Build Coastguard Worker            view.backward(torch.ones_like(view))
3948*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt.grad, torch.ones_like(nt))
3949*da0073e9SAndroid Build Coastguard Worker
3950*da0073e9SAndroid Build Coastguard Worker        # non-contiguous NT should result in contiguous copy
3951*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
3952*da0073e9SAndroid Build Coastguard Worker            [3, None, 5, 2],
3953*da0073e9SAndroid Build Coastguard Worker            device=device,
3954*da0073e9SAndroid Build Coastguard Worker            dtype=torch.float32,
3955*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
3956*da0073e9SAndroid Build Coastguard Worker            requires_grad=requires_grad,
3957*da0073e9SAndroid Build Coastguard Worker        )
3958*da0073e9SAndroid Build Coastguard Worker        nt_noncontig = nt.transpose(-1, -2)
3959*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(nt_noncontig.is_contiguous())
3960*da0073e9SAndroid Build Coastguard Worker        copy = nt_noncontig.reshape(-1, -1, 10)
3961*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(copy.is_contiguous())
3962*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(copy.shape[:2], nt.shape[:2])
3963*da0073e9SAndroid Build Coastguard Worker        # make sure gradients flow back
3964*da0073e9SAndroid Build Coastguard Worker        if requires_grad:
3965*da0073e9SAndroid Build Coastguard Worker            copy.backward(torch.ones_like(copy))
3966*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt.grad, torch.ones_like(nt))
3967*da0073e9SAndroid Build Coastguard Worker
3968*da0073e9SAndroid Build Coastguard Worker    def test_flatten_decomp(self, device):
3969*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
3970*da0073e9SAndroid Build Coastguard Worker            [3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged
3971*da0073e9SAndroid Build Coastguard Worker        )
3972*da0073e9SAndroid Build Coastguard Worker        flattened = nt.flatten(-2, -1)
3973*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flattened.shape, nt.view(3, -1, 10).shape)
3974*da0073e9SAndroid Build Coastguard Worker
3975*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
3976*da0073e9SAndroid Build Coastguard Worker            [3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged
3977*da0073e9SAndroid Build Coastguard Worker        )
3978*da0073e9SAndroid Build Coastguard Worker        flattened = nt.flatten(-3, -2)
3979*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flattened.shape, nt.view(3, -1, 10, 6).shape)
3980*da0073e9SAndroid Build Coastguard Worker
3981*da0073e9SAndroid Build Coastguard Worker    def test_chunk(self, device):
3982*da0073e9SAndroid Build Coastguard Worker        # none NJT case
3983*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(10, 4, 5, requires_grad=True)
3984*da0073e9SAndroid Build Coastguard Worker        t_list = t.chunk(3, dim=0)
3985*da0073e9SAndroid Build Coastguard Worker        loss = t_list[0].sum() + t_list[2].sum()
3986*da0073e9SAndroid Build Coastguard Worker        loss.backward()
3987*da0073e9SAndroid Build Coastguard Worker
3988*da0073e9SAndroid Build Coastguard Worker        # normal case
3989*da0073e9SAndroid Build Coastguard Worker        D = 30
3990*da0073e9SAndroid Build Coastguard Worker        B = 8
3991*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
3992*da0073e9SAndroid Build Coastguard Worker            [B, None, D],
3993*da0073e9SAndroid Build Coastguard Worker            device=device,
3994*da0073e9SAndroid Build Coastguard Worker            dtype=torch.float32,
3995*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
3996*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
3997*da0073e9SAndroid Build Coastguard Worker        )
3998*da0073e9SAndroid Build Coastguard Worker        NUM_CHUNKS = 3
3999*da0073e9SAndroid Build Coastguard Worker        chunks = nt.chunk(NUM_CHUNKS, dim=-1)
4000*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(chunks), NUM_CHUNKS)
4001*da0073e9SAndroid Build Coastguard Worker        for i in range(NUM_CHUNKS):
4002*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(chunks[i].shape[-1], D // NUM_CHUNKS)
4003*da0073e9SAndroid Build Coastguard Worker
4004*da0073e9SAndroid Build Coastguard Worker        # test chunk_backward
4005*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(
4006*da0073e9SAndroid Build Coastguard Worker            5, 11, dtype=torch.float64, device=device, requires_grad=True
4007*da0073e9SAndroid Build Coastguard Worker        )
4008*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 3, 5], device=device)
4009*da0073e9SAndroid Build Coastguard Worker
4010*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(values, offsets):
4011*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
4012*da0073e9SAndroid Build Coastguard Worker            chunks = nt.chunk(3, dim=-1)
4013*da0073e9SAndroid Build Coastguard Worker            return chunks[0].values().sum()
4014*da0073e9SAndroid Build Coastguard Worker
4015*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(
4016*da0073e9SAndroid Build Coastguard Worker            grad_test_func,
4017*da0073e9SAndroid Build Coastguard Worker            inputs=(values, offsets),
4018*da0073e9SAndroid Build Coastguard Worker            check_batched_grad=False,
4019*da0073e9SAndroid Build Coastguard Worker        )
4020*da0073e9SAndroid Build Coastguard Worker
4021*da0073e9SAndroid Build Coastguard Worker        # chunk on batch dim
4022*da0073e9SAndroid Build Coastguard Worker        chunks = nt.chunk(NUM_CHUNKS, dim=0)
4023*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(chunks), NUM_CHUNKS)
4024*da0073e9SAndroid Build Coastguard Worker        chunk_size = math.ceil(B / NUM_CHUNKS)
4025*da0073e9SAndroid Build Coastguard Worker        for i in range(NUM_CHUNKS):
4026*da0073e9SAndroid Build Coastguard Worker            if i < NUM_CHUNKS - 1:
4027*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(chunks[i].shape[0], chunk_size)
4028*da0073e9SAndroid Build Coastguard Worker            else:
4029*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(chunks[i].shape[0], B - chunk_size * (NUM_CHUNKS - 1))
4030*da0073e9SAndroid Build Coastguard Worker            offsets_expected = (
4031*da0073e9SAndroid Build Coastguard Worker                nt._offsets[i * chunk_size + 1 : (i + 1) * chunk_size + 1]
4032*da0073e9SAndroid Build Coastguard Worker                - nt._offsets[i * chunk_size]
4033*da0073e9SAndroid Build Coastguard Worker            )
4034*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(chunks[i]._offsets[1:], offsets_expected)
4035*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt._values, torch.cat([x._values for x in chunks], dim=0))
4036*da0073e9SAndroid Build Coastguard Worker
4037*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4038*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
4039*da0073e9SAndroid Build Coastguard Worker            "dim != 0 INTERNAL ASSERT FAILED .* Nested Tensor doesn't support chunk backward on dim=0 yet.",
4040*da0073e9SAndroid Build Coastguard Worker        ):
4041*da0073e9SAndroid Build Coastguard Worker            # doesn't support backward for chunk (dim=0) yet
4042*da0073e9SAndroid Build Coastguard Worker            loss = (
4043*da0073e9SAndroid Build Coastguard Worker                chunks[0].values().sum()
4044*da0073e9SAndroid Build Coastguard Worker                + chunks[1].values().sum()
4045*da0073e9SAndroid Build Coastguard Worker                + chunks[2].values().sum()
4046*da0073e9SAndroid Build Coastguard Worker            )
4047*da0073e9SAndroid Build Coastguard Worker            loss.backward()
4048*da0073e9SAndroid Build Coastguard Worker
4049*da0073e9SAndroid Build Coastguard Worker        # chunk on ragged dim not supported
4050*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4051*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "chunk.* not supported for NestedTensor on dim=1"
4052*da0073e9SAndroid Build Coastguard Worker        ):
4053*da0073e9SAndroid Build Coastguard Worker            nt.chunk(2, dim=1)
4054*da0073e9SAndroid Build Coastguard Worker
4055*da0073e9SAndroid Build Coastguard Worker    def test_squeeze(self, device):
4056*da0073e9SAndroid Build Coastguard Worker        B = 4
4057*da0073e9SAndroid Build Coastguard Worker        D = 6
4058*da0073e9SAndroid Build Coastguard Worker        # squeeze middle dim
4059*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
4060*da0073e9SAndroid Build Coastguard Worker            [B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged
4061*da0073e9SAndroid Build Coastguard Worker        )
4062*da0073e9SAndroid Build Coastguard Worker        j0 = nt.shape[1]
4063*da0073e9SAndroid Build Coastguard Worker
4064*da0073e9SAndroid Build Coastguard Worker        for dim_arg in [-2, 2]:
4065*da0073e9SAndroid Build Coastguard Worker            out = nt.squeeze(dim_arg)
4066*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.shape, (B, j0, D))
4067*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.unsqueeze(-2), nt)
4068*da0073e9SAndroid Build Coastguard Worker
4069*da0073e9SAndroid Build Coastguard Worker        # squeeze last dim
4070*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
4071*da0073e9SAndroid Build Coastguard Worker            [B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged
4072*da0073e9SAndroid Build Coastguard Worker        )
4073*da0073e9SAndroid Build Coastguard Worker        j1 = nt.shape[1]
4074*da0073e9SAndroid Build Coastguard Worker
4075*da0073e9SAndroid Build Coastguard Worker        for dim_arg in [-1, 2]:
4076*da0073e9SAndroid Build Coastguard Worker            out = nt.squeeze(dim_arg)
4077*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.shape, (B, j1))
4078*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.unsqueeze(-1), nt)
4079*da0073e9SAndroid Build Coastguard Worker
4080*da0073e9SAndroid Build Coastguard Worker        # squeeze on batch dim not supported
4081*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4082*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "squeeze.* not supported for NestedTensor on dim=0"
4083*da0073e9SAndroid Build Coastguard Worker        ):
4084*da0073e9SAndroid Build Coastguard Worker            nt.squeeze(0)
4085*da0073e9SAndroid Build Coastguard Worker
4086*da0073e9SAndroid Build Coastguard Worker        # squeeze on ragged dim not supported
4087*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4088*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "squeeze.* not supported for NestedTensor on dim=1"
4089*da0073e9SAndroid Build Coastguard Worker        ):
4090*da0073e9SAndroid Build Coastguard Worker            nt.squeeze(1)
4091*da0073e9SAndroid Build Coastguard Worker
4092*da0073e9SAndroid Build Coastguard Worker    def test_binary_pointwise_broadcasting(self, device):
4093*da0073e9SAndroid Build Coastguard Worker        # (B, j0, 3, 4)
4094*da0073e9SAndroid Build Coastguard Worker        ts = self._get_list_for_jagged_tensor(
4095*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), 3, 4), device, requires_grad=True
4096*da0073e9SAndroid Build Coastguard Worker        )
4097*da0073e9SAndroid Build Coastguard Worker        # (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
4098*da0073e9SAndroid Build Coastguard Worker        # (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
4099*da0073e9SAndroid Build Coastguard Worker        # (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?)
4100*da0073e9SAndroid Build Coastguard Worker        # Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)
4101*da0073e9SAndroid Build Coastguard Worker        t_sizes = (
4102*da0073e9SAndroid Build Coastguard Worker            (4,),
4103*da0073e9SAndroid Build Coastguard Worker            (1, 4),
4104*da0073e9SAndroid Build Coastguard Worker            (3, 1),
4105*da0073e9SAndroid Build Coastguard Worker            (1, 3, 1),
4106*da0073e9SAndroid Build Coastguard Worker            (1, 1, 1, 4),
4107*da0073e9SAndroid Build Coastguard Worker            # (1, 1, 1, 1, 4), (unsupported today)
4108*da0073e9SAndroid Build Coastguard Worker        )
4109*da0073e9SAndroid Build Coastguard Worker
4110*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(t, *ts):
4111*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor(list(ts), layout=torch.jagged)
4112*da0073e9SAndroid Build Coastguard Worker            out = nt + t
4113*da0073e9SAndroid Build Coastguard Worker            return out.values()
4114*da0073e9SAndroid Build Coastguard Worker
4115*da0073e9SAndroid Build Coastguard Worker        for t_size in t_sizes:
4116*da0073e9SAndroid Build Coastguard Worker            t = torch.rand(
4117*da0073e9SAndroid Build Coastguard Worker                t_size, requires_grad=True, device=device, dtype=torch.float64
4118*da0073e9SAndroid Build Coastguard Worker            )
4119*da0073e9SAndroid Build Coastguard Worker            gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)
4120*da0073e9SAndroid Build Coastguard Worker
4121*da0073e9SAndroid Build Coastguard Worker    def test_threshold_backward(self, device):
4122*da0073e9SAndroid Build Coastguard Worker        ts1 = self._get_list_for_jagged_tensor(
4123*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), 16), device=device, requires_grad=False
4124*da0073e9SAndroid Build Coastguard Worker        )
4125*da0073e9SAndroid Build Coastguard Worker        ts2 = self._get_list_for_jagged_tensor(
4126*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), 16), device=device, requires_grad=False
4127*da0073e9SAndroid Build Coastguard Worker        )
4128*da0073e9SAndroid Build Coastguard Worker
4129*da0073e9SAndroid Build Coastguard Worker        nt1, offsets = jagged_from_list(ts1, None)
4130*da0073e9SAndroid Build Coastguard Worker        nt2, offsets = jagged_from_list(ts2, offsets)
4131*da0073e9SAndroid Build Coastguard Worker        buf1 = nt1.values().detach().clone()
4132*da0073e9SAndroid Build Coastguard Worker        buf2 = nt2.values().detach().clone()
4133*da0073e9SAndroid Build Coastguard Worker
4134*da0073e9SAndroid Build Coastguard Worker        res_nt = torch.ops.aten.threshold_backward(nt1, nt2, 0.0)
4135*da0073e9SAndroid Build Coastguard Worker        res_dense = torch.ops.aten.threshold_backward(buf1, buf2, 0.0)
4136*da0073e9SAndroid Build Coastguard Worker
4137*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_dense, res_nt.values())
4138*da0073e9SAndroid Build Coastguard Worker
4139*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4140*da0073e9SAndroid Build Coastguard Worker    @parametrize(
4141*da0073e9SAndroid Build Coastguard Worker        "func",
4142*da0073e9SAndroid Build Coastguard Worker        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
4143*da0073e9SAndroid Build Coastguard Worker        name_fn=get_op_name,
4144*da0073e9SAndroid Build Coastguard Worker    )
4145*da0073e9SAndroid Build Coastguard Worker    @parametrize("keepdim", [False, True])
4146*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4147*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4148*da0073e9SAndroid Build Coastguard Worker    def test_jagged_op_different_output_shape_dim(
4149*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, keepdim, requires_grad, components_require_grad, func
4150*da0073e9SAndroid Build Coastguard Worker    ):
4151*da0073e9SAndroid Build Coastguard Worker        """
4152*da0073e9SAndroid Build Coastguard Worker        Operator passes when reducing on valid reduction dimensions.
4153*da0073e9SAndroid Build Coastguard Worker        This test is for operators which return an output tensor with a shape different from the input tensor.
4154*da0073e9SAndroid Build Coastguard Worker        """
4155*da0073e9SAndroid Build Coastguard Worker        if get_op_name(func) == "mean" and not keepdim:
4156*da0073e9SAndroid Build Coastguard Worker            return
4157*da0073e9SAndroid Build Coastguard Worker
4158*da0073e9SAndroid Build Coastguard Worker        op_name = get_op_name(func)
4159*da0073e9SAndroid Build Coastguard Worker
4160*da0073e9SAndroid Build Coastguard Worker        ts = self._get_list_for_jagged_tensor(
4161*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), 3, 4), device=device, requires_grad=True
4162*da0073e9SAndroid Build Coastguard Worker        )  # (B, j0, 3, 4)
4163*da0073e9SAndroid Build Coastguard Worker
4164*da0073e9SAndroid Build Coastguard Worker        # verify correctness of shapes (assuming that ragged_idx == 1)
4165*da0073e9SAndroid Build Coastguard Worker        if op_name == "sum":
4166*da0073e9SAndroid Build Coastguard Worker            reduce_dims = (
4167*da0073e9SAndroid Build Coastguard Worker                ((0, 1), (3, 4), (1, 1, 3, 4), (0,)),  # batch, ragged
4168*da0073e9SAndroid Build Coastguard Worker                ((2, 3), (3, None), (3, None, 1, 1), (1, 2)),  # non-batch, non-batch
4169*da0073e9SAndroid Build Coastguard Worker                ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)),  # batch, ragged, non-batch
4170*da0073e9SAndroid Build Coastguard Worker                ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)),  # batch, ragged, non-batch
4171*da0073e9SAndroid Build Coastguard Worker                (
4172*da0073e9SAndroid Build Coastguard Worker                    (0, 1, 2, 3),
4173*da0073e9SAndroid Build Coastguard Worker                    (),
4174*da0073e9SAndroid Build Coastguard Worker                    (1, 1, 1, 1),
4175*da0073e9SAndroid Build Coastguard Worker                    (0, 1, 2),
4176*da0073e9SAndroid Build Coastguard Worker                ),  # batch, ragged, non-batch, non-batch
4177*da0073e9SAndroid Build Coastguard Worker                ((2,), (3, None, 4), (3, None, 1, 4), (1,)),  # non-batch
4178*da0073e9SAndroid Build Coastguard Worker            )  # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None
4179*da0073e9SAndroid Build Coastguard Worker        elif op_name == "mean":
4180*da0073e9SAndroid Build Coastguard Worker            reduce_dims = (
4181*da0073e9SAndroid Build Coastguard Worker                ((2,), (3, None, 4), (3, None, 1, 4), (1,)),
4182*da0073e9SAndroid Build Coastguard Worker                ((3,), (3, None, 3), (3, None, 3, 1), (2,)),
4183*da0073e9SAndroid Build Coastguard Worker            )
4184*da0073e9SAndroid Build Coastguard Worker
4185*da0073e9SAndroid Build Coastguard Worker        for rd, ref_shape_no_keepdim, ref_shape_keepdim, _ in reduce_dims:
4186*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
4187*da0073e9SAndroid Build Coastguard Worker            out = func(nt, dim=rd, keepdim=keepdim)
4188*da0073e9SAndroid Build Coastguard Worker            ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
4189*da0073e9SAndroid Build Coastguard Worker            if not torch.compiler.is_compiling:  # if not using torch dynamo
4190*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(out.shape), len(ref_shape))
4191*da0073e9SAndroid Build Coastguard Worker                for o, r in zip(out.shape, ref_shape):
4192*da0073e9SAndroid Build Coastguard Worker                    if r is not None:
4193*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(o, r)
4194*da0073e9SAndroid Build Coastguard Worker                    else:
4195*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(isinstance(o, torch.SymInt))
4196*da0073e9SAndroid Build Coastguard Worker
4197*da0073e9SAndroid Build Coastguard Worker        # verify correctness of values
4198*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4199*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False,
4200*da0073e9SAndroid Build Coastguard Worker            include_requires_grad=components_require_grad,
4201*da0073e9SAndroid Build Coastguard Worker            include_inner_dim_size_1=True,
4202*da0073e9SAndroid Build Coastguard Worker        )
4203*da0073e9SAndroid Build Coastguard Worker        for tensor_list, reduce_dim_tuple in itertools.product(
4204*da0073e9SAndroid Build Coastguard Worker            tensor_lists, reduce_dims
4205*da0073e9SAndroid Build Coastguard Worker        ):
4206*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4207*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4208*da0073e9SAndroid Build Coastguard Worker                device=device,
4209*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4210*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4211*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4212*da0073e9SAndroid Build Coastguard Worker            )
4213*da0073e9SAndroid Build Coastguard Worker
4214*da0073e9SAndroid Build Coastguard Worker            reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple
4215*da0073e9SAndroid Build Coastguard Worker
4216*da0073e9SAndroid Build Coastguard Worker            if nt.dim() > reduce_dim[-1]:
4217*da0073e9SAndroid Build Coastguard Worker                out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
4218*da0073e9SAndroid Build Coastguard Worker                if nt._ragged_idx in reduce_dim:  # raggedness reduced away
4219*da0073e9SAndroid Build Coastguard Worker                    out_expected = func(
4220*da0073e9SAndroid Build Coastguard Worker                        nt.values(), dim=reduce_dim_expected, keepdim=keepdim
4221*da0073e9SAndroid Build Coastguard Worker                    )
4222*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.allclose(out_actual, out_expected))
4223*da0073e9SAndroid Build Coastguard Worker                else:  # raggedness preserved
4224*da0073e9SAndroid Build Coastguard Worker                    out_expected = func(nt.values(), dim=reduce_dim_expected)
4225*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
4226*da0073e9SAndroid Build Coastguard Worker                        torch.allclose(
4227*da0073e9SAndroid Build Coastguard Worker                            out_actual.values().view(-1), out_expected.view(-1)
4228*da0073e9SAndroid Build Coastguard Worker                        )
4229*da0073e9SAndroid Build Coastguard Worker                    )
4230*da0073e9SAndroid Build Coastguard Worker
4231*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4232*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4233*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4234*da0073e9SAndroid Build Coastguard Worker    def test_softmax_dim(
4235*da0073e9SAndroid Build Coastguard Worker        self,
4236*da0073e9SAndroid Build Coastguard Worker        device,
4237*da0073e9SAndroid Build Coastguard Worker        dtype,
4238*da0073e9SAndroid Build Coastguard Worker        requires_grad,
4239*da0073e9SAndroid Build Coastguard Worker        components_require_grad,
4240*da0073e9SAndroid Build Coastguard Worker    ):
4241*da0073e9SAndroid Build Coastguard Worker        """
4242*da0073e9SAndroid Build Coastguard Worker        Softmax passes when reducing on valid reduction dimensions.
4243*da0073e9SAndroid Build Coastguard Worker        """
4244*da0073e9SAndroid Build Coastguard Worker        ts = self._get_list_for_jagged_tensor(
4245*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), 3, 4), device=device, requires_grad=True
4246*da0073e9SAndroid Build Coastguard Worker        )  # (B, j0, 3, 4)
4247*da0073e9SAndroid Build Coastguard Worker
4248*da0073e9SAndroid Build Coastguard Worker        output_shape = (3, None, 3, 4)
4249*da0073e9SAndroid Build Coastguard Worker
4250*da0073e9SAndroid Build Coastguard Worker        # verify correctness of shapes (assuming that ragged_idx == 1)
4251*da0073e9SAndroid Build Coastguard Worker        reduce_dims = (
4252*da0073e9SAndroid Build Coastguard Worker            (2, 1),
4253*da0073e9SAndroid Build Coastguard Worker            (3, 2),
4254*da0073e9SAndroid Build Coastguard Worker        )  # (reduction dimension, effective reduction dimension for baseline)
4255*da0073e9SAndroid Build Coastguard Worker
4256*da0073e9SAndroid Build Coastguard Worker        for reduce_dim, _ in reduce_dims:
4257*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
4258*da0073e9SAndroid Build Coastguard Worker            out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim)
4259*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.disable(self.assertEqual)(
4260*da0073e9SAndroid Build Coastguard Worker                len(out_actual.shape), len(output_shape)
4261*da0073e9SAndroid Build Coastguard Worker            )  # disable if running on dynamo
4262*da0073e9SAndroid Build Coastguard Worker            for dim_actual, dim_expected in zip(out_actual.shape, output_shape):
4263*da0073e9SAndroid Build Coastguard Worker                if dim_expected is not None:
4264*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(dim_actual, dim_expected)
4265*da0073e9SAndroid Build Coastguard Worker                else:
4266*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(isinstance(dim_actual, torch.SymInt))
4267*da0073e9SAndroid Build Coastguard Worker
4268*da0073e9SAndroid Build Coastguard Worker        # verify correctness of values
4269*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4270*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False,
4271*da0073e9SAndroid Build Coastguard Worker            include_requires_grad=components_require_grad,
4272*da0073e9SAndroid Build Coastguard Worker            include_inner_dim_size_1=True,
4273*da0073e9SAndroid Build Coastguard Worker        )
4274*da0073e9SAndroid Build Coastguard Worker        for tensor_list, reduce_dim_tuple in itertools.product(
4275*da0073e9SAndroid Build Coastguard Worker            tensor_lists, reduce_dims
4276*da0073e9SAndroid Build Coastguard Worker        ):
4277*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4278*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4279*da0073e9SAndroid Build Coastguard Worker                device=device,
4280*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4281*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4282*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4283*da0073e9SAndroid Build Coastguard Worker            )
4284*da0073e9SAndroid Build Coastguard Worker
4285*da0073e9SAndroid Build Coastguard Worker            reduce_dim, reduce_dim_expected = reduce_dim_tuple
4286*da0073e9SAndroid Build Coastguard Worker
4287*da0073e9SAndroid Build Coastguard Worker            if nt.dim() > reduce_dim:
4288*da0073e9SAndroid Build Coastguard Worker                out_actual = torch.nn.functional.softmax(
4289*da0073e9SAndroid Build Coastguard Worker                    nt, dim=reduce_dim
4290*da0073e9SAndroid Build Coastguard Worker                )  # nested tensor
4291*da0073e9SAndroid Build Coastguard Worker                out_expected = torch.nn.functional.softmax(
4292*da0073e9SAndroid Build Coastguard Worker                    nt.values(), dim=reduce_dim_expected
4293*da0073e9SAndroid Build Coastguard Worker                )  # dense tensor of dimensions 1 less than out_actual
4294*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
4295*da0073e9SAndroid Build Coastguard Worker                    torch.allclose(out_actual.values().view(-1), out_expected.view(-1))
4296*da0073e9SAndroid Build Coastguard Worker                )
4297*da0073e9SAndroid Build Coastguard Worker
4298*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4299*da0073e9SAndroid Build Coastguard Worker    @parametrize(
4300*da0073e9SAndroid Build Coastguard Worker        "func",
4301*da0073e9SAndroid Build Coastguard Worker        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
4302*da0073e9SAndroid Build Coastguard Worker        name_fn=get_op_name,
4303*da0073e9SAndroid Build Coastguard Worker    )
4304*da0073e9SAndroid Build Coastguard Worker    @parametrize("keepdim", [False, True])
4305*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4306*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4307*da0073e9SAndroid Build Coastguard Worker    def test_op_dim_reduce_ragged_idx_1_different_output_shape(
4308*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, keepdim, requires_grad, components_require_grad, func
4309*da0073e9SAndroid Build Coastguard Worker    ):
4310*da0073e9SAndroid Build Coastguard Worker        """
4311*da0073e9SAndroid Build Coastguard Worker        Operator on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1.
4312*da0073e9SAndroid Build Coastguard Worker        This test is for operators which return an output tensor with a shape different from the input tensor.
4313*da0073e9SAndroid Build Coastguard Worker        """
4314*da0073e9SAndroid Build Coastguard Worker        if get_op_name(func) == "mean" and not keepdim:
4315*da0073e9SAndroid Build Coastguard Worker            return
4316*da0073e9SAndroid Build Coastguard Worker
4317*da0073e9SAndroid Build Coastguard Worker        op_name = get_op_name(func)
4318*da0073e9SAndroid Build Coastguard Worker
4319*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4320*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False,
4321*da0073e9SAndroid Build Coastguard Worker            include_requires_grad=components_require_grad,
4322*da0073e9SAndroid Build Coastguard Worker            include_inner_dim_size_1=True,  # (B, *, 1)
4323*da0073e9SAndroid Build Coastguard Worker        )
4324*da0073e9SAndroid Build Coastguard Worker        reduce_dim = (1,)  # ragged
4325*da0073e9SAndroid Build Coastguard Worker
4326*da0073e9SAndroid Build Coastguard Worker        for tensor_list in tensor_lists:
4327*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4328*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4329*da0073e9SAndroid Build Coastguard Worker                device=device,
4330*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4331*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4332*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4333*da0073e9SAndroid Build Coastguard Worker            )
4334*da0073e9SAndroid Build Coastguard Worker
4335*da0073e9SAndroid Build Coastguard Worker            out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
4336*da0073e9SAndroid Build Coastguard Worker            out_expected = torch.cat(
4337*da0073e9SAndroid Build Coastguard Worker                [func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0) for t in nt.unbind()]
4338*da0073e9SAndroid Build Coastguard Worker            )
4339*da0073e9SAndroid Build Coastguard Worker
4340*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
4341*da0073e9SAndroid Build Coastguard Worker                out_actual.is_nested,
4342*da0073e9SAndroid Build Coastguard Worker                f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor",
4343*da0073e9SAndroid Build Coastguard Worker            )  # output is a dense tensor
4344*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(out_actual, out_expected))
4345*da0073e9SAndroid Build Coastguard Worker
4346*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4347*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4348*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4349*da0073e9SAndroid Build Coastguard Worker    def test_softmax_dim_reduce_ragged_idx_1(
4350*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, requires_grad, components_require_grad
4351*da0073e9SAndroid Build Coastguard Worker    ):
4352*da0073e9SAndroid Build Coastguard Worker        """
4353*da0073e9SAndroid Build Coastguard Worker        Softmax on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1.
4354*da0073e9SAndroid Build Coastguard Worker        """
4355*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4356*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False,
4357*da0073e9SAndroid Build Coastguard Worker            include_requires_grad=components_require_grad,
4358*da0073e9SAndroid Build Coastguard Worker            include_inner_dim_size_1=True,  # (B, *, 1)
4359*da0073e9SAndroid Build Coastguard Worker            include_2d_tensor=True,  # (B, *)
4360*da0073e9SAndroid Build Coastguard Worker        )
4361*da0073e9SAndroid Build Coastguard Worker        reduce_dim = 1  # ragged
4362*da0073e9SAndroid Build Coastguard Worker
4363*da0073e9SAndroid Build Coastguard Worker        for tensor_list in tensor_lists:
4364*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4365*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4366*da0073e9SAndroid Build Coastguard Worker                device=device,
4367*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4368*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4369*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4370*da0073e9SAndroid Build Coastguard Worker            )
4371*da0073e9SAndroid Build Coastguard Worker
4372*da0073e9SAndroid Build Coastguard Worker            out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim)
4373*da0073e9SAndroid Build Coastguard Worker            out_expected = torch.cat(
4374*da0073e9SAndroid Build Coastguard Worker                [
4375*da0073e9SAndroid Build Coastguard Worker                    torch.nn.functional.softmax(t, dim=reduce_dim - 1)
4376*da0073e9SAndroid Build Coastguard Worker                    for t in nt.unbind()
4377*da0073e9SAndroid Build Coastguard Worker                ]
4378*da0073e9SAndroid Build Coastguard Worker            )
4379*da0073e9SAndroid Build Coastguard Worker
4380*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
4381*da0073e9SAndroid Build Coastguard Worker                out_actual.is_nested,
4382*da0073e9SAndroid Build Coastguard Worker                "softmax(): the result of reducing a nested tensor along the ragged dimension is a nested tensor",
4383*da0073e9SAndroid Build Coastguard Worker            )  # output is a nested tensor
4384*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(out_actual.values(), out_expected))
4385*da0073e9SAndroid Build Coastguard Worker
4386*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4387*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4388*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4389*da0073e9SAndroid Build Coastguard Worker    def test_softmax_reduce_batch_dim(
4390*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, requires_grad, components_require_grad
4391*da0073e9SAndroid Build Coastguard Worker    ):
4392*da0073e9SAndroid Build Coastguard Worker        """
4393*da0073e9SAndroid Build Coastguard Worker        Softmax on NestedTensor fails when trying to reduce across batch dimension.
4394*da0073e9SAndroid Build Coastguard Worker        """
4395*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4396*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False,
4397*da0073e9SAndroid Build Coastguard Worker            include_requires_grad=components_require_grad,
4398*da0073e9SAndroid Build Coastguard Worker            include_inner_dim_size_1=True,  # (B, *, 1)
4399*da0073e9SAndroid Build Coastguard Worker        )
4400*da0073e9SAndroid Build Coastguard Worker        reduce_dim = 0  # batch
4401*da0073e9SAndroid Build Coastguard Worker
4402*da0073e9SAndroid Build Coastguard Worker        for tensor_list in tensor_lists:
4403*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4404*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4405*da0073e9SAndroid Build Coastguard Worker                device=device,
4406*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4407*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4408*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4409*da0073e9SAndroid Build Coastguard Worker            )
4410*da0073e9SAndroid Build Coastguard Worker
4411*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
4412*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
4413*da0073e9SAndroid Build Coastguard Worker                "not supported when reducing across the batch dimension for NestedTensor",
4414*da0073e9SAndroid Build Coastguard Worker            ):
4415*da0073e9SAndroid Build Coastguard Worker                out = torch.nn.functional.softmax(nt, dim=reduce_dim)
4416*da0073e9SAndroid Build Coastguard Worker
4417*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4418*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4419*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4420*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm_reduce_ragged_idx_1(
4421*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, requires_grad, components_require_grad
4422*da0073e9SAndroid Build Coastguard Worker    ):
4423*da0073e9SAndroid Build Coastguard Worker        """
4424*da0073e9SAndroid Build Coastguard Worker        Layer normalization on NestedTensor passes when trying to normalize across ragged dimension, where ragged_idx == 1.
4425*da0073e9SAndroid Build Coastguard Worker        """
4426*da0073e9SAndroid Build Coastguard Worker
4427*da0073e9SAndroid Build Coastguard Worker        # requires_grad = False does not currently work with dynamo tests and throws this error:
4428*da0073e9SAndroid Build Coastguard Worker        #   AssertionError: SymInts must use SymNodeVariable.
4429*da0073e9SAndroid Build Coastguard Worker        #   If the underlying value is static, we will create a ConstantVariable and specialize.
4430*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.is_compiling() and not requires_grad:
4431*da0073e9SAndroid Build Coastguard Worker            return
4432*da0073e9SAndroid Build Coastguard Worker
4433*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4434*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False,
4435*da0073e9SAndroid Build Coastguard Worker            include_requires_grad=components_require_grad,
4436*da0073e9SAndroid Build Coastguard Worker            include_inner_dim_size_1=True,  # (B, *, 1)
4437*da0073e9SAndroid Build Coastguard Worker        )
4438*da0073e9SAndroid Build Coastguard Worker
4439*da0073e9SAndroid Build Coastguard Worker        for tensor_list in tensor_lists:
4440*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4441*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4442*da0073e9SAndroid Build Coastguard Worker                device=device,
4443*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4444*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4445*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4446*da0073e9SAndroid Build Coastguard Worker            )
4447*da0073e9SAndroid Build Coastguard Worker
4448*da0073e9SAndroid Build Coastguard Worker            if (
4449*da0073e9SAndroid Build Coastguard Worker                nt.dim() >= 3
4450*da0073e9SAndroid Build Coastguard Worker            ):  # layer norm only works for tensors with 3 or more dimensions
4451*da0073e9SAndroid Build Coastguard Worker                normalized_shape = nt.shape[nt._ragged_idx :]
4452*da0073e9SAndroid Build Coastguard Worker
4453*da0073e9SAndroid Build Coastguard Worker                out_actual = torch.nn.functional.layer_norm(
4454*da0073e9SAndroid Build Coastguard Worker                    nt, normalized_shape=normalized_shape
4455*da0073e9SAndroid Build Coastguard Worker                )
4456*da0073e9SAndroid Build Coastguard Worker                out_expected = torch.cat(
4457*da0073e9SAndroid Build Coastguard Worker                    [
4458*da0073e9SAndroid Build Coastguard Worker                        torch.nn.functional.layer_norm(t, normalized_shape=t.shape)
4459*da0073e9SAndroid Build Coastguard Worker                        for t in nt.unbind()
4460*da0073e9SAndroid Build Coastguard Worker                    ]
4461*da0073e9SAndroid Build Coastguard Worker                )  # e.g. in 3D tensor (B, *, M), performs layer normalization on B 2D tensors (*, M)
4462*da0073e9SAndroid Build Coastguard Worker
4463*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
4464*da0073e9SAndroid Build Coastguard Worker                    out_actual.is_nested,
4465*da0073e9SAndroid Build Coastguard Worker                    "layer_norm(): the result of reducing a nested tensor along the ragged dimension is a nested tensor",
4466*da0073e9SAndroid Build Coastguard Worker                )  # output is a nested tensor
4467*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out_actual._values.shape, out_expected.shape)
4468*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.allclose(out_actual.values(), out_expected))
4469*da0073e9SAndroid Build Coastguard Worker
4470*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4471*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4472*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4473*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm_2d_input(
4474*da0073e9SAndroid Build Coastguard Worker        self,
4475*da0073e9SAndroid Build Coastguard Worker        device,
4476*da0073e9SAndroid Build Coastguard Worker        dtype,
4477*da0073e9SAndroid Build Coastguard Worker        requires_grad,
4478*da0073e9SAndroid Build Coastguard Worker        components_require_grad,
4479*da0073e9SAndroid Build Coastguard Worker    ):
4480*da0073e9SAndroid Build Coastguard Worker        """
4481*da0073e9SAndroid Build Coastguard Worker        Layer normalization on NestedTensor fails when trying to operate on a 2-dimensional tensor
4482*da0073e9SAndroid Build Coastguard Worker        """
4483*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4484*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False,
4485*da0073e9SAndroid Build Coastguard Worker            include_requires_grad=components_require_grad,
4486*da0073e9SAndroid Build Coastguard Worker            include_inner_dim_size_1=True,  # (B, *, 1)
4487*da0073e9SAndroid Build Coastguard Worker            include_2d_tensor=True,  # (B, *)
4488*da0073e9SAndroid Build Coastguard Worker        )
4489*da0073e9SAndroid Build Coastguard Worker
4490*da0073e9SAndroid Build Coastguard Worker        for tensor_list in tensor_lists:
4491*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4492*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4493*da0073e9SAndroid Build Coastguard Worker                device=device,
4494*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4495*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4496*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4497*da0073e9SAndroid Build Coastguard Worker            )
4498*da0073e9SAndroid Build Coastguard Worker
4499*da0073e9SAndroid Build Coastguard Worker            if nt.dim() <= 2:
4500*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
4501*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
4502*da0073e9SAndroid Build Coastguard Worker                    "not supported for NestedTensor objects with 2 or fewer dimensions",
4503*da0073e9SAndroid Build Coastguard Worker                ):
4504*da0073e9SAndroid Build Coastguard Worker                    out = torch.nn.functional.layer_norm(
4505*da0073e9SAndroid Build Coastguard Worker                        nt, normalized_shape=(nt.shape[nt._ragged_idx],)
4506*da0073e9SAndroid Build Coastguard Worker                    )
4507*da0073e9SAndroid Build Coastguard Worker
4508*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4509*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4510*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4511*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm_operate_on_batch_dim(
4512*da0073e9SAndroid Build Coastguard Worker        self,
4513*da0073e9SAndroid Build Coastguard Worker        device,
4514*da0073e9SAndroid Build Coastguard Worker        dtype,
4515*da0073e9SAndroid Build Coastguard Worker        requires_grad,
4516*da0073e9SAndroid Build Coastguard Worker        components_require_grad,
4517*da0073e9SAndroid Build Coastguard Worker    ):
4518*da0073e9SAndroid Build Coastguard Worker        """
4519*da0073e9SAndroid Build Coastguard Worker        Layer normalization on NestedTensor fails when trying to operate on the batch dimension
4520*da0073e9SAndroid Build Coastguard Worker        """
4521*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4522*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False,
4523*da0073e9SAndroid Build Coastguard Worker            include_requires_grad=components_require_grad,
4524*da0073e9SAndroid Build Coastguard Worker            include_inner_dim_size_1=True,  # (B, *, 1)
4525*da0073e9SAndroid Build Coastguard Worker            include_2d_tensor=True,  # (B, *)
4526*da0073e9SAndroid Build Coastguard Worker        )
4527*da0073e9SAndroid Build Coastguard Worker
4528*da0073e9SAndroid Build Coastguard Worker        for tensor_list in tensor_lists:
4529*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4530*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4531*da0073e9SAndroid Build Coastguard Worker                device=device,
4532*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4533*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4534*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4535*da0073e9SAndroid Build Coastguard Worker            )
4536*da0073e9SAndroid Build Coastguard Worker
4537*da0073e9SAndroid Build Coastguard Worker            if nt.dim() > 2:  # cannot perform layer normalization on 2D tensors
4538*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
4539*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
4540*da0073e9SAndroid Build Coastguard Worker                    "not supported when normalizing over the batch dimension for NestedTensor",
4541*da0073e9SAndroid Build Coastguard Worker                ):
4542*da0073e9SAndroid Build Coastguard Worker                    out = torch.nn.functional.layer_norm(nt, normalized_shape=nt.shape)
4543*da0073e9SAndroid Build Coastguard Worker
4544*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4545*da0073e9SAndroid Build Coastguard Worker    @parametrize(
4546*da0073e9SAndroid Build Coastguard Worker        "func",
4547*da0073e9SAndroid Build Coastguard Worker        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
4548*da0073e9SAndroid Build Coastguard Worker        name_fn=get_op_name,
4549*da0073e9SAndroid Build Coastguard Worker    )
4550*da0073e9SAndroid Build Coastguard Worker    @parametrize(
4551*da0073e9SAndroid Build Coastguard Worker        "transpose_offset", [1, 2]
4552*da0073e9SAndroid Build Coastguard Worker    )  # [transpose consecutive dimensions, transpose nonconsecutive dimensions]
4553*da0073e9SAndroid Build Coastguard Worker    @parametrize("keepdim", [False, True])
4554*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4555*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4556*da0073e9SAndroid Build Coastguard Worker    def test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape(
4557*da0073e9SAndroid Build Coastguard Worker        self,
4558*da0073e9SAndroid Build Coastguard Worker        device,
4559*da0073e9SAndroid Build Coastguard Worker        dtype,
4560*da0073e9SAndroid Build Coastguard Worker        keepdim,
4561*da0073e9SAndroid Build Coastguard Worker        requires_grad,
4562*da0073e9SAndroid Build Coastguard Worker        components_require_grad,
4563*da0073e9SAndroid Build Coastguard Worker        func,
4564*da0073e9SAndroid Build Coastguard Worker        transpose_offset,
4565*da0073e9SAndroid Build Coastguard Worker    ):
4566*da0073e9SAndroid Build Coastguard Worker        """
4567*da0073e9SAndroid Build Coastguard Worker        Operator on NestedTensor passes when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1
4568*da0073e9SAndroid Build Coastguard Worker        This test is for operators which return an output tensor with a shape different from the input tensor.
4569*da0073e9SAndroid Build Coastguard Worker        """
4570*da0073e9SAndroid Build Coastguard Worker        if get_op_name(func) == "mean" and not keepdim:
4571*da0073e9SAndroid Build Coastguard Worker            return
4572*da0073e9SAndroid Build Coastguard Worker
4573*da0073e9SAndroid Build Coastguard Worker        op_name = get_op_name(func)
4574*da0073e9SAndroid Build Coastguard Worker
4575*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4576*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False,
4577*da0073e9SAndroid Build Coastguard Worker            include_requires_grad=components_require_grad,
4578*da0073e9SAndroid Build Coastguard Worker            include_inner_dim_size_1=True,  # (B, *, 1)
4579*da0073e9SAndroid Build Coastguard Worker            include_2d_tensor=True,  # (B, *)
4580*da0073e9SAndroid Build Coastguard Worker        )
4581*da0073e9SAndroid Build Coastguard Worker
4582*da0073e9SAndroid Build Coastguard Worker        for tensor_list in tensor_lists:
4583*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4584*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4585*da0073e9SAndroid Build Coastguard Worker                device=device,
4586*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4587*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4588*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4589*da0073e9SAndroid Build Coastguard Worker            )
4590*da0073e9SAndroid Build Coastguard Worker
4591*da0073e9SAndroid Build Coastguard Worker            if nt.dim() > nt._ragged_idx + transpose_offset:
4592*da0073e9SAndroid Build Coastguard Worker                nt_transposed = nt.transpose(
4593*da0073e9SAndroid Build Coastguard Worker                    nt._ragged_idx, nt._ragged_idx + transpose_offset
4594*da0073e9SAndroid Build Coastguard Worker                )
4595*da0073e9SAndroid Build Coastguard Worker                reduce_dim = (nt_transposed._ragged_idx,)  # ragged
4596*da0073e9SAndroid Build Coastguard Worker
4597*da0073e9SAndroid Build Coastguard Worker                out_actual = func(nt_transposed, dim=reduce_dim, keepdim=keepdim)
4598*da0073e9SAndroid Build Coastguard Worker                out_expected = torch.cat(
4599*da0073e9SAndroid Build Coastguard Worker                    [
4600*da0073e9SAndroid Build Coastguard Worker                        func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0)
4601*da0073e9SAndroid Build Coastguard Worker                        for t in nt_transposed.unbind()
4602*da0073e9SAndroid Build Coastguard Worker                    ]
4603*da0073e9SAndroid Build Coastguard Worker                )
4604*da0073e9SAndroid Build Coastguard Worker
4605*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(
4606*da0073e9SAndroid Build Coastguard Worker                    out_actual.is_nested,
4607*da0073e9SAndroid Build Coastguard Worker                    f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor",
4608*da0073e9SAndroid Build Coastguard Worker                )  # output is a dense tensor
4609*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.allclose(out_actual, out_expected, rtol=1e-4))
4610*da0073e9SAndroid Build Coastguard Worker
4611*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4612*da0073e9SAndroid Build Coastguard Worker    @parametrize(
4613*da0073e9SAndroid Build Coastguard Worker        "transpose_offset", [1, 2]
4614*da0073e9SAndroid Build Coastguard Worker    )  # [transpose consecutive dimensions, transpose nonconsecutive dimensions]
4615*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4616*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4617*da0073e9SAndroid Build Coastguard Worker    def test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape(
4618*da0073e9SAndroid Build Coastguard Worker        self,
4619*da0073e9SAndroid Build Coastguard Worker        device,
4620*da0073e9SAndroid Build Coastguard Worker        dtype,
4621*da0073e9SAndroid Build Coastguard Worker        requires_grad,
4622*da0073e9SAndroid Build Coastguard Worker        components_require_grad,
4623*da0073e9SAndroid Build Coastguard Worker        transpose_offset,
4624*da0073e9SAndroid Build Coastguard Worker    ):
4625*da0073e9SAndroid Build Coastguard Worker        """
4626*da0073e9SAndroid Build Coastguard Worker        Softmax on NestedTensor fails when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1
4627*da0073e9SAndroid Build Coastguard Worker        This test is for operators which return an output tensor with the same shape as the input tensor.
4628*da0073e9SAndroid Build Coastguard Worker        """
4629*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4630*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False,
4631*da0073e9SAndroid Build Coastguard Worker            include_requires_grad=components_require_grad,
4632*da0073e9SAndroid Build Coastguard Worker            include_inner_dim_size_1=True,  # (B, *, 1)
4633*da0073e9SAndroid Build Coastguard Worker        )
4634*da0073e9SAndroid Build Coastguard Worker
4635*da0073e9SAndroid Build Coastguard Worker        for tensor_list in tensor_lists:
4636*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4637*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4638*da0073e9SAndroid Build Coastguard Worker                device=device,
4639*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4640*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4641*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4642*da0073e9SAndroid Build Coastguard Worker            )
4643*da0073e9SAndroid Build Coastguard Worker
4644*da0073e9SAndroid Build Coastguard Worker            if nt.dim() > nt._ragged_idx + transpose_offset:
4645*da0073e9SAndroid Build Coastguard Worker                nt_transposed = nt.transpose(
4646*da0073e9SAndroid Build Coastguard Worker                    nt._ragged_idx, nt._ragged_idx + transpose_offset
4647*da0073e9SAndroid Build Coastguard Worker                )
4648*da0073e9SAndroid Build Coastguard Worker                reduce_dim = nt_transposed._ragged_idx  # ragged
4649*da0073e9SAndroid Build Coastguard Worker
4650*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
4651*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
4652*da0073e9SAndroid Build Coastguard Worker                    "not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor",
4653*da0073e9SAndroid Build Coastguard Worker                ):
4654*da0073e9SAndroid Build Coastguard Worker                    out = torch.nn.functional.softmax(nt_transposed, dim=reduce_dim)
4655*da0073e9SAndroid Build Coastguard Worker
4656*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4657*da0073e9SAndroid Build Coastguard Worker    @parametrize(
4658*da0073e9SAndroid Build Coastguard Worker        "func",
4659*da0073e9SAndroid Build Coastguard Worker        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
4660*da0073e9SAndroid Build Coastguard Worker        name_fn=get_op_name,
4661*da0073e9SAndroid Build Coastguard Worker    )
4662*da0073e9SAndroid Build Coastguard Worker    @parametrize("keepdim", [False, True])
4663*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4664*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4665*da0073e9SAndroid Build Coastguard Worker    def test_op_dim_transpose_non_ragged_dim_different_output_shape(
4666*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, keepdim, requires_grad, components_require_grad, func
4667*da0073e9SAndroid Build Coastguard Worker    ):
4668*da0073e9SAndroid Build Coastguard Worker        """
4669*da0073e9SAndroid Build Coastguard Worker        Operator passes when reducing transposed nested tensors on valid reduction dimensions.
4670*da0073e9SAndroid Build Coastguard Worker        This test is for operators which return an output tensor with a shape different from the input tensor.
4671*da0073e9SAndroid Build Coastguard Worker        """
4672*da0073e9SAndroid Build Coastguard Worker        if get_op_name(func) == "mean" and not keepdim:
4673*da0073e9SAndroid Build Coastguard Worker            return
4674*da0073e9SAndroid Build Coastguard Worker
4675*da0073e9SAndroid Build Coastguard Worker        # verify correctness of shapes (assuming that ragged_idx == 1)
4676*da0073e9SAndroid Build Coastguard Worker        if get_op_name(func) == "sum":
4677*da0073e9SAndroid Build Coastguard Worker            reduce_dims = (
4678*da0073e9SAndroid Build Coastguard Worker                ((0, 1), (3, 4), (1, 1, 3, 4), (0,)),  # batch, ragged
4679*da0073e9SAndroid Build Coastguard Worker                ((2, 3), (3, None), (3, None, 1, 1), (1, 2)),  # non-batch, non-batch
4680*da0073e9SAndroid Build Coastguard Worker                ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)),  # batch, ragged, non-batch
4681*da0073e9SAndroid Build Coastguard Worker                ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)),  # batch, ragged, non-batch
4682*da0073e9SAndroid Build Coastguard Worker                (
4683*da0073e9SAndroid Build Coastguard Worker                    (0, 1, 2, 3),
4684*da0073e9SAndroid Build Coastguard Worker                    (),
4685*da0073e9SAndroid Build Coastguard Worker                    (1, 1, 1, 1),
4686*da0073e9SAndroid Build Coastguard Worker                    (0, 1, 2),
4687*da0073e9SAndroid Build Coastguard Worker                ),  # batch, ragged, non-batch, non-batch
4688*da0073e9SAndroid Build Coastguard Worker                ((2,), (3, None, 4), (3, None, 1, 4), (1,)),  # non-batch
4689*da0073e9SAndroid Build Coastguard Worker            )  # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None
4690*da0073e9SAndroid Build Coastguard Worker        elif get_op_name(func) == "mean":
4691*da0073e9SAndroid Build Coastguard Worker            reduce_dims = (
4692*da0073e9SAndroid Build Coastguard Worker                ((2,), (3, None, 4), (3, None, 1, 4), (1,)),
4693*da0073e9SAndroid Build Coastguard Worker                ((3,), (3, None, 3), (3, None, 3, 1), (2,)),
4694*da0073e9SAndroid Build Coastguard Worker            )
4695*da0073e9SAndroid Build Coastguard Worker
4696*da0073e9SAndroid Build Coastguard Worker        # verify correctness of values
4697*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4698*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False,
4699*da0073e9SAndroid Build Coastguard Worker            include_requires_grad=components_require_grad,
4700*da0073e9SAndroid Build Coastguard Worker        )
4701*da0073e9SAndroid Build Coastguard Worker        for tensor_list, reduce_dim_tuple in itertools.product(
4702*da0073e9SAndroid Build Coastguard Worker            tensor_lists, reduce_dims
4703*da0073e9SAndroid Build Coastguard Worker        ):
4704*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4705*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4706*da0073e9SAndroid Build Coastguard Worker                device=device,
4707*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4708*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4709*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4710*da0073e9SAndroid Build Coastguard Worker            ).transpose(-1, -2)
4711*da0073e9SAndroid Build Coastguard Worker
4712*da0073e9SAndroid Build Coastguard Worker            reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple
4713*da0073e9SAndroid Build Coastguard Worker
4714*da0073e9SAndroid Build Coastguard Worker            if nt.dim() > max(
4715*da0073e9SAndroid Build Coastguard Worker                reduce_dim[-1], nt._ragged_idx + 2
4716*da0073e9SAndroid Build Coastguard Worker            ):  # ensure that transposed dimensions are non-batch, non-ragged dimensions
4717*da0073e9SAndroid Build Coastguard Worker                out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
4718*da0073e9SAndroid Build Coastguard Worker                if nt._ragged_idx in reduce_dim:  # raggedness reduced away
4719*da0073e9SAndroid Build Coastguard Worker                    out_expected = func(
4720*da0073e9SAndroid Build Coastguard Worker                        nt.values(), dim=reduce_dim_expected, keepdim=keepdim
4721*da0073e9SAndroid Build Coastguard Worker                    )
4722*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.allclose(out_actual, out_expected))
4723*da0073e9SAndroid Build Coastguard Worker                else:  # raggedness preserved
4724*da0073e9SAndroid Build Coastguard Worker                    out_expected = func(nt.values(), dim=reduce_dim_expected)
4725*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
4726*da0073e9SAndroid Build Coastguard Worker                        torch.allclose(
4727*da0073e9SAndroid Build Coastguard Worker                            out_actual.values().view(-1), out_expected.view(-1)
4728*da0073e9SAndroid Build Coastguard Worker                        )
4729*da0073e9SAndroid Build Coastguard Worker                    )
4730*da0073e9SAndroid Build Coastguard Worker
4731*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4732*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4733*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4734*da0073e9SAndroid Build Coastguard Worker    def test_softmax_dim_transpose_non_ragged_dim(
4735*da0073e9SAndroid Build Coastguard Worker        self,
4736*da0073e9SAndroid Build Coastguard Worker        device,
4737*da0073e9SAndroid Build Coastguard Worker        dtype,
4738*da0073e9SAndroid Build Coastguard Worker        requires_grad,
4739*da0073e9SAndroid Build Coastguard Worker        components_require_grad,
4740*da0073e9SAndroid Build Coastguard Worker    ):
4741*da0073e9SAndroid Build Coastguard Worker        """
4742*da0073e9SAndroid Build Coastguard Worker        Softmax passes when reducing transposed nested tensors on valid reduction dimensions.
4743*da0073e9SAndroid Build Coastguard Worker        This test is for operators which return an output tensor with the same shape as the input tensor.
4744*da0073e9SAndroid Build Coastguard Worker        """
4745*da0073e9SAndroid Build Coastguard Worker        # verify correctness of shapes (assuming that ragged_idx == 1)
4746*da0073e9SAndroid Build Coastguard Worker        reduce_dims = (
4747*da0073e9SAndroid Build Coastguard Worker            (2, 1),
4748*da0073e9SAndroid Build Coastguard Worker            (3, 2),
4749*da0073e9SAndroid Build Coastguard Worker        )  # (reduction dimension, effective reduction dimension for baseline)
4750*da0073e9SAndroid Build Coastguard Worker
4751*da0073e9SAndroid Build Coastguard Worker        # verify correctness of values
4752*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4753*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False,
4754*da0073e9SAndroid Build Coastguard Worker            include_requires_grad=components_require_grad,
4755*da0073e9SAndroid Build Coastguard Worker            include_inner_dim_size_1=True,  # (B, *, 1)
4756*da0073e9SAndroid Build Coastguard Worker        )
4757*da0073e9SAndroid Build Coastguard Worker        for tensor_list, reduce_dim_tuple in itertools.product(
4758*da0073e9SAndroid Build Coastguard Worker            tensor_lists, reduce_dims
4759*da0073e9SAndroid Build Coastguard Worker        ):
4760*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4761*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4762*da0073e9SAndroid Build Coastguard Worker                device=device,
4763*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4764*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4765*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4766*da0073e9SAndroid Build Coastguard Worker            ).transpose(-1, -2)
4767*da0073e9SAndroid Build Coastguard Worker
4768*da0073e9SAndroid Build Coastguard Worker            reduce_dim, reduce_dim_expected = reduce_dim_tuple
4769*da0073e9SAndroid Build Coastguard Worker
4770*da0073e9SAndroid Build Coastguard Worker            if nt.dim() > max(reduce_dim, nt._ragged_idx + 2):
4771*da0073e9SAndroid Build Coastguard Worker                out_actual = torch.nn.functional.softmax(
4772*da0073e9SAndroid Build Coastguard Worker                    nt, dim=reduce_dim
4773*da0073e9SAndroid Build Coastguard Worker                )  # nested tensor
4774*da0073e9SAndroid Build Coastguard Worker                out_expected = torch.nn.functional.softmax(
4775*da0073e9SAndroid Build Coastguard Worker                    nt.values(), dim=reduce_dim_expected
4776*da0073e9SAndroid Build Coastguard Worker                )  # dense tensor of dimensions 1 less than out_actual
4777*da0073e9SAndroid Build Coastguard Worker
4778*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
4779*da0073e9SAndroid Build Coastguard Worker                    torch.allclose(out_actual.values().view(-1), out_expected.view(-1))
4780*da0073e9SAndroid Build Coastguard Worker                )
4781*da0073e9SAndroid Build Coastguard Worker
4782*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4783*da0073e9SAndroid Build Coastguard Worker    @parametrize("keepdim", [False, True])
4784*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4785*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4786*da0073e9SAndroid Build Coastguard Worker    def test_sum_dim_reduce_ragged_and_non_batch(
4787*da0073e9SAndroid Build Coastguard Worker        self,
4788*da0073e9SAndroid Build Coastguard Worker        device,
4789*da0073e9SAndroid Build Coastguard Worker        dtype,
4790*da0073e9SAndroid Build Coastguard Worker        keepdim,
4791*da0073e9SAndroid Build Coastguard Worker        requires_grad,
4792*da0073e9SAndroid Build Coastguard Worker        components_require_grad,
4793*da0073e9SAndroid Build Coastguard Worker    ):
4794*da0073e9SAndroid Build Coastguard Worker        """
4795*da0073e9SAndroid Build Coastguard Worker        Sum on NestedTensor fails when trying to reduce across ragged and non-batch dimensions
4796*da0073e9SAndroid Build Coastguard Worker        """
4797*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4798*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False, include_requires_grad=components_require_grad
4799*da0073e9SAndroid Build Coastguard Worker        )
4800*da0073e9SAndroid Build Coastguard Worker        reduce_dims = (
4801*da0073e9SAndroid Build Coastguard Worker            (1, 2),  # ragged, non-batch
4802*da0073e9SAndroid Build Coastguard Worker            (1, 3),  # ragged, non-batch
4803*da0073e9SAndroid Build Coastguard Worker        )
4804*da0073e9SAndroid Build Coastguard Worker
4805*da0073e9SAndroid Build Coastguard Worker        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
4806*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4807*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4808*da0073e9SAndroid Build Coastguard Worker                device=device,
4809*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4810*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4811*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4812*da0073e9SAndroid Build Coastguard Worker            )
4813*da0073e9SAndroid Build Coastguard Worker
4814*da0073e9SAndroid Build Coastguard Worker            if nt.dim() > reduce_dim[-1]:
4815*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
4816*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
4817*da0073e9SAndroid Build Coastguard Worker                    "not supported along a ragged and non-batch dimension for NestedTensor",
4818*da0073e9SAndroid Build Coastguard Worker                ):
4819*da0073e9SAndroid Build Coastguard Worker                    out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
4820*da0073e9SAndroid Build Coastguard Worker
4821*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4822*da0073e9SAndroid Build Coastguard Worker    @parametrize("keepdim", [False, True])
4823*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4824*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4825*da0073e9SAndroid Build Coastguard Worker    def test_sum_dim_reduce_batch_and_non_batch(
4826*da0073e9SAndroid Build Coastguard Worker        self,
4827*da0073e9SAndroid Build Coastguard Worker        device,
4828*da0073e9SAndroid Build Coastguard Worker        dtype,
4829*da0073e9SAndroid Build Coastguard Worker        keepdim,
4830*da0073e9SAndroid Build Coastguard Worker        requires_grad,
4831*da0073e9SAndroid Build Coastguard Worker        components_require_grad,
4832*da0073e9SAndroid Build Coastguard Worker    ):
4833*da0073e9SAndroid Build Coastguard Worker        """
4834*da0073e9SAndroid Build Coastguard Worker        Sum on NestedTensor fails when trying to reduce across batch and non-batch dimensions
4835*da0073e9SAndroid Build Coastguard Worker        """
4836*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4837*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False, include_requires_grad=components_require_grad
4838*da0073e9SAndroid Build Coastguard Worker        )
4839*da0073e9SAndroid Build Coastguard Worker        reduce_dims = (
4840*da0073e9SAndroid Build Coastguard Worker            (0, 2),  # batch, non-batch
4841*da0073e9SAndroid Build Coastguard Worker            (0, 3),  # batch, non-batch
4842*da0073e9SAndroid Build Coastguard Worker        )
4843*da0073e9SAndroid Build Coastguard Worker
4844*da0073e9SAndroid Build Coastguard Worker        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
4845*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4846*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4847*da0073e9SAndroid Build Coastguard Worker                device=device,
4848*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4849*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4850*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4851*da0073e9SAndroid Build Coastguard Worker            )
4852*da0073e9SAndroid Build Coastguard Worker
4853*da0073e9SAndroid Build Coastguard Worker            if nt.dim() > reduce_dim[-1]:
4854*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
4855*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
4856*da0073e9SAndroid Build Coastguard Worker                    "not supported along the batch dimension but not the ragged dimension for NestedTensor",
4857*da0073e9SAndroid Build Coastguard Worker                ):
4858*da0073e9SAndroid Build Coastguard Worker                    out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
4859*da0073e9SAndroid Build Coastguard Worker
4860*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4861*da0073e9SAndroid Build Coastguard Worker    @parametrize(
4862*da0073e9SAndroid Build Coastguard Worker        "func",
4863*da0073e9SAndroid Build Coastguard Worker        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
4864*da0073e9SAndroid Build Coastguard Worker        name_fn=get_op_name,
4865*da0073e9SAndroid Build Coastguard Worker    )
4866*da0073e9SAndroid Build Coastguard Worker    @parametrize("keepdim", [False, True])
4867*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4868*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4869*da0073e9SAndroid Build Coastguard Worker    def test_op_dim_reduce_batch_only_different_output_shape(
4870*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, keepdim, requires_grad, components_require_grad, func
4871*da0073e9SAndroid Build Coastguard Worker    ):
4872*da0073e9SAndroid Build Coastguard Worker        """
4873*da0073e9SAndroid Build Coastguard Worker        Operator on NestedTensor fails when trying to reduce across batch dimension
4874*da0073e9SAndroid Build Coastguard Worker        """
4875*da0073e9SAndroid Build Coastguard Worker        if get_op_name(func) == "mean" and not keepdim:
4876*da0073e9SAndroid Build Coastguard Worker            return
4877*da0073e9SAndroid Build Coastguard Worker
4878*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
4879*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False, include_requires_grad=components_require_grad
4880*da0073e9SAndroid Build Coastguard Worker        )
4881*da0073e9SAndroid Build Coastguard Worker        reduce_dim = (0,)  # batch
4882*da0073e9SAndroid Build Coastguard Worker
4883*da0073e9SAndroid Build Coastguard Worker        for tensor_list in tensor_lists:
4884*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
4885*da0073e9SAndroid Build Coastguard Worker                tensor_list,
4886*da0073e9SAndroid Build Coastguard Worker                device=device,
4887*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
4888*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
4889*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
4890*da0073e9SAndroid Build Coastguard Worker            )
4891*da0073e9SAndroid Build Coastguard Worker
4892*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
4893*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
4894*da0073e9SAndroid Build Coastguard Worker                "not supported along the batch dimension but not the ragged dimension for NestedTensor",
4895*da0073e9SAndroid Build Coastguard Worker            ):
4896*da0073e9SAndroid Build Coastguard Worker                out = func(nt, dim=reduce_dim, keepdim=keepdim)
4897*da0073e9SAndroid Build Coastguard Worker
4898*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4899*da0073e9SAndroid Build Coastguard Worker    @parametrize(
4900*da0073e9SAndroid Build Coastguard Worker        "func",
4901*da0073e9SAndroid Build Coastguard Worker        [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
4902*da0073e9SAndroid Build Coastguard Worker        name_fn=get_op_name,
4903*da0073e9SAndroid Build Coastguard Worker    )
4904*da0073e9SAndroid Build Coastguard Worker    @parametrize("keepdim", [False, True])
4905*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4906*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4907*da0073e9SAndroid Build Coastguard Worker    def test_op_dim_with_lengths_different_output_shape(
4908*da0073e9SAndroid Build Coastguard Worker        self,
4909*da0073e9SAndroid Build Coastguard Worker        device,
4910*da0073e9SAndroid Build Coastguard Worker        dtype,
4911*da0073e9SAndroid Build Coastguard Worker        keepdim,
4912*da0073e9SAndroid Build Coastguard Worker        requires_grad,
4913*da0073e9SAndroid Build Coastguard Worker        components_require_grad,
4914*da0073e9SAndroid Build Coastguard Worker        func,
4915*da0073e9SAndroid Build Coastguard Worker    ):
4916*da0073e9SAndroid Build Coastguard Worker        """
4917*da0073e9SAndroid Build Coastguard Worker        Operator on NestedTensor fails when trying to reduce a nested tensor with lengths,
4918*da0073e9SAndroid Build Coastguard Worker        i.e. a nested tensor with holes, if reducing on the ragged dimension.
4919*da0073e9SAndroid Build Coastguard Worker        This test is for operators which return an output tensor with different shape than the input tensor.
4920*da0073e9SAndroid Build Coastguard Worker        """
4921*da0073e9SAndroid Build Coastguard Worker        if get_op_name(func) == "mean" and not keepdim:
4922*da0073e9SAndroid Build Coastguard Worker            return
4923*da0073e9SAndroid Build Coastguard Worker
4924*da0073e9SAndroid Build Coastguard Worker        reduce_dims = ((1,), (2,), (2, 3))
4925*da0073e9SAndroid Build Coastguard Worker
4926*da0073e9SAndroid Build Coastguard Worker        lengths = torch.randint(5, 10, (20,), device=device)
4927*da0073e9SAndroid Build Coastguard Worker        offsets = torch.zeros((21,), device=device, dtype=torch.int)
4928*da0073e9SAndroid Build Coastguard Worker        torch.cumsum(lengths, dim=0, out=offsets[1:])
4929*da0073e9SAndroid Build Coastguard Worker
4930*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(
4931*da0073e9SAndroid Build Coastguard Worker            (offsets[-1].item(), 20),
4932*da0073e9SAndroid Build Coastguard Worker            device=device,
4933*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
4934*da0073e9SAndroid Build Coastguard Worker            requires_grad=requires_grad,
4935*da0073e9SAndroid Build Coastguard Worker        )
4936*da0073e9SAndroid Build Coastguard Worker
4937*da0073e9SAndroid Build Coastguard Worker        nt_with_holes = torch.nested.nested_tensor_from_jagged(
4938*da0073e9SAndroid Build Coastguard Worker            values,
4939*da0073e9SAndroid Build Coastguard Worker            offsets,
4940*da0073e9SAndroid Build Coastguard Worker            lengths=offsets.diff() - 2,  # arbitrary subtraction to create holes
4941*da0073e9SAndroid Build Coastguard Worker        )
4942*da0073e9SAndroid Build Coastguard Worker
4943*da0073e9SAndroid Build Coastguard Worker        for reduce_dim in reduce_dims:
4944*da0073e9SAndroid Build Coastguard Worker            if nt_with_holes.dim() > reduce_dim[-1]:
4945*da0073e9SAndroid Build Coastguard Worker                if nt_with_holes._ragged_idx in reduce_dim:
4946*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
4947*da0073e9SAndroid Build Coastguard Worker                        RuntimeError,
4948*da0073e9SAndroid Build Coastguard Worker                        "not supported where lengths is not None "
4949*da0073e9SAndroid Build Coastguard Worker                        + "if reducing across the ragged dimension for NestedTensor",
4950*da0073e9SAndroid Build Coastguard Worker                    ):
4951*da0073e9SAndroid Build Coastguard Worker                        out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim)
4952*da0073e9SAndroid Build Coastguard Worker                else:
4953*da0073e9SAndroid Build Coastguard Worker                    out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim)
4954*da0073e9SAndroid Build Coastguard Worker
4955*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
4956*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
4957*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
4958*da0073e9SAndroid Build Coastguard Worker    def test_softmax_dim_with_lengths(
4959*da0073e9SAndroid Build Coastguard Worker        self,
4960*da0073e9SAndroid Build Coastguard Worker        device,
4961*da0073e9SAndroid Build Coastguard Worker        dtype,
4962*da0073e9SAndroid Build Coastguard Worker        requires_grad,
4963*da0073e9SAndroid Build Coastguard Worker        components_require_grad,
4964*da0073e9SAndroid Build Coastguard Worker    ):
4965*da0073e9SAndroid Build Coastguard Worker        """
4966*da0073e9SAndroid Build Coastguard Worker        Softmax on NestedTensor fails when trying to reduce a nested tensor with lengths,
4967*da0073e9SAndroid Build Coastguard Worker        i.e. a nested tensor with holes, if reducing on the ragged dimension.
4968*da0073e9SAndroid Build Coastguard Worker        """
4969*da0073e9SAndroid Build Coastguard Worker        reduce_dims = (1, 2, 3)
4970*da0073e9SAndroid Build Coastguard Worker
4971*da0073e9SAndroid Build Coastguard Worker        lengths = torch.randint(5, 10, (20,), device=device)
4972*da0073e9SAndroid Build Coastguard Worker        offsets = torch.zeros((21,), device=device, dtype=torch.int)
4973*da0073e9SAndroid Build Coastguard Worker        torch.cumsum(lengths, dim=0, out=offsets[1:])
4974*da0073e9SAndroid Build Coastguard Worker
4975*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(
4976*da0073e9SAndroid Build Coastguard Worker            (offsets[-1].item(), 20),
4977*da0073e9SAndroid Build Coastguard Worker            device=device,
4978*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
4979*da0073e9SAndroid Build Coastguard Worker            requires_grad=requires_grad,
4980*da0073e9SAndroid Build Coastguard Worker        )
4981*da0073e9SAndroid Build Coastguard Worker
4982*da0073e9SAndroid Build Coastguard Worker        nt_with_holes = torch.nested.nested_tensor_from_jagged(
4983*da0073e9SAndroid Build Coastguard Worker            values,
4984*da0073e9SAndroid Build Coastguard Worker            offsets,
4985*da0073e9SAndroid Build Coastguard Worker            lengths=offsets.diff() - 2,  # arbitrary subtraction to create holes
4986*da0073e9SAndroid Build Coastguard Worker        )
4987*da0073e9SAndroid Build Coastguard Worker
4988*da0073e9SAndroid Build Coastguard Worker        for reduce_dim in reduce_dims:
4989*da0073e9SAndroid Build Coastguard Worker            if nt_with_holes.dim() > reduce_dim:
4990*da0073e9SAndroid Build Coastguard Worker                if nt_with_holes._ragged_idx == reduce_dim:
4991*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
4992*da0073e9SAndroid Build Coastguard Worker                        RuntimeError,
4993*da0073e9SAndroid Build Coastguard Worker                        "not supported where lengths is not None "
4994*da0073e9SAndroid Build Coastguard Worker                        + "if reducing across the ragged dimension for NestedTensor",
4995*da0073e9SAndroid Build Coastguard Worker                    ):
4996*da0073e9SAndroid Build Coastguard Worker                        out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim)
4997*da0073e9SAndroid Build Coastguard Worker                else:
4998*da0073e9SAndroid Build Coastguard Worker                    out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim)
4999*da0073e9SAndroid Build Coastguard Worker
5000*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo(
5001*da0073e9SAndroid Build Coastguard Worker        "ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx] does not currently work "
5002*da0073e9SAndroid Build Coastguard Worker        + "with dynamo tests and throws this error: `AssertionError: SymInts must use SymNodeVariable. "
5003*da0073e9SAndroid Build Coastguard Worker        + "If the underlying value is static, we will create a ConstantVariable and specialize.`"
5004*da0073e9SAndroid Build Coastguard Worker    )
5005*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
5006*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
5007*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
5008*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm_with_lengths(
5009*da0073e9SAndroid Build Coastguard Worker        self,
5010*da0073e9SAndroid Build Coastguard Worker        device,
5011*da0073e9SAndroid Build Coastguard Worker        dtype,
5012*da0073e9SAndroid Build Coastguard Worker        requires_grad,
5013*da0073e9SAndroid Build Coastguard Worker        components_require_grad,
5014*da0073e9SAndroid Build Coastguard Worker    ):
5015*da0073e9SAndroid Build Coastguard Worker        """
5016*da0073e9SAndroid Build Coastguard Worker        Layer normalization on NestedTensor fails when trying to operate on a nested tensor with lengths,
5017*da0073e9SAndroid Build Coastguard Worker        i.e. a nested tensor with holes, if operating on the ragged dimension.
5018*da0073e9SAndroid Build Coastguard Worker        """
5019*da0073e9SAndroid Build Coastguard Worker
5020*da0073e9SAndroid Build Coastguard Worker        # create components for nested tensor
5021*da0073e9SAndroid Build Coastguard Worker        lengths = torch.randint(5, 10, (20,), device=device)
5022*da0073e9SAndroid Build Coastguard Worker        offsets = torch.zeros((21,), device=device, dtype=torch.int)
5023*da0073e9SAndroid Build Coastguard Worker        torch.cumsum(lengths, dim=0, out=offsets[1:])
5024*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(
5025*da0073e9SAndroid Build Coastguard Worker            (offsets[-1].item(), 10, 30),
5026*da0073e9SAndroid Build Coastguard Worker            device=device,
5027*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
5028*da0073e9SAndroid Build Coastguard Worker            requires_grad=requires_grad,
5029*da0073e9SAndroid Build Coastguard Worker        )
5030*da0073e9SAndroid Build Coastguard Worker
5031*da0073e9SAndroid Build Coastguard Worker        nt_with_holes = torch.nested.nested_tensor_from_jagged(
5032*da0073e9SAndroid Build Coastguard Worker            values,
5033*da0073e9SAndroid Build Coastguard Worker            offsets,
5034*da0073e9SAndroid Build Coastguard Worker            lengths=offsets.diff() - 2,  # arbitrary subtraction to create holes
5035*da0073e9SAndroid Build Coastguard Worker        )
5036*da0073e9SAndroid Build Coastguard Worker
5037*da0073e9SAndroid Build Coastguard Worker        ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx]
5038*da0073e9SAndroid Build Coastguard Worker
5039*da0073e9SAndroid Build Coastguard Worker        normalized_shapes = (
5040*da0073e9SAndroid Build Coastguard Worker            (10, 30),  # normalization on non-ragged dimension passes
5041*da0073e9SAndroid Build Coastguard Worker            (ragged_size, 10, 30),  # normalization on ragged dimension fails
5042*da0073e9SAndroid Build Coastguard Worker        )
5043*da0073e9SAndroid Build Coastguard Worker
5044*da0073e9SAndroid Build Coastguard Worker        for normalized_shape in normalized_shapes:
5045*da0073e9SAndroid Build Coastguard Worker            if ragged_size in normalized_shape:
5046*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
5047*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
5048*da0073e9SAndroid Build Coastguard Worker                    "not supported where lengths is not None if operating on the ragged dimension for NestedTensor",
5049*da0073e9SAndroid Build Coastguard Worker                ):
5050*da0073e9SAndroid Build Coastguard Worker                    out = torch.nn.functional.layer_norm(
5051*da0073e9SAndroid Build Coastguard Worker                        nt_with_holes, normalized_shape=normalized_shape
5052*da0073e9SAndroid Build Coastguard Worker                    )
5053*da0073e9SAndroid Build Coastguard Worker            else:
5054*da0073e9SAndroid Build Coastguard Worker                out = torch.nn.functional.layer_norm(
5055*da0073e9SAndroid Build Coastguard Worker                    nt_with_holes, normalized_shape=normalized_shape
5056*da0073e9SAndroid Build Coastguard Worker                )
5057*da0073e9SAndroid Build Coastguard Worker
5058*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
5059*da0073e9SAndroid Build Coastguard Worker    @parametrize("keepdim", [True])
5060*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
5061*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
5062*da0073e9SAndroid Build Coastguard Worker    def test_mean_dim_reduce_multiple_dims(
5063*da0073e9SAndroid Build Coastguard Worker        self,
5064*da0073e9SAndroid Build Coastguard Worker        device,
5065*da0073e9SAndroid Build Coastguard Worker        dtype,
5066*da0073e9SAndroid Build Coastguard Worker        keepdim,
5067*da0073e9SAndroid Build Coastguard Worker        requires_grad,
5068*da0073e9SAndroid Build Coastguard Worker        components_require_grad,
5069*da0073e9SAndroid Build Coastguard Worker    ):
5070*da0073e9SAndroid Build Coastguard Worker        """
5071*da0073e9SAndroid Build Coastguard Worker        Mean on NestedTensor fails when trying to reduce across multiple dimensions
5072*da0073e9SAndroid Build Coastguard Worker        """
5073*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
5074*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False, include_requires_grad=components_require_grad
5075*da0073e9SAndroid Build Coastguard Worker        )
5076*da0073e9SAndroid Build Coastguard Worker        reduce_dims = ((0, 1), (2, 3), (2, 3, 4))
5077*da0073e9SAndroid Build Coastguard Worker
5078*da0073e9SAndroid Build Coastguard Worker        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
5079*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
5080*da0073e9SAndroid Build Coastguard Worker                tensor_list,
5081*da0073e9SAndroid Build Coastguard Worker                device=device,
5082*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
5083*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
5084*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
5085*da0073e9SAndroid Build Coastguard Worker            )
5086*da0073e9SAndroid Build Coastguard Worker
5087*da0073e9SAndroid Build Coastguard Worker            if nt.dim() > reduce_dim[-1]:
5088*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
5089*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
5090*da0073e9SAndroid Build Coastguard Worker                    "not supported across multiple dimensions for NestedTensor",
5091*da0073e9SAndroid Build Coastguard Worker                ):
5092*da0073e9SAndroid Build Coastguard Worker                    out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim)
5093*da0073e9SAndroid Build Coastguard Worker
5094*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
5095*da0073e9SAndroid Build Coastguard Worker    @parametrize("keepdim", [False, True])
5096*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
5097*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
5098*da0073e9SAndroid Build Coastguard Worker    def test_mean_dim_keepdim_False(
5099*da0073e9SAndroid Build Coastguard Worker        self,
5100*da0073e9SAndroid Build Coastguard Worker        device,
5101*da0073e9SAndroid Build Coastguard Worker        dtype,
5102*da0073e9SAndroid Build Coastguard Worker        keepdim,
5103*da0073e9SAndroid Build Coastguard Worker        requires_grad,
5104*da0073e9SAndroid Build Coastguard Worker        components_require_grad,
5105*da0073e9SAndroid Build Coastguard Worker    ):
5106*da0073e9SAndroid Build Coastguard Worker        """
5107*da0073e9SAndroid Build Coastguard Worker        Mean on NestedTensor fails when keepdim=False
5108*da0073e9SAndroid Build Coastguard Worker        """
5109*da0073e9SAndroid Build Coastguard Worker        tensor_lists = self._get_example_tensor_lists(
5110*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False, include_requires_grad=components_require_grad
5111*da0073e9SAndroid Build Coastguard Worker        )
5112*da0073e9SAndroid Build Coastguard Worker        reduce_dims = ((1,), (2,), (3,))
5113*da0073e9SAndroid Build Coastguard Worker
5114*da0073e9SAndroid Build Coastguard Worker        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
5115*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
5116*da0073e9SAndroid Build Coastguard Worker                tensor_list,
5117*da0073e9SAndroid Build Coastguard Worker                device=device,
5118*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
5119*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
5120*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
5121*da0073e9SAndroid Build Coastguard Worker            )
5122*da0073e9SAndroid Build Coastguard Worker
5123*da0073e9SAndroid Build Coastguard Worker            if nt.dim() > reduce_dim[-1]:
5124*da0073e9SAndroid Build Coastguard Worker                if not keepdim:
5125*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
5126*da0073e9SAndroid Build Coastguard Worker                        RuntimeError,
5127*da0073e9SAndroid Build Coastguard Worker                        "not supported when keepdim=False for NestedTensor",
5128*da0073e9SAndroid Build Coastguard Worker                    ):
5129*da0073e9SAndroid Build Coastguard Worker                        out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim)
5130*da0073e9SAndroid Build Coastguard Worker                else:
5131*da0073e9SAndroid Build Coastguard Worker                    out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim)
5132*da0073e9SAndroid Build Coastguard Worker
5133*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.half)
5134*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
5135*da0073e9SAndroid Build Coastguard Worker    @parametrize("weights_only", [False, True])
5136*da0073e9SAndroid Build Coastguard Worker    def test_serialization(self, device, dtype, requires_grad, weights_only):
5137*da0073e9SAndroid Build Coastguard Worker        def compare_metadata(nt1, nt2):
5138*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size())
5139*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides())
5140*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
5141*da0073e9SAndroid Build Coastguard Worker                nt1._nested_tensor_storage_offsets(),
5142*da0073e9SAndroid Build Coastguard Worker                nt2._nested_tensor_storage_offsets(),
5143*da0073e9SAndroid Build Coastguard Worker            )
5144*da0073e9SAndroid Build Coastguard Worker
5145*da0073e9SAndroid Build Coastguard Worker        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
5146*da0073e9SAndroid Build Coastguard Worker        for a in [nt_contiguous, nt_noncontiguous]:
5147*da0073e9SAndroid Build Coastguard Worker            buffer = io.BytesIO()
5148*da0073e9SAndroid Build Coastguard Worker            serialized = torch.save(a, buffer)
5149*da0073e9SAndroid Build Coastguard Worker            buffer.seek(0)
5150*da0073e9SAndroid Build Coastguard Worker            b = torch.load(buffer, weights_only=weights_only)
5151*da0073e9SAndroid Build Coastguard Worker            # should be both conceptually equal and metadata equivalent
5152*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, b)
5153*da0073e9SAndroid Build Coastguard Worker            compare_metadata(a, b)
5154*da0073e9SAndroid Build Coastguard Worker            # should be conceptually equal but not necessarily metadata equivalent
5155*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b, nt_contiguous)
5156*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b, nt_noncontiguous)
5157*da0073e9SAndroid Build Coastguard Worker
5158*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
5159*da0073e9SAndroid Build Coastguard Worker        PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property"
5160*da0073e9SAndroid Build Coastguard Worker    )
5161*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
5162*da0073e9SAndroid Build Coastguard Worker    def test_pin_memory(self, device):
5163*da0073e9SAndroid Build Coastguard Worker        nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
5164*da0073e9SAndroid Build Coastguard Worker        for nt in [nt_contiguous, nt_noncontiguous]:
5165*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(nt.is_pinned())
5166*da0073e9SAndroid Build Coastguard Worker            pinned = nt.pin_memory(device)
5167*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(pinned.is_pinned())
5168*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt, pinned)
5169*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(nt.data_ptr(), pinned.data_ptr())
5170*da0073e9SAndroid Build Coastguard Worker            # test that pin_memory on already pinned tensor has no effect
5171*da0073e9SAndroid Build Coastguard Worker            self.assertIs(pinned, pinned.pin_memory())
5172*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
5173*da0073e9SAndroid Build Coastguard Worker
5174*da0073e9SAndroid Build Coastguard Worker    @torch.compiler.disable
5175*da0073e9SAndroid Build Coastguard Worker    def _validate_nt(
5176*da0073e9SAndroid Build Coastguard Worker        self,
5177*da0073e9SAndroid Build Coastguard Worker        nt,
5178*da0073e9SAndroid Build Coastguard Worker        device,
5179*da0073e9SAndroid Build Coastguard Worker        dtype,
5180*da0073e9SAndroid Build Coastguard Worker        layout,
5181*da0073e9SAndroid Build Coastguard Worker        requires_grad,
5182*da0073e9SAndroid Build Coastguard Worker        dim,
5183*da0073e9SAndroid Build Coastguard Worker        batch_size,
5184*da0073e9SAndroid Build Coastguard Worker        contiguous,
5185*da0073e9SAndroid Build Coastguard Worker        cached_min_seqlen=None,
5186*da0073e9SAndroid Build Coastguard Worker        cached_max_seqlen=None,
5187*da0073e9SAndroid Build Coastguard Worker        base=None,
5188*da0073e9SAndroid Build Coastguard Worker        ref_nt=None,
5189*da0073e9SAndroid Build Coastguard Worker    ):
5190*da0073e9SAndroid Build Coastguard Worker        # Validate a bunch of properties after NT construction.
5191*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
5192*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.dim(), dim)
5193*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.device, device)
5194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.dtype, dtype)
5195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.layout, layout)
5196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.requires_grad, requires_grad)
5197*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt.is_contiguous(), contiguous)
5198*da0073e9SAndroid Build Coastguard Worker
5199*da0073e9SAndroid Build Coastguard Worker        if layout == torch.jagged:
5200*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt._values.device, device)
5201*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt._offsets.device, device)
5202*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt.shape[0], batch_size)
5203*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(nt.shape[1], torch.SymInt))
5204*da0073e9SAndroid Build Coastguard Worker
5205*da0073e9SAndroid Build Coastguard Worker            if base is not None:
5206*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(nt._is_view() and nt._base is base)
5207*da0073e9SAndroid Build Coastguard Worker                replay_cache = nt._view_func(torch.randn_like(nt._base))._metadata_cache
5208*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
5209*da0073e9SAndroid Build Coastguard Worker                    "min_seqlen" in replay_cache, cached_min_seqlen is not None
5210*da0073e9SAndroid Build Coastguard Worker                )
5211*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
5212*da0073e9SAndroid Build Coastguard Worker                    "max_seqlen" in replay_cache, cached_max_seqlen is not None
5213*da0073e9SAndroid Build Coastguard Worker                )
5214*da0073e9SAndroid Build Coastguard Worker
5215*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
5216*da0073e9SAndroid Build Coastguard Worker                "min_seqlen" in nt._metadata_cache, cached_min_seqlen is not None
5217*da0073e9SAndroid Build Coastguard Worker            )
5218*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
5219*da0073e9SAndroid Build Coastguard Worker                "max_seqlen" in nt._metadata_cache, cached_max_seqlen is not None
5220*da0073e9SAndroid Build Coastguard Worker            )
5221*da0073e9SAndroid Build Coastguard Worker
5222*da0073e9SAndroid Build Coastguard Worker            if cached_min_seqlen is not None:
5223*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(nt._min_seqlen, cached_min_seqlen)
5224*da0073e9SAndroid Build Coastguard Worker
5225*da0073e9SAndroid Build Coastguard Worker            if cached_max_seqlen is not None:
5226*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(nt._max_seqlen, cached_max_seqlen)
5227*da0073e9SAndroid Build Coastguard Worker
5228*da0073e9SAndroid Build Coastguard Worker        if ref_nt is not None:
5229*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt.size(0), ref_nt.size(0))
5230*da0073e9SAndroid Build Coastguard Worker            for n1, n2 in zip(nt.unbind(), ref_nt.unbind()):
5231*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(n1, n2)
5232*da0073e9SAndroid Build Coastguard Worker
5233*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.half)
5234*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
5235*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
5236*da0073e9SAndroid Build Coastguard Worker    def test_jagged_layout_construction_nested_tensor(
5237*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, requires_grad, components_require_grad
5238*da0073e9SAndroid Build Coastguard Worker    ):
5239*da0073e9SAndroid Build Coastguard Worker        for tensor_list in self._get_example_tensor_lists(
5240*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=True, include_requires_grad=components_require_grad
5241*da0073e9SAndroid Build Coastguard Worker        ):
5242*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
5243*da0073e9SAndroid Build Coastguard Worker                tensor_list,
5244*da0073e9SAndroid Build Coastguard Worker                device=device,
5245*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
5246*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
5247*da0073e9SAndroid Build Coastguard Worker                requires_grad=requires_grad,
5248*da0073e9SAndroid Build Coastguard Worker            )
5249*da0073e9SAndroid Build Coastguard Worker
5250*da0073e9SAndroid Build Coastguard Worker            expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1
5251*da0073e9SAndroid Build Coastguard Worker            expected_batch_size = len(tensor_list)
5252*da0073e9SAndroid Build Coastguard Worker            expected_contiguous = True
5253*da0073e9SAndroid Build Coastguard Worker            expected_min_seqlen = min(
5254*da0073e9SAndroid Build Coastguard Worker                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
5255*da0073e9SAndroid Build Coastguard Worker                for t in tensor_list
5256*da0073e9SAndroid Build Coastguard Worker            )
5257*da0073e9SAndroid Build Coastguard Worker            expected_max_seqlen = max(
5258*da0073e9SAndroid Build Coastguard Worker                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
5259*da0073e9SAndroid Build Coastguard Worker                for t in tensor_list
5260*da0073e9SAndroid Build Coastguard Worker            )
5261*da0073e9SAndroid Build Coastguard Worker            self._validate_nt(
5262*da0073e9SAndroid Build Coastguard Worker                nt,
5263*da0073e9SAndroid Build Coastguard Worker                device,
5264*da0073e9SAndroid Build Coastguard Worker                dtype,
5265*da0073e9SAndroid Build Coastguard Worker                torch.jagged,
5266*da0073e9SAndroid Build Coastguard Worker                requires_grad,
5267*da0073e9SAndroid Build Coastguard Worker                expected_dim,
5268*da0073e9SAndroid Build Coastguard Worker                expected_batch_size,
5269*da0073e9SAndroid Build Coastguard Worker                expected_contiguous,
5270*da0073e9SAndroid Build Coastguard Worker                expected_min_seqlen,
5271*da0073e9SAndroid Build Coastguard Worker                expected_max_seqlen,
5272*da0073e9SAndroid Build Coastguard Worker            )
5273*da0073e9SAndroid Build Coastguard Worker
5274*da0073e9SAndroid Build Coastguard Worker            # Make sure grads -don't- flow back into original tensors for nested_tensor()
5275*da0073e9SAndroid Build Coastguard Worker            if requires_grad:
5276*da0073e9SAndroid Build Coastguard Worker                (nt * 2).backward(torch.ones_like(nt))
5277*da0073e9SAndroid Build Coastguard Worker            for t in tensor_list:
5278*da0073e9SAndroid Build Coastguard Worker                t = t if isinstance(t, torch.Tensor) else torch.as_tensor(t)
5279*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.grad is None)
5280*da0073e9SAndroid Build Coastguard Worker
5281*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.half)
5282*da0073e9SAndroid Build Coastguard Worker    @parametrize("components_require_grad", [False, True])
5283*da0073e9SAndroid Build Coastguard Worker    def test_jagged_layout_construction_as_nested_tensor(
5284*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, components_require_grad
5285*da0073e9SAndroid Build Coastguard Worker    ):
5286*da0073e9SAndroid Build Coastguard Worker        # NB: as_nested_tensor(tensor_list) doesn't support lists of lists for tensor_list
5287*da0073e9SAndroid Build Coastguard Worker        for tensor_list in self._get_example_tensor_lists(
5288*da0073e9SAndroid Build Coastguard Worker            include_list_of_lists=False, include_requires_grad=components_require_grad
5289*da0073e9SAndroid Build Coastguard Worker        ):
5290*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.as_nested_tensor(
5291*da0073e9SAndroid Build Coastguard Worker                tensor_list, device=device, dtype=dtype, layout=torch.jagged
5292*da0073e9SAndroid Build Coastguard Worker            )
5293*da0073e9SAndroid Build Coastguard Worker
5294*da0073e9SAndroid Build Coastguard Worker            # nt.requires_grad=True should be set if at least one component requires grad
5295*da0073e9SAndroid Build Coastguard Worker            expected_dim = tensor_list[0].dim() + 1
5296*da0073e9SAndroid Build Coastguard Worker            expected_batch_size = len(tensor_list)
5297*da0073e9SAndroid Build Coastguard Worker            expected_contiguous = True
5298*da0073e9SAndroid Build Coastguard Worker            expected_min_seqlen = min(
5299*da0073e9SAndroid Build Coastguard Worker                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
5300*da0073e9SAndroid Build Coastguard Worker                for t in tensor_list
5301*da0073e9SAndroid Build Coastguard Worker            )
5302*da0073e9SAndroid Build Coastguard Worker            expected_max_seqlen = max(
5303*da0073e9SAndroid Build Coastguard Worker                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
5304*da0073e9SAndroid Build Coastguard Worker                for t in tensor_list
5305*da0073e9SAndroid Build Coastguard Worker            )
5306*da0073e9SAndroid Build Coastguard Worker            self._validate_nt(
5307*da0073e9SAndroid Build Coastguard Worker                nt,
5308*da0073e9SAndroid Build Coastguard Worker                device,
5309*da0073e9SAndroid Build Coastguard Worker                dtype,
5310*da0073e9SAndroid Build Coastguard Worker                torch.jagged,
5311*da0073e9SAndroid Build Coastguard Worker                components_require_grad,
5312*da0073e9SAndroid Build Coastguard Worker                expected_dim,
5313*da0073e9SAndroid Build Coastguard Worker                expected_batch_size,
5314*da0073e9SAndroid Build Coastguard Worker                expected_contiguous,
5315*da0073e9SAndroid Build Coastguard Worker                expected_min_seqlen,
5316*da0073e9SAndroid Build Coastguard Worker                expected_max_seqlen,
5317*da0073e9SAndroid Build Coastguard Worker            )
5318*da0073e9SAndroid Build Coastguard Worker
5319*da0073e9SAndroid Build Coastguard Worker            # Make sure grads flow back into original tensors for as_nested_tensor()
5320*da0073e9SAndroid Build Coastguard Worker            if components_require_grad:
5321*da0073e9SAndroid Build Coastguard Worker                (nt * 2).backward(torch.ones_like(nt))
5322*da0073e9SAndroid Build Coastguard Worker                for t in tensor_list:
5323*da0073e9SAndroid Build Coastguard Worker                    if t.requires_grad:
5324*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(t.grad, torch.ones_like(t) * 2)
5325*da0073e9SAndroid Build Coastguard Worker                    else:
5326*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(t.grad is None)
5327*da0073e9SAndroid Build Coastguard Worker
5328*da0073e9SAndroid Build Coastguard Worker    @xfailIfTorchDynamo
5329*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
5330*da0073e9SAndroid Build Coastguard Worker        PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property"
5331*da0073e9SAndroid Build Coastguard Worker    )
5332*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
5333*da0073e9SAndroid Build Coastguard Worker    def test_jagged_layout_construction_with_pinned_memory(self, device):
5334*da0073e9SAndroid Build Coastguard Worker        for tensor_list in self._get_example_tensor_lists():
5335*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
5336*da0073e9SAndroid Build Coastguard Worker                tensor_list, layout=torch.jagged, device="cpu", pin_memory=True
5337*da0073e9SAndroid Build Coastguard Worker            )
5338*da0073e9SAndroid Build Coastguard Worker
5339*da0073e9SAndroid Build Coastguard Worker            expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1
5340*da0073e9SAndroid Build Coastguard Worker            expected_batch_size = len(tensor_list)
5341*da0073e9SAndroid Build Coastguard Worker            expected_min_seqlen = min(
5342*da0073e9SAndroid Build Coastguard Worker                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
5343*da0073e9SAndroid Build Coastguard Worker                for t in tensor_list
5344*da0073e9SAndroid Build Coastguard Worker            )
5345*da0073e9SAndroid Build Coastguard Worker            expected_max_seqlen = max(
5346*da0073e9SAndroid Build Coastguard Worker                (torch.tensor(t) if isinstance(t, list) else t).shape[0]
5347*da0073e9SAndroid Build Coastguard Worker                for t in tensor_list
5348*da0073e9SAndroid Build Coastguard Worker            )
5349*da0073e9SAndroid Build Coastguard Worker            self._validate_nt(
5350*da0073e9SAndroid Build Coastguard Worker                nt,
5351*da0073e9SAndroid Build Coastguard Worker                device="cpu",
5352*da0073e9SAndroid Build Coastguard Worker                dtype=torch.float32,
5353*da0073e9SAndroid Build Coastguard Worker                layout=torch.jagged,
5354*da0073e9SAndroid Build Coastguard Worker                requires_grad=False,
5355*da0073e9SAndroid Build Coastguard Worker                dim=expected_dim,
5356*da0073e9SAndroid Build Coastguard Worker                batch_size=expected_batch_size,
5357*da0073e9SAndroid Build Coastguard Worker                contiguous=True,
5358*da0073e9SAndroid Build Coastguard Worker                cached_min_seqlen=expected_min_seqlen,
5359*da0073e9SAndroid Build Coastguard Worker                cached_max_seqlen=expected_max_seqlen,
5360*da0073e9SAndroid Build Coastguard Worker            )
5361*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(nt.is_pinned())
5362*da0073e9SAndroid Build Coastguard Worker
5363*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.half)
5364*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
5365*da0073e9SAndroid Build Coastguard Worker    @parametrize("values_is_view", [False, True])
5366*da0073e9SAndroid Build Coastguard Worker    def test_jagged_view_from_values_offsets(
5367*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, requires_grad, values_is_view
5368*da0073e9SAndroid Build Coastguard Worker    ):
5369*da0073e9SAndroid Build Coastguard Worker        if values_is_view:
5370*da0073e9SAndroid Build Coastguard Worker            # make values a view of base
5371*da0073e9SAndroid Build Coastguard Worker            base = torch.randn(
5372*da0073e9SAndroid Build Coastguard Worker                2, 3, 4, 5, 6, device=device, dtype=dtype, requires_grad=requires_grad
5373*da0073e9SAndroid Build Coastguard Worker            )
5374*da0073e9SAndroid Build Coastguard Worker            values = base.flatten(0, -2)
5375*da0073e9SAndroid Build Coastguard Worker        else:
5376*da0073e9SAndroid Build Coastguard Worker            values = torch.randn(
5377*da0073e9SAndroid Build Coastguard Worker                10, 5, device=device, dtype=dtype, requires_grad=requires_grad
5378*da0073e9SAndroid Build Coastguard Worker            )
5379*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64)
5380*da0073e9SAndroid Build Coastguard Worker
5381*da0073e9SAndroid Build Coastguard Worker        nt = nested_view_from_values_offsets(values, offsets)
5382*da0073e9SAndroid Build Coastguard Worker
5383*da0073e9SAndroid Build Coastguard Worker        expected_dim = values.dim() + 1
5384*da0073e9SAndroid Build Coastguard Worker        expected_batch_size = offsets.shape[0] - 1
5385*da0073e9SAndroid Build Coastguard Worker        expected_base = base if values_is_view else values
5386*da0073e9SAndroid Build Coastguard Worker        lengths = offsets.diff()
5387*da0073e9SAndroid Build Coastguard Worker        self._validate_nt(
5388*da0073e9SAndroid Build Coastguard Worker            nt,
5389*da0073e9SAndroid Build Coastguard Worker            device,
5390*da0073e9SAndroid Build Coastguard Worker            dtype,
5391*da0073e9SAndroid Build Coastguard Worker            torch.jagged,
5392*da0073e9SAndroid Build Coastguard Worker            requires_grad,
5393*da0073e9SAndroid Build Coastguard Worker            expected_dim,
5394*da0073e9SAndroid Build Coastguard Worker            expected_batch_size,
5395*da0073e9SAndroid Build Coastguard Worker            # ensure NT is a proper view
5396*da0073e9SAndroid Build Coastguard Worker            base=expected_base,
5397*da0073e9SAndroid Build Coastguard Worker            contiguous=True,
5398*da0073e9SAndroid Build Coastguard Worker            # if no min / max are passed, expect the metadata cache to be empty
5399*da0073e9SAndroid Build Coastguard Worker            cached_min_seqlen=None,
5400*da0073e9SAndroid Build Coastguard Worker            cached_max_seqlen=None,
5401*da0073e9SAndroid Build Coastguard Worker        )
5402*da0073e9SAndroid Build Coastguard Worker
5403*da0073e9SAndroid Build Coastguard Worker        if requires_grad:
5404*da0073e9SAndroid Build Coastguard Worker            # Make sure grads flow back
5405*da0073e9SAndroid Build Coastguard Worker            (nt * 2).backward(torch.ones_like(nt))
5406*da0073e9SAndroid Build Coastguard Worker
5407*da0073e9SAndroid Build Coastguard Worker            @torch.compiler.disable
5408*da0073e9SAndroid Build Coastguard Worker            def _check_grad(t):
5409*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.grad is not None)
5410*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.grad, torch.ones_like(t) * 2)
5411*da0073e9SAndroid Build Coastguard Worker
5412*da0073e9SAndroid Build Coastguard Worker            _check_grad(base if values_is_view else values)
5413*da0073e9SAndroid Build Coastguard Worker
5414*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
5415*da0073e9SAndroid Build Coastguard Worker    @parametrize("pass_min_max", [False, True])
5416*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_from_jagged(self, device, dtype, pass_min_max):
5417*da0073e9SAndroid Build Coastguard Worker        # === construct from (values, offsets) ===
5418*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(10, 5, device=device, dtype=dtype)
5419*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64)
5420*da0073e9SAndroid Build Coastguard Worker
5421*da0073e9SAndroid Build Coastguard Worker        # compute min / max seqlen
5422*da0073e9SAndroid Build Coastguard Worker        lengths = offsets.diff()
5423*da0073e9SAndroid Build Coastguard Worker        min_seqlen = lengths.min().item()
5424*da0073e9SAndroid Build Coastguard Worker        max_seqlen = lengths.max().item()
5425*da0073e9SAndroid Build Coastguard Worker
5426*da0073e9SAndroid Build Coastguard Worker        if pass_min_max:
5427*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(
5428*da0073e9SAndroid Build Coastguard Worker                values, offsets=offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen
5429*da0073e9SAndroid Build Coastguard Worker            )
5430*da0073e9SAndroid Build Coastguard Worker        else:
5431*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(values, offsets=offsets)
5432*da0073e9SAndroid Build Coastguard Worker        self._validate_nt(
5433*da0073e9SAndroid Build Coastguard Worker            nt,
5434*da0073e9SAndroid Build Coastguard Worker            device,
5435*da0073e9SAndroid Build Coastguard Worker            dtype,
5436*da0073e9SAndroid Build Coastguard Worker            torch.jagged,
5437*da0073e9SAndroid Build Coastguard Worker            requires_grad=False,
5438*da0073e9SAndroid Build Coastguard Worker            dim=3,
5439*da0073e9SAndroid Build Coastguard Worker            batch_size=4,
5440*da0073e9SAndroid Build Coastguard Worker            contiguous=True,
5441*da0073e9SAndroid Build Coastguard Worker            cached_min_seqlen=(min_seqlen if pass_min_max else None),
5442*da0073e9SAndroid Build Coastguard Worker            cached_max_seqlen=(max_seqlen if pass_min_max else None),
5443*da0073e9SAndroid Build Coastguard Worker            base=values,
5444*da0073e9SAndroid Build Coastguard Worker        )
5445*da0073e9SAndroid Build Coastguard Worker
5446*da0073e9SAndroid Build Coastguard Worker        # === construct from (values, offsets, lengths) ===
5447*da0073e9SAndroid Build Coastguard Worker        lengths = torch.tensor([2, 1, 1, 2], device=device)
5448*da0073e9SAndroid Build Coastguard Worker
5449*da0073e9SAndroid Build Coastguard Worker        # compute min / max seqlen
5450*da0073e9SAndroid Build Coastguard Worker        min_seqlen = lengths.min().item()
5451*da0073e9SAndroid Build Coastguard Worker        max_seqlen = lengths.max().item()
5452*da0073e9SAndroid Build Coastguard Worker
5453*da0073e9SAndroid Build Coastguard Worker        if pass_min_max:
5454*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(
5455*da0073e9SAndroid Build Coastguard Worker                values,
5456*da0073e9SAndroid Build Coastguard Worker                offsets=offsets,
5457*da0073e9SAndroid Build Coastguard Worker                lengths=lengths,
5458*da0073e9SAndroid Build Coastguard Worker                min_seqlen=min_seqlen,
5459*da0073e9SAndroid Build Coastguard Worker                max_seqlen=max_seqlen,
5460*da0073e9SAndroid Build Coastguard Worker            )
5461*da0073e9SAndroid Build Coastguard Worker        else:
5462*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(
5463*da0073e9SAndroid Build Coastguard Worker                values, offsets=offsets, lengths=lengths
5464*da0073e9SAndroid Build Coastguard Worker            )
5465*da0073e9SAndroid Build Coastguard Worker
5466*da0073e9SAndroid Build Coastguard Worker        # when both offsets / lengths are specified, expect non-contiguous
5467*da0073e9SAndroid Build Coastguard Worker        self._validate_nt(
5468*da0073e9SAndroid Build Coastguard Worker            nt,
5469*da0073e9SAndroid Build Coastguard Worker            device,
5470*da0073e9SAndroid Build Coastguard Worker            dtype,
5471*da0073e9SAndroid Build Coastguard Worker            torch.jagged,
5472*da0073e9SAndroid Build Coastguard Worker            requires_grad=False,
5473*da0073e9SAndroid Build Coastguard Worker            dim=3,
5474*da0073e9SAndroid Build Coastguard Worker            batch_size=4,
5475*da0073e9SAndroid Build Coastguard Worker            contiguous=False,
5476*da0073e9SAndroid Build Coastguard Worker            cached_min_seqlen=(min_seqlen if pass_min_max else None),
5477*da0073e9SAndroid Build Coastguard Worker            cached_max_seqlen=(max_seqlen if pass_min_max else None),
5478*da0073e9SAndroid Build Coastguard Worker            base=values,
5479*da0073e9SAndroid Build Coastguard Worker        )
5480*da0073e9SAndroid Build Coastguard Worker        self.assertIs(nt.lengths(), lengths)
5481*da0073e9SAndroid Build Coastguard Worker
5482*da0073e9SAndroid Build Coastguard Worker        # === construct from (values, lengths) ===
5483*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(14, 5, device=device, dtype=dtype)
5484*da0073e9SAndroid Build Coastguard Worker        lengths = torch.tensor([2, 3, 4, 5], device=device)
5485*da0073e9SAndroid Build Coastguard Worker
5486*da0073e9SAndroid Build Coastguard Worker        # compute min / max seqlen
5487*da0073e9SAndroid Build Coastguard Worker        min_seqlen = lengths.min().item()
5488*da0073e9SAndroid Build Coastguard Worker        max_seqlen = lengths.max().item()
5489*da0073e9SAndroid Build Coastguard Worker
5490*da0073e9SAndroid Build Coastguard Worker        if pass_min_max:
5491*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(
5492*da0073e9SAndroid Build Coastguard Worker                values, lengths=lengths, min_seqlen=min_seqlen, max_seqlen=max_seqlen
5493*da0073e9SAndroid Build Coastguard Worker            )
5494*da0073e9SAndroid Build Coastguard Worker        else:
5495*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths)
5496*da0073e9SAndroid Build Coastguard Worker
5497*da0073e9SAndroid Build Coastguard Worker        # for now, if only lengths is specified, convert to offsets to integrate best with the
5498*da0073e9SAndroid Build Coastguard Worker        # existing kernels
5499*da0073e9SAndroid Build Coastguard Worker        expected_offsets = torch.tensor([0, 2, 5, 9, 14], device=device)
5500*da0073e9SAndroid Build Coastguard Worker        expected_nt = torch.nested.nested_tensor_from_jagged(
5501*da0073e9SAndroid Build Coastguard Worker            values, offsets=expected_offsets
5502*da0073e9SAndroid Build Coastguard Worker        )
5503*da0073e9SAndroid Build Coastguard Worker        self._validate_nt(
5504*da0073e9SAndroid Build Coastguard Worker            nt,
5505*da0073e9SAndroid Build Coastguard Worker            device,
5506*da0073e9SAndroid Build Coastguard Worker            dtype,
5507*da0073e9SAndroid Build Coastguard Worker            torch.jagged,
5508*da0073e9SAndroid Build Coastguard Worker            requires_grad=False,
5509*da0073e9SAndroid Build Coastguard Worker            dim=3,
5510*da0073e9SAndroid Build Coastguard Worker            batch_size=4,
5511*da0073e9SAndroid Build Coastguard Worker            contiguous=True,
5512*da0073e9SAndroid Build Coastguard Worker            cached_min_seqlen=(min_seqlen if pass_min_max else None),
5513*da0073e9SAndroid Build Coastguard Worker            cached_max_seqlen=(max_seqlen if pass_min_max else None),
5514*da0073e9SAndroid Build Coastguard Worker            base=values,
5515*da0073e9SAndroid Build Coastguard Worker            ref_nt=expected_nt,
5516*da0073e9SAndroid Build Coastguard Worker        )
5517*da0073e9SAndroid Build Coastguard Worker
5518*da0073e9SAndroid Build Coastguard Worker        # error case: no offsets or lengths
5519*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5520*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "At least one of offsets or lengths is required"
5521*da0073e9SAndroid Build Coastguard Worker        ):
5522*da0073e9SAndroid Build Coastguard Worker            torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None)
5523*da0073e9SAndroid Build Coastguard Worker
5524*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
5525*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_from_jagged_fx_trace(self, device):
5526*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
5527*da0073e9SAndroid Build Coastguard Worker            return torch.nested.nested_tensor_from_jagged(x, y)
5528*da0073e9SAndroid Build Coastguard Worker
5529*da0073e9SAndroid Build Coastguard Worker        def user_unwrapped(x, y):
5530*da0073e9SAndroid Build Coastguard Worker            return fn(x, y)
5531*da0073e9SAndroid Build Coastguard Worker
5532*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5533*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
5534*da0073e9SAndroid Build Coastguard Worker            "torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace",
5535*da0073e9SAndroid Build Coastguard Worker        ):
5536*da0073e9SAndroid Build Coastguard Worker            torch.fx.symbolic_trace(user_unwrapped)
5537*da0073e9SAndroid Build Coastguard Worker
5538*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.half)
5539*da0073e9SAndroid Build Coastguard Worker    @parametrize("dim", range(5))
5540*da0073e9SAndroid Build Coastguard Worker    @parametrize(
5541*da0073e9SAndroid Build Coastguard Worker        "layout",
5542*da0073e9SAndroid Build Coastguard Worker        [torch.strided, torch.jagged],
5543*da0073e9SAndroid Build Coastguard Worker        name_fn=lambda l: f"layout_{str(l).split('.')[1]}",
5544*da0073e9SAndroid Build Coastguard Worker    )
5545*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
5546*da0073e9SAndroid Build Coastguard Worker    @parametrize("contiguous", [False, True])
5547*da0073e9SAndroid Build Coastguard Worker    def test_as_nested_tensor_from_tensor(
5548*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, dim, layout, requires_grad, contiguous
5549*da0073e9SAndroid Build Coastguard Worker    ):
5550*da0073e9SAndroid Build Coastguard Worker        if dim == 0:
5551*da0073e9SAndroid Build Coastguard Worker            t = torch.tensor(3.0, requires_grad=requires_grad)
5552*da0073e9SAndroid Build Coastguard Worker        else:
5553*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(*(3 for _ in range(dim)), requires_grad=requires_grad)
5554*da0073e9SAndroid Build Coastguard Worker        assert t.dim() == dim
5555*da0073e9SAndroid Build Coastguard Worker
5556*da0073e9SAndroid Build Coastguard Worker        if dim < 2:
5557*da0073e9SAndroid Build Coastguard Worker            # 0-1 dim tensors can't be converted to NTs
5558*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5559*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Expected tensor argument to have dim"
5560*da0073e9SAndroid Build Coastguard Worker            ):
5561*da0073e9SAndroid Build Coastguard Worker                nt = torch.nested.as_nested_tensor(
5562*da0073e9SAndroid Build Coastguard Worker                    t, device=device, dtype=dtype, layout=layout
5563*da0073e9SAndroid Build Coastguard Worker                )
5564*da0073e9SAndroid Build Coastguard Worker            return
5565*da0073e9SAndroid Build Coastguard Worker
5566*da0073e9SAndroid Build Coastguard Worker        orig_t = t
5567*da0073e9SAndroid Build Coastguard Worker        if not contiguous:
5568*da0073e9SAndroid Build Coastguard Worker            t = t.transpose(0, 1)
5569*da0073e9SAndroid Build Coastguard Worker
5570*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.as_nested_tensor(t, device=device, dtype=dtype, layout=layout)
5571*da0073e9SAndroid Build Coastguard Worker        expected_dim = t.dim()
5572*da0073e9SAndroid Build Coastguard Worker        expected_batch_size = t.size(0)
5573*da0073e9SAndroid Build Coastguard Worker        expected_seqlen = t.size(1) if layout == torch.jagged else None
5574*da0073e9SAndroid Build Coastguard Worker        self._validate_nt(
5575*da0073e9SAndroid Build Coastguard Worker            nt,
5576*da0073e9SAndroid Build Coastguard Worker            device,
5577*da0073e9SAndroid Build Coastguard Worker            dtype,
5578*da0073e9SAndroid Build Coastguard Worker            layout,
5579*da0073e9SAndroid Build Coastguard Worker            requires_grad=requires_grad,
5580*da0073e9SAndroid Build Coastguard Worker            dim=dim,
5581*da0073e9SAndroid Build Coastguard Worker            batch_size=expected_batch_size,
5582*da0073e9SAndroid Build Coastguard Worker            contiguous=True,
5583*da0073e9SAndroid Build Coastguard Worker            cached_min_seqlen=expected_seqlen,
5584*da0073e9SAndroid Build Coastguard Worker            cached_max_seqlen=expected_seqlen,
5585*da0073e9SAndroid Build Coastguard Worker        )
5586*da0073e9SAndroid Build Coastguard Worker
5587*da0073e9SAndroid Build Coastguard Worker        if torch.device(device) == t.device and dtype == t.dtype and contiguous:
5588*da0073e9SAndroid Build Coastguard Worker            # should be the non-copying (view) case
5589*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(nt._is_view() and nt._base is t)
5590*da0073e9SAndroid Build Coastguard Worker
5591*da0073e9SAndroid Build Coastguard Worker        # should have equivalent components to construction from unbound tensor list
5592*da0073e9SAndroid Build Coastguard Worker        nt_from_unbind = torch.nested.as_nested_tensor(
5593*da0073e9SAndroid Build Coastguard Worker            list(t.unbind(0)), device=device, dtype=dtype, layout=layout
5594*da0073e9SAndroid Build Coastguard Worker        )
5595*da0073e9SAndroid Build Coastguard Worker        self.assertEqualIgnoringNestedInts(nt, nt_from_unbind)
5596*da0073e9SAndroid Build Coastguard Worker
5597*da0073e9SAndroid Build Coastguard Worker        # ensure call on a NT with the same properties returns the NT directly
5598*da0073e9SAndroid Build Coastguard Worker        nt2 = torch.nested.as_nested_tensor(
5599*da0073e9SAndroid Build Coastguard Worker            nt, device=device, dtype=dtype, layout=layout
5600*da0073e9SAndroid Build Coastguard Worker        )
5601*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nt is nt2)
5602*da0073e9SAndroid Build Coastguard Worker
5603*da0073e9SAndroid Build Coastguard Worker        # ensure call with device=None uses input tensor device
5604*da0073e9SAndroid Build Coastguard Worker        nt3 = torch.nested.as_nested_tensor(
5605*da0073e9SAndroid Build Coastguard Worker            t.to(device=device, dtype=dtype),
5606*da0073e9SAndroid Build Coastguard Worker            device=None,
5607*da0073e9SAndroid Build Coastguard Worker            dtype=None,
5608*da0073e9SAndroid Build Coastguard Worker            layout=layout,
5609*da0073e9SAndroid Build Coastguard Worker        )
5610*da0073e9SAndroid Build Coastguard Worker        self._validate_nt(
5611*da0073e9SAndroid Build Coastguard Worker            nt3,
5612*da0073e9SAndroid Build Coastguard Worker            device,
5613*da0073e9SAndroid Build Coastguard Worker            dtype,
5614*da0073e9SAndroid Build Coastguard Worker            layout,
5615*da0073e9SAndroid Build Coastguard Worker            requires_grad=requires_grad,
5616*da0073e9SAndroid Build Coastguard Worker            dim=dim,
5617*da0073e9SAndroid Build Coastguard Worker            batch_size=expected_batch_size,
5618*da0073e9SAndroid Build Coastguard Worker            contiguous=True,
5619*da0073e9SAndroid Build Coastguard Worker            cached_min_seqlen=expected_seqlen,
5620*da0073e9SAndroid Build Coastguard Worker            cached_max_seqlen=expected_seqlen,
5621*da0073e9SAndroid Build Coastguard Worker        )
5622*da0073e9SAndroid Build Coastguard Worker
5623*da0073e9SAndroid Build Coastguard Worker        # we don't support conversion between layouts this way atm
5624*da0073e9SAndroid Build Coastguard Worker        other_layout = torch.strided if layout == torch.jagged else torch.jagged
5625*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5626*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Converting between nested tensor layouts is not supported"
5627*da0073e9SAndroid Build Coastguard Worker        ):
5628*da0073e9SAndroid Build Coastguard Worker            torch.nested.as_nested_tensor(
5629*da0073e9SAndroid Build Coastguard Worker                nt, device=device, dtype=dtype, layout=other_layout
5630*da0073e9SAndroid Build Coastguard Worker            )
5631*da0073e9SAndroid Build Coastguard Worker
5632*da0073e9SAndroid Build Coastguard Worker        if requires_grad:
5633*da0073e9SAndroid Build Coastguard Worker            # make sure gradients flow back into inputs
5634*da0073e9SAndroid Build Coastguard Worker            (nt * 2).backward(torch.ones_like(nt))
5635*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(orig_t.grad, torch.ones_like(orig_t) * 2)
5636*da0073e9SAndroid Build Coastguard Worker
5637*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.half)
5638*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
5639*da0073e9SAndroid Build Coastguard Worker    def test_device_dtype_transfer_updates_offsets(self, device, dtype):
5640*da0073e9SAndroid Build Coastguard Worker        for tensor_list in self._get_example_tensor_lists():
5641*da0073e9SAndroid Build Coastguard Worker            orig_device = torch.device("cpu")
5642*da0073e9SAndroid Build Coastguard Worker            orig_dtype = torch.float32
5643*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
5644*da0073e9SAndroid Build Coastguard Worker                tensor_list, layout=torch.jagged, device=orig_device, dtype=orig_dtype
5645*da0073e9SAndroid Build Coastguard Worker            )
5646*da0073e9SAndroid Build Coastguard Worker
5647*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.int64, nt.offsets().dtype)
5648*da0073e9SAndroid Build Coastguard Worker            nt = nt.to(device=device).to(dtype=dtype)
5649*da0073e9SAndroid Build Coastguard Worker
5650*da0073e9SAndroid Build Coastguard Worker            # offsets should still be int64 on the new device
5651*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt.values().device, nt.offsets().device)
5652*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.int64, nt.offsets().dtype)
5653*da0073e9SAndroid Build Coastguard Worker
5654*da0073e9SAndroid Build Coastguard Worker    def test_unbind(self, device):
5655*da0073e9SAndroid Build Coastguard Worker        for tensor_list in self._get_example_tensor_lists():
5656*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
5657*da0073e9SAndroid Build Coastguard Worker                tensor_list, layout=torch.jagged, device=device
5658*da0073e9SAndroid Build Coastguard Worker            )  # ragged_idx = 1
5659*da0073e9SAndroid Build Coastguard Worker            out = nt.unbind()
5660*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(out), len(tensor_list))
5661*da0073e9SAndroid Build Coastguard Worker            for i, t in enumerate(out):
5662*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t, tensor_list[i])
5663*da0073e9SAndroid Build Coastguard Worker
5664*da0073e9SAndroid Build Coastguard Worker    @parametrize("ragged_idx", [2, 3])
5665*da0073e9SAndroid Build Coastguard Worker    def test_unbind_transpose(self, device, ragged_idx):
5666*da0073e9SAndroid Build Coastguard Worker        for tensor_list in self._get_example_tensor_lists():
5667*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
5668*da0073e9SAndroid Build Coastguard Worker                tensor_list, layout=torch.jagged, device=device
5669*da0073e9SAndroid Build Coastguard Worker            )
5670*da0073e9SAndroid Build Coastguard Worker            if ragged_idx < nt.dim():
5671*da0073e9SAndroid Build Coastguard Worker                nt = nt.transpose(1, ragged_idx)  # set ragged_idx
5672*da0073e9SAndroid Build Coastguard Worker                out = nt.unbind()
5673*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(out), len(tensor_list))
5674*da0073e9SAndroid Build Coastguard Worker                for i, t in enumerate(out):
5675*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
5676*da0073e9SAndroid Build Coastguard Worker                        t.transpose(0, ragged_idx - 1), tensor_list[i]
5677*da0073e9SAndroid Build Coastguard Worker                    )  # transpose back each element of result
5678*da0073e9SAndroid Build Coastguard Worker
5679*da0073e9SAndroid Build Coastguard Worker    def test_unbind_transpose_ragged_idx_last_dim(self, device):
5680*da0073e9SAndroid Build Coastguard Worker        for tensor_list in self._get_example_tensor_lists():
5681*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor(
5682*da0073e9SAndroid Build Coastguard Worker                tensor_list, layout=torch.jagged, device=device
5683*da0073e9SAndroid Build Coastguard Worker            ).transpose(1, -1)  # set ragged_idx = last dimension
5684*da0073e9SAndroid Build Coastguard Worker            out = nt.unbind()
5685*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(out), len(tensor_list))
5686*da0073e9SAndroid Build Coastguard Worker            for i, t in enumerate(out):
5687*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
5688*da0073e9SAndroid Build Coastguard Worker                    t.transpose(0, -1), tensor_list[i]
5689*da0073e9SAndroid Build Coastguard Worker                )  # transpose back each element of result
5690*da0073e9SAndroid Build Coastguard Worker
5691*da0073e9SAndroid Build Coastguard Worker    def test_unbind_lengths(self, device):
5692*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(16, 128, device=device)
5693*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
5694*da0073e9SAndroid Build Coastguard Worker        lengths = torch.tensor([6, 2, 1, 2], device=device)
5695*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor_from_jagged(
5696*da0073e9SAndroid Build Coastguard Worker            values, offsets=offsets, lengths=lengths
5697*da0073e9SAndroid Build Coastguard Worker        )  # 3D nested tensor
5698*da0073e9SAndroid Build Coastguard Worker
5699*da0073e9SAndroid Build Coastguard Worker        tensor_list = []
5700*da0073e9SAndroid Build Coastguard Worker        for i in range(offsets.shape[0] - 1):
5701*da0073e9SAndroid Build Coastguard Worker            tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i])])
5702*da0073e9SAndroid Build Coastguard Worker
5703*da0073e9SAndroid Build Coastguard Worker        out = nt.unbind()
5704*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(out), len(tensor_list))
5705*da0073e9SAndroid Build Coastguard Worker        for i, t in enumerate(out):
5706*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t, tensor_list[i])
5707*da0073e9SAndroid Build Coastguard Worker
5708*da0073e9SAndroid Build Coastguard Worker    def test_unbind_lengths_ragged_idx_1(self, device):
5709*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(16, 8, 128, device=device)
5710*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
5711*da0073e9SAndroid Build Coastguard Worker        lengths = torch.tensor([6, 2, 1, 2], device=device)
5712*da0073e9SAndroid Build Coastguard Worker        ragged_idx = 1
5713*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested._internal.nested_tensor.NestedTensor(
5714*da0073e9SAndroid Build Coastguard Worker            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
5715*da0073e9SAndroid Build Coastguard Worker        )  # 4D nested tensor
5716*da0073e9SAndroid Build Coastguard Worker
5717*da0073e9SAndroid Build Coastguard Worker        tensor_list = []
5718*da0073e9SAndroid Build Coastguard Worker        for i in range(offsets.shape[0] - 1):
5719*da0073e9SAndroid Build Coastguard Worker            tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i]), :, :])
5720*da0073e9SAndroid Build Coastguard Worker
5721*da0073e9SAndroid Build Coastguard Worker        out = nt.unbind()
5722*da0073e9SAndroid Build Coastguard Worker
5723*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(out), len(tensor_list))
5724*da0073e9SAndroid Build Coastguard Worker        for i, t in enumerate(out):
5725*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t, tensor_list[i])
5726*da0073e9SAndroid Build Coastguard Worker
5727*da0073e9SAndroid Build Coastguard Worker    def test_unbind_lengths_ragged_idx_equals_2_bad_dim(self, device):
5728*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(16, 8, 128, device=device)
5729*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
5730*da0073e9SAndroid Build Coastguard Worker        lengths = torch.tensor([6, 2, 1, 2], device=device)
5731*da0073e9SAndroid Build Coastguard Worker        ragged_idx = 2
5732*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested._internal.nested_tensor.NestedTensor(
5733*da0073e9SAndroid Build Coastguard Worker            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
5734*da0073e9SAndroid Build Coastguard Worker        )  # 4D nested tensor
5735*da0073e9SAndroid Build Coastguard Worker
5736*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
5737*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
5738*da0073e9SAndroid Build Coastguard Worker            r"unbind\(\): nested tensor offsets and lengths.*",
5739*da0073e9SAndroid Build Coastguard Worker            lambda: nt.unbind(),
5740*da0073e9SAndroid Build Coastguard Worker        )
5741*da0073e9SAndroid Build Coastguard Worker
5742*da0073e9SAndroid Build Coastguard Worker    def test_unbind_lengths_ragged_idx_2(self, device):
5743*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(16, 8, 128, device=device)
5744*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 4, 8], device=device)
5745*da0073e9SAndroid Build Coastguard Worker        lengths = torch.tensor([2, 1, 3], device=device)
5746*da0073e9SAndroid Build Coastguard Worker        ragged_idx = 2
5747*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested._internal.nested_tensor.NestedTensor(
5748*da0073e9SAndroid Build Coastguard Worker            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
5749*da0073e9SAndroid Build Coastguard Worker        )  # 4D nested tensor
5750*da0073e9SAndroid Build Coastguard Worker
5751*da0073e9SAndroid Build Coastguard Worker        tensor_list = []
5752*da0073e9SAndroid Build Coastguard Worker        for i in range(offsets.shape[0] - 1):
5753*da0073e9SAndroid Build Coastguard Worker            tensor_list.append(values[:, offsets[i] : (offsets[i] + lengths[i]), :])
5754*da0073e9SAndroid Build Coastguard Worker
5755*da0073e9SAndroid Build Coastguard Worker        out = nt.unbind()
5756*da0073e9SAndroid Build Coastguard Worker
5757*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(out), len(tensor_list))
5758*da0073e9SAndroid Build Coastguard Worker        for i, t in enumerate(out):
5759*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t, tensor_list[i])
5760*da0073e9SAndroid Build Coastguard Worker
5761*da0073e9SAndroid Build Coastguard Worker    def test_unbind_lengths_ragged_idx_3(self, device):
5762*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(16, 8, 128, device=device)
5763*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 100, 128], device=device)
5764*da0073e9SAndroid Build Coastguard Worker        lengths = torch.tensor([50, 28], device=device)
5765*da0073e9SAndroid Build Coastguard Worker        ragged_idx = 3
5766*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested._internal.nested_tensor.NestedTensor(
5767*da0073e9SAndroid Build Coastguard Worker            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
5768*da0073e9SAndroid Build Coastguard Worker        )  # 4D nested tensor
5769*da0073e9SAndroid Build Coastguard Worker
5770*da0073e9SAndroid Build Coastguard Worker        tensor_list = []
5771*da0073e9SAndroid Build Coastguard Worker        for i in range(offsets.shape[0] - 1):
5772*da0073e9SAndroid Build Coastguard Worker            tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])])
5773*da0073e9SAndroid Build Coastguard Worker
5774*da0073e9SAndroid Build Coastguard Worker        out = nt.unbind()
5775*da0073e9SAndroid Build Coastguard Worker
5776*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(out), len(tensor_list))
5777*da0073e9SAndroid Build Coastguard Worker        for i, t in enumerate(out):
5778*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t, tensor_list[i])
5779*da0073e9SAndroid Build Coastguard Worker
5780*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo(
5781*da0073e9SAndroid Build Coastguard Worker        "TorchDynamo raises an error for ragged_idx == 0 earlier than Torch"
5782*da0073e9SAndroid Build Coastguard Worker    )
5783*da0073e9SAndroid Build Coastguard Worker    def test_unbind_lengths_ragged_idx_0(self, device):
5784*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(16, 8, 128, device=device)
5785*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 100, 128], device=device)
5786*da0073e9SAndroid Build Coastguard Worker        lengths = torch.tensor([50, 28], device=device)
5787*da0073e9SAndroid Build Coastguard Worker        ragged_idx = 0
5788*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested._internal.nested_tensor.NestedTensor(
5789*da0073e9SAndroid Build Coastguard Worker            values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
5790*da0073e9SAndroid Build Coastguard Worker        )  # 4D nested tensor
5791*da0073e9SAndroid Build Coastguard Worker
5792*da0073e9SAndroid Build Coastguard Worker        tensor_list = []
5793*da0073e9SAndroid Build Coastguard Worker        for i in range(offsets.shape[0] - 1):
5794*da0073e9SAndroid Build Coastguard Worker            tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])])
5795*da0073e9SAndroid Build Coastguard Worker
5796*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
5797*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
5798*da0073e9SAndroid Build Coastguard Worker            r"unbind\(\): nested tensor.*out of bounds",
5799*da0073e9SAndroid Build Coastguard Worker            lambda: nt.unbind(),
5800*da0073e9SAndroid Build Coastguard Worker        )
5801*da0073e9SAndroid Build Coastguard Worker
5802*da0073e9SAndroid Build Coastguard Worker    def test_narrow(self, device):
5803*da0073e9SAndroid Build Coastguard Worker        starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
5804*da0073e9SAndroid Build Coastguard Worker        lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
5805*da0073e9SAndroid Build Coastguard Worker        buffer = (
5806*da0073e9SAndroid Build Coastguard Worker            torch.arange(0, 10, device=device, dtype=torch.int64)
5807*da0073e9SAndroid Build Coastguard Worker            .unsqueeze(0)
5808*da0073e9SAndroid Build Coastguard Worker            .expand(5, -1)
5809*da0073e9SAndroid Build Coastguard Worker            .clone()
5810*da0073e9SAndroid Build Coastguard Worker            .detach()
5811*da0073e9SAndroid Build Coastguard Worker        )
5812*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.narrow(buffer, 1, starts, lengths, layout=torch.jagged)
5813*da0073e9SAndroid Build Coastguard Worker
5814*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nt._is_view() and nt._base is buffer)
5815*da0073e9SAndroid Build Coastguard Worker
5816*da0073e9SAndroid Build Coastguard Worker        # TODO: Use this approach when unbind is functional
5817*da0073e9SAndroid Build Coastguard Worker        # unbinded_nt = nt.unbind()
5818*da0073e9SAndroid Build Coastguard Worker        # for i in range(starts.shape[0]):
5819*da0073e9SAndroid Build Coastguard Worker        #     self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i])
5820*da0073e9SAndroid Build Coastguard Worker        for i in range(starts.shape[0]):
5821*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
5822*da0073e9SAndroid Build Coastguard Worker                torch.arange(
5823*da0073e9SAndroid Build Coastguard Worker                    starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64
5824*da0073e9SAndroid Build Coastguard Worker                ),
5825*da0073e9SAndroid Build Coastguard Worker                nt.values()[nt.offsets()[i] : (nt.offsets()[i] + nt.lengths()[i])],
5826*da0073e9SAndroid Build Coastguard Worker            )
5827*da0073e9SAndroid Build Coastguard Worker
5828*da0073e9SAndroid Build Coastguard Worker    def test_njt_cat(self, device):
5829*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 3], device=device, dtype=torch.int64)
5830*da0073e9SAndroid Build Coastguard Worker        values_1 = torch.randn(
5831*da0073e9SAndroid Build Coastguard Worker            3, 2, dtype=torch.float64, device=device, requires_grad=True
5832*da0073e9SAndroid Build Coastguard Worker        )
5833*da0073e9SAndroid Build Coastguard Worker        values_2 = torch.randn(
5834*da0073e9SAndroid Build Coastguard Worker            3, 4, dtype=torch.float64, device=device, requires_grad=True
5835*da0073e9SAndroid Build Coastguard Worker        )
5836*da0073e9SAndroid Build Coastguard Worker
5837*da0073e9SAndroid Build Coastguard Worker        def grad_test_func(values_1, values_2, offsets):
5838*da0073e9SAndroid Build Coastguard Worker            nt_1 = torch.nested.nested_tensor_from_jagged(values_1, offsets)
5839*da0073e9SAndroid Build Coastguard Worker            nt_2 = torch.nested.nested_tensor_from_jagged(values_2, offsets)
5840*da0073e9SAndroid Build Coastguard Worker            nt_3 = torch.cat([nt_1, nt_2], dim=-1)
5841*da0073e9SAndroid Build Coastguard Worker            return nt_3.values()
5842*da0073e9SAndroid Build Coastguard Worker
5843*da0073e9SAndroid Build Coastguard Worker        assert gradcheck(
5844*da0073e9SAndroid Build Coastguard Worker            grad_test_func,
5845*da0073e9SAndroid Build Coastguard Worker            inputs=(values_1, values_2, offsets),
5846*da0073e9SAndroid Build Coastguard Worker            check_batched_grad=False,
5847*da0073e9SAndroid Build Coastguard Worker        )
5848*da0073e9SAndroid Build Coastguard Worker
5849*da0073e9SAndroid Build Coastguard Worker    def test_is_contiguous(self, device):
5850*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
5851*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
5852*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
5853*da0073e9SAndroid Build Coastguard Worker        nt_contiguous = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
5854*da0073e9SAndroid Build Coastguard Worker
5855*da0073e9SAndroid Build Coastguard Worker        starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
5856*da0073e9SAndroid Build Coastguard Worker        lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
5857*da0073e9SAndroid Build Coastguard Worker        narrow_base = (
5858*da0073e9SAndroid Build Coastguard Worker            torch.arange(0, 10, device=device, dtype=torch.int64)
5859*da0073e9SAndroid Build Coastguard Worker            .unsqueeze(0)
5860*da0073e9SAndroid Build Coastguard Worker            .expand(5, -1)
5861*da0073e9SAndroid Build Coastguard Worker            .clone()
5862*da0073e9SAndroid Build Coastguard Worker        )
5863*da0073e9SAndroid Build Coastguard Worker        nt_noncontiguous = torch.nested.narrow(
5864*da0073e9SAndroid Build Coastguard Worker            narrow_base, 1, starts_nc, lengths_nc, layout=torch.jagged
5865*da0073e9SAndroid Build Coastguard Worker        )
5866*da0073e9SAndroid Build Coastguard Worker
5867*da0073e9SAndroid Build Coastguard Worker        starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64)
5868*da0073e9SAndroid Build Coastguard Worker        lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64)
5869*da0073e9SAndroid Build Coastguard Worker        nt_contiguous_narrow = torch.nested.narrow(
5870*da0073e9SAndroid Build Coastguard Worker            narrow_base, 1, starts_c, lengths_c, layout=torch.jagged
5871*da0073e9SAndroid Build Coastguard Worker        )
5872*da0073e9SAndroid Build Coastguard Worker
5873*da0073e9SAndroid Build Coastguard Worker        # Test contiguous case
5874*da0073e9SAndroid Build Coastguard Worker        assert nt_contiguous.is_contiguous()
5875*da0073e9SAndroid Build Coastguard Worker
5876*da0073e9SAndroid Build Coastguard Worker        # Test narrow case
5877*da0073e9SAndroid Build Coastguard Worker        assert not nt_noncontiguous.is_contiguous()
5878*da0073e9SAndroid Build Coastguard Worker        assert nt_contiguous_narrow.is_contiguous()
5879*da0073e9SAndroid Build Coastguard Worker
5880*da0073e9SAndroid Build Coastguard Worker        # Test querying by memory_format
5881*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5882*da0073e9SAndroid Build Coastguard Worker            nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)
5883*da0073e9SAndroid Build Coastguard Worker        )
5884*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5885*da0073e9SAndroid Build Coastguard Worker            not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)
5886*da0073e9SAndroid Build Coastguard Worker        )
5887*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5888*da0073e9SAndroid Build Coastguard Worker            nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format)
5889*da0073e9SAndroid Build Coastguard Worker        )
5890*da0073e9SAndroid Build Coastguard Worker
5891*da0073e9SAndroid Build Coastguard Worker    def test_layout_under_torch_dispatch_mode(self):
5892*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.logging_tensor import (
5893*da0073e9SAndroid Build Coastguard Worker            capture_logs_with_logging_tensor_mode,
5894*da0073e9SAndroid Build Coastguard Worker        )
5895*da0073e9SAndroid Build Coastguard Worker
5896*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
5897*da0073e9SAndroid Build Coastguard Worker            [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
5898*da0073e9SAndroid Build Coastguard Worker        )
5899*da0073e9SAndroid Build Coastguard Worker
5900*da0073e9SAndroid Build Coastguard Worker        with capture_logs_with_logging_tensor_mode():
5901*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt.layout, torch.jagged)
5902*da0073e9SAndroid Build Coastguard Worker
5903*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
5904*da0073e9SAndroid Build Coastguard Worker    @parametrize(
5905*da0073e9SAndroid Build Coastguard Worker        "func", [torch.empty_like, torch.randn_like], name_fn=lambda f: f.__name__
5906*da0073e9SAndroid Build Coastguard Worker    )
5907*da0073e9SAndroid Build Coastguard Worker    def test_like_shape(self, func):
5908*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
5909*da0073e9SAndroid Build Coastguard Worker            [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
5910*da0073e9SAndroid Build Coastguard Worker        )
5911*da0073e9SAndroid Build Coastguard Worker        nt_like = func(nt)
5912*da0073e9SAndroid Build Coastguard Worker
5913*da0073e9SAndroid Build Coastguard Worker        for nt_ub in nt_like.unbind():
5914*da0073e9SAndroid Build Coastguard Worker            t_like = func(nt_ub)
5915*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_ub.shape, t_like.shape)
5916*da0073e9SAndroid Build Coastguard Worker
5917*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
5918*da0073e9SAndroid Build Coastguard Worker    @parametrize(
5919*da0073e9SAndroid Build Coastguard Worker        "func", [torch.ones_like, torch.zeros_like], name_fn=lambda f: f.__name__
5920*da0073e9SAndroid Build Coastguard Worker    )
5921*da0073e9SAndroid Build Coastguard Worker    def test_like_value(self, func):
5922*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
5923*da0073e9SAndroid Build Coastguard Worker            [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
5924*da0073e9SAndroid Build Coastguard Worker        )
5925*da0073e9SAndroid Build Coastguard Worker        nt_like = func(nt)
5926*da0073e9SAndroid Build Coastguard Worker
5927*da0073e9SAndroid Build Coastguard Worker        for nt_ub in nt_like.unbind():
5928*da0073e9SAndroid Build Coastguard Worker            t_like = func(nt_ub)
5929*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt_ub, t_like)
5930*da0073e9SAndroid Build Coastguard Worker
5931*da0073e9SAndroid Build Coastguard Worker    def test_noncontiguous_pointwise(self, device):
5932*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
5933*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
5934*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(4, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
5935*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor([a, b, c], layout=torch.jagged)
5936*da0073e9SAndroid Build Coastguard Worker        # transpose ragged dim
5937*da0073e9SAndroid Build Coastguard Worker        transposed = nt.transpose(1, 2)
5938*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(transposed.is_contiguous())
5939*da0073e9SAndroid Build Coastguard Worker        clone = transposed.clone()
5940*da0073e9SAndroid Build Coastguard Worker
5941*da0073e9SAndroid Build Coastguard Worker        def check_nt_equality(x, y):
5942*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.values(), y.values())
5943*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.offsets(), y.offsets())
5944*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x._ragged_idx, y._ragged_idx)
5945*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.shape, y.shape)
5946*da0073e9SAndroid Build Coastguard Worker
5947*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(clone.is_contiguous())
5948*da0073e9SAndroid Build Coastguard Worker        check_nt_equality(clone, transposed)
5949*da0073e9SAndroid Build Coastguard Worker
5950*da0073e9SAndroid Build Coastguard Worker        clone_contig = transposed.clone(memory_format=torch.contiguous_format)
5951*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(clone_contig.is_contiguous())
5952*da0073e9SAndroid Build Coastguard Worker        check_nt_equality(clone_contig, transposed)
5953*da0073e9SAndroid Build Coastguard Worker
5954*da0073e9SAndroid Build Coastguard Worker        detached = transposed.detach()
5955*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(clone.is_contiguous())
5956*da0073e9SAndroid Build Coastguard Worker        check_nt_equality(detached, transposed)
5957*da0073e9SAndroid Build Coastguard Worker
5958*da0073e9SAndroid Build Coastguard Worker    def test_permute(self, device):
5959*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
5960*da0073e9SAndroid Build Coastguard Worker            [2, None, 3, 5], device, torch.float32, layout=torch.jagged
5961*da0073e9SAndroid Build Coastguard Worker        )
5962*da0073e9SAndroid Build Coastguard Worker        nt_shape = nt.shape
5963*da0073e9SAndroid Build Coastguard Worker        nt_inner_shape = nt.values().shape
5964*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5965*da0073e9SAndroid Build Coastguard Worker            ValueError,
5966*da0073e9SAndroid Build Coastguard Worker            r"permute\(\): number of dimensions in the tensor input \(4\) "
5967*da0073e9SAndroid Build Coastguard Worker            + r"does not match the length of the desired ordering of dimensions \(3\).",
5968*da0073e9SAndroid Build Coastguard Worker        ):
5969*da0073e9SAndroid Build Coastguard Worker            nt.permute(0, 2, 1)
5970*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5971*da0073e9SAndroid Build Coastguard Worker            ValueError, r"permute\(\): duplicate dims are not allowed."
5972*da0073e9SAndroid Build Coastguard Worker        ):
5973*da0073e9SAndroid Build Coastguard Worker            nt.permute(0, 2, -2, 3)
5974*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5975*da0073e9SAndroid Build Coastguard Worker            ValueError, "Permute is not supported on the batch dimension for jagged NT"
5976*da0073e9SAndroid Build Coastguard Worker        ):
5977*da0073e9SAndroid Build Coastguard Worker            nt.permute(1, 0, 2, 3)
5978*da0073e9SAndroid Build Coastguard Worker        nt_permute = nt.permute(0, 2, 1, -1)
5979*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5980*da0073e9SAndroid Build Coastguard Worker            nt_permute.shape, (nt_shape[0], nt_shape[2], nt_shape[1], nt_shape[3])
5981*da0073e9SAndroid Build Coastguard Worker        )
5982*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5983*da0073e9SAndroid Build Coastguard Worker            nt_permute.values().shape,
5984*da0073e9SAndroid Build Coastguard Worker            (nt_inner_shape[1], nt_inner_shape[0], nt_inner_shape[2]),
5985*da0073e9SAndroid Build Coastguard Worker        )
5986*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt_permute._ragged_idx, 2)
5987*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt_permute.permute(0, 2, 1, 3), nt)
5988*da0073e9SAndroid Build Coastguard Worker
5989*da0073e9SAndroid Build Coastguard Worker    def test_to_dtype(self, device):
5990*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
5991*da0073e9SAndroid Build Coastguard Worker            [2, None, 3], device, torch.float32, layout=torch.jagged
5992*da0073e9SAndroid Build Coastguard Worker        )
5993*da0073e9SAndroid Build Coastguard Worker        nt_after = nt.to(torch.float64)
5994*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.float32, nt.dtype)
5995*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.float64, nt_after.dtype)
5996*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.float64, nt_after.values().dtype)
5997*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.int64, nt_after.offsets().dtype)
5998*da0073e9SAndroid Build Coastguard Worker
5999*da0073e9SAndroid Build Coastguard Worker        noncontiguous_nt = nt.transpose(1, 2)
6000*da0073e9SAndroid Build Coastguard Worker        noncontiguous_nt_after = noncontiguous_nt.to(torch.bfloat16)
6001*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bfloat16, noncontiguous_nt_after.dtype)
6002*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bfloat16, noncontiguous_nt_after.values().dtype)
6003*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.int64, noncontiguous_nt_after.offsets().dtype)
6004*da0073e9SAndroid Build Coastguard Worker
6005*da0073e9SAndroid Build Coastguard Worker    def test_to_copy(self, device):
6006*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
6007*da0073e9SAndroid Build Coastguard Worker            [
6008*da0073e9SAndroid Build Coastguard Worker                torch.randn(
6009*da0073e9SAndroid Build Coastguard Worker                    i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device
6010*da0073e9SAndroid Build Coastguard Worker                )
6011*da0073e9SAndroid Build Coastguard Worker                for i in range(3)
6012*da0073e9SAndroid Build Coastguard Worker            ],
6013*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
6014*da0073e9SAndroid Build Coastguard Worker        )
6015*da0073e9SAndroid Build Coastguard Worker
6016*da0073e9SAndroid Build Coastguard Worker        nt_copy_dtype = torch.ops.aten._to_copy(nt, dtype=torch.float16)
6017*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.float16, nt_copy_dtype.dtype)
6018*da0073e9SAndroid Build Coastguard Worker
6019*da0073e9SAndroid Build Coastguard Worker        nt_t = nt.transpose(1, 2)
6020*da0073e9SAndroid Build Coastguard Worker        nt_t_copy_dtype = torch.ops.aten._to_copy(nt_t, dtype=torch.float16)
6021*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.float16, nt_t_copy_dtype.dtype)
6022*da0073e9SAndroid Build Coastguard Worker
6023*da0073e9SAndroid Build Coastguard Worker    def test_copy_(self, device):
6024*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 4], device=device)
6025*da0073e9SAndroid Build Coastguard Worker        a = torch.nested.nested_tensor_from_jagged(
6026*da0073e9SAndroid Build Coastguard Worker            torch.zeros(4, 3, device=device), offsets
6027*da0073e9SAndroid Build Coastguard Worker        )
6028*da0073e9SAndroid Build Coastguard Worker        b = torch.nested.nested_tensor_from_jagged(
6029*da0073e9SAndroid Build Coastguard Worker            torch.ones(4, 3, device=device), offsets
6030*da0073e9SAndroid Build Coastguard Worker        )
6031*da0073e9SAndroid Build Coastguard Worker        a.copy_(b)
6032*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.disable(self.assertEqual)(a, b)
6033*da0073e9SAndroid Build Coastguard Worker
6034*da0073e9SAndroid Build Coastguard Worker        offsets_2 = torch.tensor([0, 2, 4], device=device)
6035*da0073e9SAndroid Build Coastguard Worker        c = torch.nested.nested_tensor_from_jagged(
6036*da0073e9SAndroid Build Coastguard Worker            torch.ones(4, 3, device=device), offsets_2
6037*da0073e9SAndroid Build Coastguard Worker        )
6038*da0073e9SAndroid Build Coastguard Worker        # fail when tensors have the same size but not the exact same offset tensor.
6039*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
6040*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
6041*da0073e9SAndroid Build Coastguard Worker            "copy_ only supports Nested Tensors that have same size and the exact same offset tensor.",
6042*da0073e9SAndroid Build Coastguard Worker        ):
6043*da0073e9SAndroid Build Coastguard Worker            a.copy_(c)
6044*da0073e9SAndroid Build Coastguard Worker
6045*da0073e9SAndroid Build Coastguard Worker        # fail when tensors have different sizes
6046*da0073e9SAndroid Build Coastguard Worker        a = a.transpose(1, 2)
6047*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
6048*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
6049*da0073e9SAndroid Build Coastguard Worker            "copy_ only supports Nested Tensors that have same size and the exact same offset tensor.",
6050*da0073e9SAndroid Build Coastguard Worker        ):
6051*da0073e9SAndroid Build Coastguard Worker            a.copy_(b)
6052*da0073e9SAndroid Build Coastguard Worker
6053*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Dynamo doesn't know how to trace prof.events()")
6054*da0073e9SAndroid Build Coastguard Worker    def test_profiler_sequence_nr(self):
6055*da0073e9SAndroid Build Coastguard Worker        with torch.profiler.profile() as prof:
6056*da0073e9SAndroid Build Coastguard Worker            values = torch.randn(4, 6, requires_grad=True)
6057*da0073e9SAndroid Build Coastguard Worker            offsets = torch.tensor([0, 2, 4])
6058*da0073e9SAndroid Build Coastguard Worker            values = values * 2
6059*da0073e9SAndroid Build Coastguard Worker            l = torch.nn.Linear(6, 8)
6060*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
6061*da0073e9SAndroid Build Coastguard Worker
6062*da0073e9SAndroid Build Coastguard Worker            nt = l(nt)
6063*da0073e9SAndroid Build Coastguard Worker            val = nt.values()
6064*da0073e9SAndroid Build Coastguard Worker
6065*da0073e9SAndroid Build Coastguard Worker            loss = val.sum()
6066*da0073e9SAndroid Build Coastguard Worker            loss.backward()
6067*da0073e9SAndroid Build Coastguard Worker
6068*da0073e9SAndroid Build Coastguard Worker        fwd_seq_nrs = []
6069*da0073e9SAndroid Build Coastguard Worker        for evt in prof.events():
6070*da0073e9SAndroid Build Coastguard Worker            if (
6071*da0073e9SAndroid Build Coastguard Worker                "linear" in evt.name.lower()
6072*da0073e9SAndroid Build Coastguard Worker                and "backward" not in evt.name.lower()
6073*da0073e9SAndroid Build Coastguard Worker                and evt.sequence_nr != -1
6074*da0073e9SAndroid Build Coastguard Worker            ):
6075*da0073e9SAndroid Build Coastguard Worker                fwd_seq_nrs.append(evt.sequence_nr)
6076*da0073e9SAndroid Build Coastguard Worker
6077*da0073e9SAndroid Build Coastguard Worker        bwd_seq_nrs = []
6078*da0073e9SAndroid Build Coastguard Worker        for evt in prof.events():
6079*da0073e9SAndroid Build Coastguard Worker            if (
6080*da0073e9SAndroid Build Coastguard Worker                "linear" in evt.name.lower()
6081*da0073e9SAndroid Build Coastguard Worker                and "backward" in evt.name.lower()
6082*da0073e9SAndroid Build Coastguard Worker                and "evaluate_function" not in evt.name.lower()
6083*da0073e9SAndroid Build Coastguard Worker                and evt.sequence_nr != -1
6084*da0073e9SAndroid Build Coastguard Worker            ):
6085*da0073e9SAndroid Build Coastguard Worker                bwd_seq_nrs.append(evt.sequence_nr)
6086*da0073e9SAndroid Build Coastguard Worker
6087*da0073e9SAndroid Build Coastguard Worker        # There should only be one such event with a sequence number:
6088*da0073e9SAndroid Build Coastguard Worker        # the PythonTLSSnapshot event - but, note that it's not terrible if
6089*da0073e9SAndroid Build Coastguard Worker        # we end up with multiple events with the same sequence number - so we
6090*da0073e9SAndroid Build Coastguard Worker        # could relax this check if it becomes inconvenient to maintain this
6091*da0073e9SAndroid Build Coastguard Worker        # property.
6092*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(fwd_seq_nrs), 1)
6093*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(bwd_seq_nrs), 1)
6094*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fwd_seq_nrs[0], bwd_seq_nrs[0])
6095*da0073e9SAndroid Build Coastguard Worker
6096*da0073e9SAndroid Build Coastguard Worker    def test_is_same_size(self, device):
6097*da0073e9SAndroid Build Coastguard Worker        def get_3_tensors():
6098*da0073e9SAndroid Build Coastguard Worker            return [
6099*da0073e9SAndroid Build Coastguard Worker                torch.randn(
6100*da0073e9SAndroid Build Coastguard Worker                    i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device
6101*da0073e9SAndroid Build Coastguard Worker                )
6102*da0073e9SAndroid Build Coastguard Worker                for i in range(3)
6103*da0073e9SAndroid Build Coastguard Worker            ]
6104*da0073e9SAndroid Build Coastguard Worker
6105*da0073e9SAndroid Build Coastguard Worker        nt1, offsets1 = jagged_from_list(get_3_tensors(), None)
6106*da0073e9SAndroid Build Coastguard Worker        nt2, offsets1 = jagged_from_list(get_3_tensors(), offsets1)
6107*da0073e9SAndroid Build Coastguard Worker
6108*da0073e9SAndroid Build Coastguard Worker        nt3, offsets2 = jagged_from_list(get_3_tensors(), None)
6109*da0073e9SAndroid Build Coastguard Worker        nt4, offsets2 = jagged_from_list(get_3_tensors(), offsets2)
6110*da0073e9SAndroid Build Coastguard Worker
6111*da0073e9SAndroid Build Coastguard Worker        def check_size(nt1, nt2, nt3, nt4):
6112*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.ops.aten.is_same_size(nt1, nt2))
6113*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.ops.aten.is_same_size(nt3, nt4))
6114*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.ops.aten.is_same_size(nt1, nt3))
6115*da0073e9SAndroid Build Coastguard Worker
6116*da0073e9SAndroid Build Coastguard Worker        check_size(nt1, nt2, nt3, nt4)
6117*da0073e9SAndroid Build Coastguard Worker
6118*da0073e9SAndroid Build Coastguard Worker        nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4))
6119*da0073e9SAndroid Build Coastguard Worker        check_size(nt1_t, nt2_t, nt3_t, nt4_t)
6120*da0073e9SAndroid Build Coastguard Worker
6121*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compiles internally")
6122*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
6123*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6124*da0073e9SAndroid Build Coastguard Worker    def test_specialize_dynamic_shape(self, device):
6125*da0073e9SAndroid Build Coastguard Worker        values = torch.randn((18, 16), device=device)
6126*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 3, 6, 15, 18], device=device)
6127*da0073e9SAndroid Build Coastguard Worker        like_values = torch.randn_like(values)
6128*da0073e9SAndroid Build Coastguard Worker
6129*da0073e9SAndroid Build Coastguard Worker        # this marks values as dynamic
6130*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor_from_jagged(values, offsets)
6131*da0073e9SAndroid Build Coastguard Worker
6132*da0073e9SAndroid Build Coastguard Worker        def fn(values, same_size):
6133*da0073e9SAndroid Build Coastguard Worker            # here, the dynamic shape is specialized by same_size's shape
6134*da0073e9SAndroid Build Coastguard Worker            # https://github.com/pytorch/pytorch/issues/127097
6135*da0073e9SAndroid Build Coastguard Worker            # make sure this doesn't error out in torch.compile
6136*da0073e9SAndroid Build Coastguard Worker            return values + same_size
6137*da0073e9SAndroid Build Coastguard Worker
6138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
6139*da0073e9SAndroid Build Coastguard Worker            fn(values, like_values),
6140*da0073e9SAndroid Build Coastguard Worker            torch.compile(fn)(values, like_values),
6141*da0073e9SAndroid Build Coastguard Worker        )
6142*da0073e9SAndroid Build Coastguard Worker
6143*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compiles internally")
6144*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
6145*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6146*da0073e9SAndroid Build Coastguard Worker    def test_specialize_dynamic_shape_recompile(self, device):
6147*da0073e9SAndroid Build Coastguard Worker        def generate_inp(total_len):
6148*da0073e9SAndroid Build Coastguard Worker            values = torch.randn((total_len, 16), device=device)
6149*da0073e9SAndroid Build Coastguard Worker            offsets = torch.tensor([0, 2, 3, 6, 15, total_len], device=device)
6150*da0073e9SAndroid Build Coastguard Worker            like_values = torch.randn_like(values)
6151*da0073e9SAndroid Build Coastguard Worker            return values, offsets, like_values
6152*da0073e9SAndroid Build Coastguard Worker
6153*da0073e9SAndroid Build Coastguard Worker        def check_results(ref_fn, res_fn, args):
6154*da0073e9SAndroid Build Coastguard Worker            values, offsets, like_values = args
6155*da0073e9SAndroid Build Coastguard Worker            # this may add dynamic shape markings
6156*da0073e9SAndroid Build Coastguard Worker            # goal of this test is to make sure that whatever markings are there,
6157*da0073e9SAndroid Build Coastguard Worker            # we eventually stop recompiling as shape changes.
6158*da0073e9SAndroid Build Coastguard Worker            nt = torch.nested.nested_tensor_from_jagged(values, offsets)
6159*da0073e9SAndroid Build Coastguard Worker
6160*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref_fn(values, like_values), res_fn(values, like_values))
6161*da0073e9SAndroid Build Coastguard Worker
6162*da0073e9SAndroid Build Coastguard Worker        def fn(values, same_size):
6163*da0073e9SAndroid Build Coastguard Worker            return values + same_size
6164*da0073e9SAndroid Build Coastguard Worker
6165*da0073e9SAndroid Build Coastguard Worker        compile_counter = torch._dynamo.testing.CompileCounter()
6166*da0073e9SAndroid Build Coastguard Worker
6167*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(compile_counter, nopython=True)(fn)
6168*da0073e9SAndroid Build Coastguard Worker        check_results(fn, compiled_fn, generate_inp(18))
6169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compile_counter.frame_count, 1)
6170*da0073e9SAndroid Build Coastguard Worker
6171*da0073e9SAndroid Build Coastguard Worker        check_results(fn, compiled_fn, generate_inp(19))
6172*da0073e9SAndroid Build Coastguard Worker        # we'll probably recompile here with dynamic shapes - it's okay if not though.
6173*da0073e9SAndroid Build Coastguard Worker        frame_count_2 = compile_counter.frame_count
6174*da0073e9SAndroid Build Coastguard Worker        self.assertIn(frame_count_2, [1, 2])
6175*da0073e9SAndroid Build Coastguard Worker
6176*da0073e9SAndroid Build Coastguard Worker        # make sure that by now we've already compiled with dynamic shapes, so additional
6177*da0073e9SAndroid Build Coastguard Worker        # shapes should not trigger additional recompiles.
6178*da0073e9SAndroid Build Coastguard Worker        check_results(fn, compiled_fn, generate_inp(20))
6179*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compile_counter.frame_count, frame_count_2)
6180*da0073e9SAndroid Build Coastguard Worker
6181*da0073e9SAndroid Build Coastguard Worker    # Note 1: Math fallback doesn't work with bfloat16 on CUDA
6182*da0073e9SAndroid Build Coastguard Worker    # Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT
6183*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
6184*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_ROCM,
6185*da0073e9SAndroid Build Coastguard Worker        "ROCm doesn't support flash attention or mem_efficient attention for NT",
6186*da0073e9SAndroid Build Coastguard Worker    )
6187*da0073e9SAndroid Build Coastguard Worker    @dtypes(
6188*da0073e9SAndroid Build Coastguard Worker        *(
6189*da0073e9SAndroid Build Coastguard Worker            [torch.float16, torch.bfloat16, torch.float32]
6190*da0073e9SAndroid Build Coastguard Worker            if SM80OrLater
6191*da0073e9SAndroid Build Coastguard Worker            else [torch.float16, torch.float32]
6192*da0073e9SAndroid Build Coastguard Worker        )
6193*da0073e9SAndroid Build Coastguard Worker    )
6194*da0073e9SAndroid Build Coastguard Worker    def test_sdpa(self, device, dtype):
6195*da0073e9SAndroid Build Coastguard Worker        batch_size = 1
6196*da0073e9SAndroid Build Coastguard Worker        emb_dims = 128
6197*da0073e9SAndroid Build Coastguard Worker        n_heads = 8
6198*da0073e9SAndroid Build Coastguard Worker        head_dims = emb_dims // n_heads
6199*da0073e9SAndroid Build Coastguard Worker
6200*da0073e9SAndroid Build Coastguard Worker        sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device)
6201*da0073e9SAndroid Build Coastguard Worker        sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device)
6202*da0073e9SAndroid Build Coastguard Worker
6203*da0073e9SAndroid Build Coastguard Worker        query = torch.nn.Linear(
6204*da0073e9SAndroid Build Coastguard Worker            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
6205*da0073e9SAndroid Build Coastguard Worker        )
6206*da0073e9SAndroid Build Coastguard Worker        key = torch.nn.Linear(
6207*da0073e9SAndroid Build Coastguard Worker            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
6208*da0073e9SAndroid Build Coastguard Worker        )
6209*da0073e9SAndroid Build Coastguard Worker        value = torch.nn.Linear(
6210*da0073e9SAndroid Build Coastguard Worker            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
6211*da0073e9SAndroid Build Coastguard Worker        )
6212*da0073e9SAndroid Build Coastguard Worker
6213*da0073e9SAndroid Build Coastguard Worker        # Simplest case: 1 sentence, no batching
6214*da0073e9SAndroid Build Coastguard Worker        x_d1 = sen1.unsqueeze(0)
6215*da0073e9SAndroid Build Coastguard Worker        x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged)
6216*da0073e9SAndroid Build Coastguard Worker
6217*da0073e9SAndroid Build Coastguard Worker        # See note below for why we detach here.
6218*da0073e9SAndroid Build Coastguard Worker        q_d1 = (
6219*da0073e9SAndroid Build Coastguard Worker            query(x_d1)
6220*da0073e9SAndroid Build Coastguard Worker            .view(batch_size, -1, n_heads, head_dims)
6221*da0073e9SAndroid Build Coastguard Worker            .detach()
6222*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
6223*da0073e9SAndroid Build Coastguard Worker        )
6224*da0073e9SAndroid Build Coastguard Worker        q_d1_t = q_d1.transpose(1, 2)
6225*da0073e9SAndroid Build Coastguard Worker        k_d1 = (
6226*da0073e9SAndroid Build Coastguard Worker            key(x_d1)
6227*da0073e9SAndroid Build Coastguard Worker            .view(batch_size, -1, n_heads, head_dims)
6228*da0073e9SAndroid Build Coastguard Worker            .detach()
6229*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
6230*da0073e9SAndroid Build Coastguard Worker        )
6231*da0073e9SAndroid Build Coastguard Worker        k_d1_t = k_d1.transpose(1, 2)
6232*da0073e9SAndroid Build Coastguard Worker        v_d1 = (
6233*da0073e9SAndroid Build Coastguard Worker            value(x_d1)
6234*da0073e9SAndroid Build Coastguard Worker            .view(batch_size, -1, n_heads, head_dims)
6235*da0073e9SAndroid Build Coastguard Worker            .detach()
6236*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
6237*da0073e9SAndroid Build Coastguard Worker        )
6238*da0073e9SAndroid Build Coastguard Worker        v_d1_t = v_d1.transpose(1, 2)
6239*da0073e9SAndroid Build Coastguard Worker
6240*da0073e9SAndroid Build Coastguard Worker        q_nt = (
6241*da0073e9SAndroid Build Coastguard Worker            query(x_nt)
6242*da0073e9SAndroid Build Coastguard Worker            .view(*x_nt.size()[0:2], n_heads, head_dims)
6243*da0073e9SAndroid Build Coastguard Worker            .detach()
6244*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
6245*da0073e9SAndroid Build Coastguard Worker        )
6246*da0073e9SAndroid Build Coastguard Worker        q_nt_t = q_nt.transpose(1, 2)
6247*da0073e9SAndroid Build Coastguard Worker        k_nt = (
6248*da0073e9SAndroid Build Coastguard Worker            key(x_nt)
6249*da0073e9SAndroid Build Coastguard Worker            .view(*x_nt.size()[0:2], n_heads, head_dims)
6250*da0073e9SAndroid Build Coastguard Worker            .detach()
6251*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
6252*da0073e9SAndroid Build Coastguard Worker        )
6253*da0073e9SAndroid Build Coastguard Worker        k_nt_t = k_nt.transpose(1, 2)
6254*da0073e9SAndroid Build Coastguard Worker        v_nt = (
6255*da0073e9SAndroid Build Coastguard Worker            value(x_nt)
6256*da0073e9SAndroid Build Coastguard Worker            .view(*x_nt.size()[0:2], n_heads, head_dims)
6257*da0073e9SAndroid Build Coastguard Worker            .detach()
6258*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
6259*da0073e9SAndroid Build Coastguard Worker        )
6260*da0073e9SAndroid Build Coastguard Worker        v_nt_t = v_nt.transpose(1, 2)
6261*da0073e9SAndroid Build Coastguard Worker
6262*da0073e9SAndroid Build Coastguard Worker        # High Precision Math Reference
6263*da0073e9SAndroid Build Coastguard Worker        q_d1_f32 = q_d1.to(torch.float32)
6264*da0073e9SAndroid Build Coastguard Worker        k_d1_f32 = k_d1.to(torch.float32)
6265*da0073e9SAndroid Build Coastguard Worker        v_d1_f32 = v_d1.to(torch.float32)
6266*da0073e9SAndroid Build Coastguard Worker        q_d1_f32_t = q_d1_f32.transpose(1, 2)
6267*da0073e9SAndroid Build Coastguard Worker        k_d1_f32_t = k_d1_f32.transpose(1, 2)
6268*da0073e9SAndroid Build Coastguard Worker        v_d1_f32_t = v_d1_f32.transpose(1, 2)
6269*da0073e9SAndroid Build Coastguard Worker        out_ref = torch.ops.aten._scaled_dot_product_attention_math(
6270*da0073e9SAndroid Build Coastguard Worker            q_d1_f32_t, k_d1_f32_t, v_d1_f32_t
6271*da0073e9SAndroid Build Coastguard Worker        )[0]
6272*da0073e9SAndroid Build Coastguard Worker        grads_ref = torch.autograd.grad(out_ref.sum(), (q_d1_f32, k_d1_f32, v_d1_f32))
6273*da0073e9SAndroid Build Coastguard Worker
6274*da0073e9SAndroid Build Coastguard Worker        # Low Precision Math Reference
6275*da0073e9SAndroid Build Coastguard Worker        out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
6276*da0073e9SAndroid Build Coastguard Worker            q_d1_t, k_d1_t, v_d1_t
6277*da0073e9SAndroid Build Coastguard Worker        )[0]
6278*da0073e9SAndroid Build Coastguard Worker        grads_lp_ref = torch.autograd.grad(out_lp_ref.sum(), (q_d1, k_d1, v_d1))
6279*da0073e9SAndroid Build Coastguard Worker
6280*da0073e9SAndroid Build Coastguard Worker        # Compute tolerances
6281*da0073e9SAndroid Build Coastguard Worker        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
6282*da0073e9SAndroid Build Coastguard Worker        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(grads_ref[0], grads_lp_ref[0])
6283*da0073e9SAndroid Build Coastguard Worker        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(grads_ref[1], grads_lp_ref[1])
6284*da0073e9SAndroid Build Coastguard Worker        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(grads_ref[2], grads_lp_ref[2])
6285*da0073e9SAndroid Build Coastguard Worker        grad_atols = [grad_q_ref_atol, grad_k_ref_atol, grad_v_ref_atol]
6286*da0073e9SAndroid Build Coastguard Worker        grad_rtols = [grad_q_ref_rtol, grad_k_ref_rtol, grad_v_ref_rtol]
6287*da0073e9SAndroid Build Coastguard Worker
6288*da0073e9SAndroid Build Coastguard Worker        attn_d1 = torch.nn.functional.scaled_dot_product_attention(
6289*da0073e9SAndroid Build Coastguard Worker            q_d1_t, k_d1_t, v_d1_t
6290*da0073e9SAndroid Build Coastguard Worker        ).transpose(1, 2)
6291*da0073e9SAndroid Build Coastguard Worker        attn_nt = torch.nn.functional.scaled_dot_product_attention(
6292*da0073e9SAndroid Build Coastguard Worker            q_nt_t, k_nt_t, v_nt_t
6293*da0073e9SAndroid Build Coastguard Worker        ).transpose(1, 2)
6294*da0073e9SAndroid Build Coastguard Worker
6295*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
6296*da0073e9SAndroid Build Coastguard Worker            attn_d1,
6297*da0073e9SAndroid Build Coastguard Worker            attn_nt.unbind()[0].unsqueeze(0),
6298*da0073e9SAndroid Build Coastguard Worker            atol=output_ref_atol,
6299*da0073e9SAndroid Build Coastguard Worker            rtol=output_ref_rtol,
6300*da0073e9SAndroid Build Coastguard Worker        )
6301*da0073e9SAndroid Build Coastguard Worker
6302*da0073e9SAndroid Build Coastguard Worker        # Simple case: 2 sentences, no extra params
6303*da0073e9SAndroid Build Coastguard Worker        x_d2 = sen2.unsqueeze(0)
6304*da0073e9SAndroid Build Coastguard Worker        x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged)
6305*da0073e9SAndroid Build Coastguard Worker
6306*da0073e9SAndroid Build Coastguard Worker        # NB: we make sure the leaf tensor we compute gradients for is the view-ed tensor before
6307*da0073e9SAndroid Build Coastguard Worker        # it is transposed. This is because today we cannot backward through view or unbind a
6308*da0073e9SAndroid Build Coastguard Worker        # transposed tensor.
6309*da0073e9SAndroid Build Coastguard Worker        q_d2 = (
6310*da0073e9SAndroid Build Coastguard Worker            query(x_d2)
6311*da0073e9SAndroid Build Coastguard Worker            .view(batch_size, -1, n_heads, head_dims)
6312*da0073e9SAndroid Build Coastguard Worker            .detach()
6313*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
6314*da0073e9SAndroid Build Coastguard Worker        )
6315*da0073e9SAndroid Build Coastguard Worker        q_d2_t = q_d2.transpose(1, 2)
6316*da0073e9SAndroid Build Coastguard Worker        k_d2 = (
6317*da0073e9SAndroid Build Coastguard Worker            key(x_d2)
6318*da0073e9SAndroid Build Coastguard Worker            .view(batch_size, -1, n_heads, head_dims)
6319*da0073e9SAndroid Build Coastguard Worker            .detach()
6320*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
6321*da0073e9SAndroid Build Coastguard Worker        )
6322*da0073e9SAndroid Build Coastguard Worker        k_d2_t = k_d2.transpose(1, 2)
6323*da0073e9SAndroid Build Coastguard Worker        v_d2 = (
6324*da0073e9SAndroid Build Coastguard Worker            value(x_d2)
6325*da0073e9SAndroid Build Coastguard Worker            .view(batch_size, -1, n_heads, head_dims)
6326*da0073e9SAndroid Build Coastguard Worker            .detach()
6327*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
6328*da0073e9SAndroid Build Coastguard Worker        )
6329*da0073e9SAndroid Build Coastguard Worker        v_d2_t = v_d2.transpose(1, 2)
6330*da0073e9SAndroid Build Coastguard Worker
6331*da0073e9SAndroid Build Coastguard Worker        q_nt = (
6332*da0073e9SAndroid Build Coastguard Worker            query(x_nt)
6333*da0073e9SAndroid Build Coastguard Worker            .view(*x_nt.size()[0:2], n_heads, head_dims)
6334*da0073e9SAndroid Build Coastguard Worker            .detach()
6335*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
6336*da0073e9SAndroid Build Coastguard Worker        )
6337*da0073e9SAndroid Build Coastguard Worker        q_nt_t = q_nt.transpose(1, 2)
6338*da0073e9SAndroid Build Coastguard Worker        k_nt = (
6339*da0073e9SAndroid Build Coastguard Worker            key(x_nt)
6340*da0073e9SAndroid Build Coastguard Worker            .view(*x_nt.size()[0:2], n_heads, head_dims)
6341*da0073e9SAndroid Build Coastguard Worker            .detach()
6342*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
6343*da0073e9SAndroid Build Coastguard Worker        )
6344*da0073e9SAndroid Build Coastguard Worker        k_nt_t = k_nt.transpose(1, 2)
6345*da0073e9SAndroid Build Coastguard Worker        v_nt = (
6346*da0073e9SAndroid Build Coastguard Worker            value(x_nt)
6347*da0073e9SAndroid Build Coastguard Worker            .view(*x_nt.size()[0:2], n_heads, head_dims)
6348*da0073e9SAndroid Build Coastguard Worker            .detach()
6349*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
6350*da0073e9SAndroid Build Coastguard Worker        )
6351*da0073e9SAndroid Build Coastguard Worker        v_nt_t = v_nt.transpose(1, 2)
6352*da0073e9SAndroid Build Coastguard Worker
6353*da0073e9SAndroid Build Coastguard Worker        attn_d2 = torch.nn.functional.scaled_dot_product_attention(
6354*da0073e9SAndroid Build Coastguard Worker            q_d2_t, k_d2_t, v_d2_t
6355*da0073e9SAndroid Build Coastguard Worker        ).transpose(1, 2)
6356*da0073e9SAndroid Build Coastguard Worker        d1_grads = torch.autograd.grad(attn_d1.sum(), (q_d1, k_d1, v_d1))
6357*da0073e9SAndroid Build Coastguard Worker        d2_grads = torch.autograd.grad(attn_d2.sum(), (q_d2, k_d2, v_d2))
6358*da0073e9SAndroid Build Coastguard Worker
6359*da0073e9SAndroid Build Coastguard Worker        # Simple case 3: batch_size = 1, seq_len = 1
6360*da0073e9SAndroid Build Coastguard Worker        q_3 = torch.randn(1, 8, 16, dtype=dtype, device=device)
6361*da0073e9SAndroid Build Coastguard Worker        q_nt_3 = torch.nested.as_nested_tensor([q_3], layout=torch.jagged)
6362*da0073e9SAndroid Build Coastguard Worker        q_nt_3 = q_nt_3.transpose(1, 2)
6363*da0073e9SAndroid Build Coastguard Worker        attn_out = torch.nn.functional.scaled_dot_product_attention(
6364*da0073e9SAndroid Build Coastguard Worker            q_nt_3, q_nt_3, q_nt_3
6365*da0073e9SAndroid Build Coastguard Worker        )
6366*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(attn_out.shape, q_nt_3.shape)
6367*da0073e9SAndroid Build Coastguard Worker
6368*da0073e9SAndroid Build Coastguard Worker        def check_forward_backward():
6369*da0073e9SAndroid Build Coastguard Worker            attn_nt = torch.nn.functional.scaled_dot_product_attention(
6370*da0073e9SAndroid Build Coastguard Worker                q_nt_t, k_nt_t, v_nt_t
6371*da0073e9SAndroid Build Coastguard Worker            ).transpose(1, 2)
6372*da0073e9SAndroid Build Coastguard Worker
6373*da0073e9SAndroid Build Coastguard Worker            attn_nts = attn_nt.unbind()
6374*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
6375*da0073e9SAndroid Build Coastguard Worker                attn_d1,
6376*da0073e9SAndroid Build Coastguard Worker                attn_nts[0].unsqueeze(0),
6377*da0073e9SAndroid Build Coastguard Worker                atol=output_ref_atol,
6378*da0073e9SAndroid Build Coastguard Worker                rtol=output_ref_rtol,
6379*da0073e9SAndroid Build Coastguard Worker            )
6380*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
6381*da0073e9SAndroid Build Coastguard Worker                attn_d2,
6382*da0073e9SAndroid Build Coastguard Worker                attn_nts[1].unsqueeze(0),
6383*da0073e9SAndroid Build Coastguard Worker                atol=output_ref_atol,
6384*da0073e9SAndroid Build Coastguard Worker                rtol=output_ref_rtol,
6385*da0073e9SAndroid Build Coastguard Worker            )
6386*da0073e9SAndroid Build Coastguard Worker
6387*da0073e9SAndroid Build Coastguard Worker            nt_grads = torch.autograd.grad(attn_nt.values().sum(), (q_nt, k_nt, v_nt))
6388*da0073e9SAndroid Build Coastguard Worker            for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip(
6389*da0073e9SAndroid Build Coastguard Worker                nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols
6390*da0073e9SAndroid Build Coastguard Worker            ):
6391*da0073e9SAndroid Build Coastguard Worker                unbound_nt_grads = nt_grad.unbind()
6392*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
6393*da0073e9SAndroid Build Coastguard Worker                    d1_grad,
6394*da0073e9SAndroid Build Coastguard Worker                    unbound_nt_grads[0].unsqueeze(0),
6395*da0073e9SAndroid Build Coastguard Worker                    atol=grad_atol,
6396*da0073e9SAndroid Build Coastguard Worker                    rtol=grad_rtol,
6397*da0073e9SAndroid Build Coastguard Worker                )
6398*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
6399*da0073e9SAndroid Build Coastguard Worker                    d2_grad,
6400*da0073e9SAndroid Build Coastguard Worker                    unbound_nt_grads[1].unsqueeze(0),
6401*da0073e9SAndroid Build Coastguard Worker                    atol=grad_atol,
6402*da0073e9SAndroid Build Coastguard Worker                    rtol=grad_rtol,
6403*da0073e9SAndroid Build Coastguard Worker                )
6404*da0073e9SAndroid Build Coastguard Worker
6405*da0073e9SAndroid Build Coastguard Worker        # Default
6406*da0073e9SAndroid Build Coastguard Worker        check_forward_backward()
6407*da0073e9SAndroid Build Coastguard Worker
6408*da0073e9SAndroid Build Coastguard Worker        # Test dispatcher works by calling only mem-effn and math (as they are safe for all devices)
6409*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cuda.sdp_kernel(
6410*da0073e9SAndroid Build Coastguard Worker            enable_flash=False, enable_mem_efficient=True, enable_math=True
6411*da0073e9SAndroid Build Coastguard Worker        ):
6412*da0073e9SAndroid Build Coastguard Worker            check_forward_backward()
6413*da0073e9SAndroid Build Coastguard Worker
6414*da0073e9SAndroid Build Coastguard Worker        # Test math fallback
6415*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cuda.sdp_kernel(
6416*da0073e9SAndroid Build Coastguard Worker            enable_flash=False, enable_mem_efficient=False, enable_math=True
6417*da0073e9SAndroid Build Coastguard Worker        ):
6418*da0073e9SAndroid Build Coastguard Worker            # Math fallback doesn't work with bfloat16 on CUDA because
6419*da0073e9SAndroid Build Coastguard Worker            # "group_gemm_dispatch" not implemented for 'BFloat16'
6420*da0073e9SAndroid Build Coastguard Worker            if not (str(device).startswith("cuda") and dtype == torch.bfloat16):
6421*da0073e9SAndroid Build Coastguard Worker                check_forward_backward()
6422*da0073e9SAndroid Build Coastguard Worker
6423*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("SDPA test compiles internally")
6424*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
6425*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6426*da0073e9SAndroid Build Coastguard Worker    # Guarding with sqrt() doesn't work on ROCm?
6427*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm
6428*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
6429*da0073e9SAndroid Build Coastguard Worker    @dtypes(
6430*da0073e9SAndroid Build Coastguard Worker        *(
6431*da0073e9SAndroid Build Coastguard Worker            [torch.float16, torch.bfloat16, torch.float32]
6432*da0073e9SAndroid Build Coastguard Worker            if SM80OrLater
6433*da0073e9SAndroid Build Coastguard Worker            else [torch.float16, torch.float32]
6434*da0073e9SAndroid Build Coastguard Worker        )
6435*da0073e9SAndroid Build Coastguard Worker    )
6436*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_compile(self, device, dtype):
6437*da0073e9SAndroid Build Coastguard Worker        batch_size = 1
6438*da0073e9SAndroid Build Coastguard Worker        emb_dims = 1024
6439*da0073e9SAndroid Build Coastguard Worker        n_heads = 8
6440*da0073e9SAndroid Build Coastguard Worker        head_dims = emb_dims // n_heads
6441*da0073e9SAndroid Build Coastguard Worker
6442*da0073e9SAndroid Build Coastguard Worker        sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device)
6443*da0073e9SAndroid Build Coastguard Worker        sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device)
6444*da0073e9SAndroid Build Coastguard Worker
6445*da0073e9SAndroid Build Coastguard Worker        query = torch.nn.Linear(
6446*da0073e9SAndroid Build Coastguard Worker            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
6447*da0073e9SAndroid Build Coastguard Worker        )
6448*da0073e9SAndroid Build Coastguard Worker        key = torch.nn.Linear(
6449*da0073e9SAndroid Build Coastguard Worker            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
6450*da0073e9SAndroid Build Coastguard Worker        )
6451*da0073e9SAndroid Build Coastguard Worker        value = torch.nn.Linear(
6452*da0073e9SAndroid Build Coastguard Worker            emb_dims, emb_dims, bias=False, device=device, dtype=dtype
6453*da0073e9SAndroid Build Coastguard Worker        )
6454*da0073e9SAndroid Build Coastguard Worker
6455*da0073e9SAndroid Build Coastguard Worker        # Simplest case: 1 sentence, no batching
6456*da0073e9SAndroid Build Coastguard Worker        x_d1 = sen1.unsqueeze(0)
6457*da0073e9SAndroid Build Coastguard Worker        x_d2 = sen2.unsqueeze(0)
6458*da0073e9SAndroid Build Coastguard Worker        x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged)
6459*da0073e9SAndroid Build Coastguard Worker
6460*da0073e9SAndroid Build Coastguard Worker        q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
6461*da0073e9SAndroid Build Coastguard Worker        k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
6462*da0073e9SAndroid Build Coastguard Worker        v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
6463*da0073e9SAndroid Build Coastguard Worker        q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
6464*da0073e9SAndroid Build Coastguard Worker        k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
6465*da0073e9SAndroid Build Coastguard Worker        v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
6466*da0073e9SAndroid Build Coastguard Worker
6467*da0073e9SAndroid Build Coastguard Worker        q_nt = (
6468*da0073e9SAndroid Build Coastguard Worker            query(x_nt)
6469*da0073e9SAndroid Build Coastguard Worker            .view(*x_nt.size()[0:2], n_heads, head_dims)
6470*da0073e9SAndroid Build Coastguard Worker            .detach()
6471*da0073e9SAndroid Build Coastguard Worker            .transpose(1, 2)
6472*da0073e9SAndroid Build Coastguard Worker        )
6473*da0073e9SAndroid Build Coastguard Worker        k_nt = (
6474*da0073e9SAndroid Build Coastguard Worker            key(x_nt)
6475*da0073e9SAndroid Build Coastguard Worker            .view(*x_nt.size()[0:2], n_heads, head_dims)
6476*da0073e9SAndroid Build Coastguard Worker            .detach()
6477*da0073e9SAndroid Build Coastguard Worker            .transpose(1, 2)
6478*da0073e9SAndroid Build Coastguard Worker        )
6479*da0073e9SAndroid Build Coastguard Worker        v_nt = (
6480*da0073e9SAndroid Build Coastguard Worker            value(x_nt)
6481*da0073e9SAndroid Build Coastguard Worker            .view(*x_nt.size()[0:2], n_heads, head_dims)
6482*da0073e9SAndroid Build Coastguard Worker            .detach()
6483*da0073e9SAndroid Build Coastguard Worker            .transpose(1, 2)
6484*da0073e9SAndroid Build Coastguard Worker        )
6485*da0073e9SAndroid Build Coastguard Worker
6486*da0073e9SAndroid Build Coastguard Worker        # High Precision Math Reference
6487*da0073e9SAndroid Build Coastguard Worker        q_d1_f32 = q_d1.to(torch.float32)
6488*da0073e9SAndroid Build Coastguard Worker        k_d1_f32 = k_d1.to(torch.float32)
6489*da0073e9SAndroid Build Coastguard Worker        v_d1_f32 = v_d1.to(torch.float32)
6490*da0073e9SAndroid Build Coastguard Worker        out_ref = torch.ops.aten._scaled_dot_product_attention_math(
6491*da0073e9SAndroid Build Coastguard Worker            q_d1_f32, k_d1_f32, v_d1_f32
6492*da0073e9SAndroid Build Coastguard Worker        )[0]
6493*da0073e9SAndroid Build Coastguard Worker        # Low Precision Math Reference
6494*da0073e9SAndroid Build Coastguard Worker        out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
6495*da0073e9SAndroid Build Coastguard Worker            q_d1, k_d1, v_d1
6496*da0073e9SAndroid Build Coastguard Worker        )[0]
6497*da0073e9SAndroid Build Coastguard Worker        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
6498*da0073e9SAndroid Build Coastguard Worker
6499*da0073e9SAndroid Build Coastguard Worker        attn_d1 = torch.nn.functional.scaled_dot_product_attention(
6500*da0073e9SAndroid Build Coastguard Worker            q_d1, k_d1, v_d1
6501*da0073e9SAndroid Build Coastguard Worker        ).transpose(1, 2)
6502*da0073e9SAndroid Build Coastguard Worker        attn_d2 = torch.nn.functional.scaled_dot_product_attention(
6503*da0073e9SAndroid Build Coastguard Worker            q_d2, k_d2, v_d2
6504*da0073e9SAndroid Build Coastguard Worker        ).transpose(1, 2)
6505*da0073e9SAndroid Build Coastguard Worker
6506*da0073e9SAndroid Build Coastguard Worker        compiled_sdpa = torch.compile(torch.nn.functional.scaled_dot_product_attention)
6507*da0073e9SAndroid Build Coastguard Worker        attn_nt = compiled_sdpa(q_nt, k_nt, v_nt).transpose(1, 2)
6508*da0073e9SAndroid Build Coastguard Worker
6509*da0073e9SAndroid Build Coastguard Worker        attn_nts = attn_nt.unbind()
6510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
6511*da0073e9SAndroid Build Coastguard Worker            attn_d1,
6512*da0073e9SAndroid Build Coastguard Worker            attn_nts[0].unsqueeze(0),
6513*da0073e9SAndroid Build Coastguard Worker            atol=output_ref_atol,
6514*da0073e9SAndroid Build Coastguard Worker            rtol=output_ref_rtol,
6515*da0073e9SAndroid Build Coastguard Worker        )
6516*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
6517*da0073e9SAndroid Build Coastguard Worker            attn_d2,
6518*da0073e9SAndroid Build Coastguard Worker            attn_nts[1].unsqueeze(0),
6519*da0073e9SAndroid Build Coastguard Worker            atol=output_ref_atol,
6520*da0073e9SAndroid Build Coastguard Worker            rtol=output_ref_rtol,
6521*da0073e9SAndroid Build Coastguard Worker        )
6522*da0073e9SAndroid Build Coastguard Worker
6523*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.double, torch.half)
6524*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_with_constant_sequence_length(self, device, dtype):
6525*da0073e9SAndroid Build Coastguard Worker        # shape (B, P*, S, D)
6526*da0073e9SAndroid Build Coastguard Worker        # B: batch size
6527*da0073e9SAndroid Build Coastguard Worker        # P*: ragged number of prompts
6528*da0073e9SAndroid Build Coastguard Worker        # S: (constant) sequence length
6529*da0073e9SAndroid Build Coastguard Worker        # D: embedding size
6530*da0073e9SAndroid Build Coastguard Worker        query = random_nt_from_dims(
6531*da0073e9SAndroid Build Coastguard Worker            [4, None, 8, 10],
6532*da0073e9SAndroid Build Coastguard Worker            device=device,
6533*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
6534*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
6535*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
6536*da0073e9SAndroid Build Coastguard Worker        )
6537*da0073e9SAndroid Build Coastguard Worker        key = random_nt_from_similar(query)
6538*da0073e9SAndroid Build Coastguard Worker        value = random_nt_from_similar(query)
6539*da0073e9SAndroid Build Coastguard Worker        output = F.scaled_dot_product_attention(query, key, value)
6540*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(output, NestedTensor))
6541*da0073e9SAndroid Build Coastguard Worker        output.values().sum().backward()
6542*da0073e9SAndroid Build Coastguard Worker
6543*da0073e9SAndroid Build Coastguard Worker        query_dense = query.clone().detach().requires_grad_(True)
6544*da0073e9SAndroid Build Coastguard Worker        # should be equivalent to just running the buffers through
6545*da0073e9SAndroid Build Coastguard Worker        output_dense = F.scaled_dot_product_attention(
6546*da0073e9SAndroid Build Coastguard Worker            query_dense.values(), key.values(), value.values()
6547*da0073e9SAndroid Build Coastguard Worker        )
6548*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.disable(self.assertEqual)(output._values, output_dense)
6549*da0073e9SAndroid Build Coastguard Worker        output_dense.sum().backward()
6550*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.disable(self.assertEqual)(query.grad, query_dense.grad)
6551*da0073e9SAndroid Build Coastguard Worker
6552*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
6553*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
6554*da0073e9SAndroid Build Coastguard Worker        not PLATFORM_SUPPORTS_FUSED_ATTENTION,
6555*da0073e9SAndroid Build Coastguard Worker        "Platform doesn't support flash or mem-efficient attention",
6556*da0073e9SAndroid Build Coastguard Worker    )
6557*da0073e9SAndroid Build Coastguard Worker    @dtypes(
6558*da0073e9SAndroid Build Coastguard Worker        *(
6559*da0073e9SAndroid Build Coastguard Worker            [torch.float16, torch.bfloat16, torch.float32]
6560*da0073e9SAndroid Build Coastguard Worker            if SM80OrLater
6561*da0073e9SAndroid Build Coastguard Worker            else [torch.float16, torch.float32]
6562*da0073e9SAndroid Build Coastguard Worker        )
6563*da0073e9SAndroid Build Coastguard Worker    )
6564*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_with_packed_in_proj(self, device, dtype):
6565*da0073e9SAndroid Build Coastguard Worker        # shape (B, *, D)
6566*da0073e9SAndroid Build Coastguard Worker        input_packed = random_nt_from_dims(
6567*da0073e9SAndroid Build Coastguard Worker            [5, None, 10], device=device, dtype=dtype, layout=torch.jagged
6568*da0073e9SAndroid Build Coastguard Worker        )
6569*da0073e9SAndroid Build Coastguard Worker
6570*da0073e9SAndroid Build Coastguard Worker        # Do input projection.
6571*da0073e9SAndroid Build Coastguard Worker        num_heads = 2
6572*da0073e9SAndroid Build Coastguard Worker        # should be multiple of 4 for efficient kernels (e.g. flash / mem-efficient)
6573*da0073e9SAndroid Build Coastguard Worker        head_dim = 8
6574*da0073e9SAndroid Build Coastguard Worker        qkv_linear = torch.nn.Linear(10, num_heads * head_dim * 3).to(
6575*da0073e9SAndroid Build Coastguard Worker            device=device, dtype=dtype
6576*da0073e9SAndroid Build Coastguard Worker        )
6577*da0073e9SAndroid Build Coastguard Worker
6578*da0073e9SAndroid Build Coastguard Worker        def in_proj(input_packed, qkv_linear=qkv_linear):
6579*da0073e9SAndroid Build Coastguard Worker            qkv_post_proj = qkv_linear(input_packed)
6580*da0073e9SAndroid Build Coastguard Worker            # these are non-contiguous to trigger _is_safe_to_get_storage_as_tensor()
6581*da0073e9SAndroid Build Coastguard Worker            q, k, v = qkv_post_proj.chunk(3, dim=-1)
6582*da0073e9SAndroid Build Coastguard Worker            q = q.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
6583*da0073e9SAndroid Build Coastguard Worker            k = k.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
6584*da0073e9SAndroid Build Coastguard Worker            v = v.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
6585*da0073e9SAndroid Build Coastguard Worker            return q, k, v
6586*da0073e9SAndroid Build Coastguard Worker
6587*da0073e9SAndroid Build Coastguard Worker        q, k, v = in_proj(input_packed)
6588*da0073e9SAndroid Build Coastguard Worker        output = F.scaled_dot_product_attention(q, k, v, attn_mask=None)
6589*da0073e9SAndroid Build Coastguard Worker
6590*da0073e9SAndroid Build Coastguard Worker        # compare to individually running unbound components through
6591*da0073e9SAndroid Build Coastguard Worker        for in_component, out_component in zip(
6592*da0073e9SAndroid Build Coastguard Worker            input_packed.unbind(), output.transpose(-2, -3).unbind()
6593*da0073e9SAndroid Build Coastguard Worker        ):
6594*da0073e9SAndroid Build Coastguard Worker            q, k, v = in_proj(in_component)
6595*da0073e9SAndroid Build Coastguard Worker            out = F.scaled_dot_product_attention(q, k, v).transpose(-2, -3)
6596*da0073e9SAndroid Build Coastguard Worker
6597*da0073e9SAndroid Build Coastguard Worker            # Low Precision Math Reference
6598*da0073e9SAndroid Build Coastguard Worker            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q, k, v)[
6599*da0073e9SAndroid Build Coastguard Worker                0
6600*da0073e9SAndroid Build Coastguard Worker            ].transpose(-2, -3)
6601*da0073e9SAndroid Build Coastguard Worker            output_ref_atol, output_ref_rtol = get_tolerances(
6602*da0073e9SAndroid Build Coastguard Worker                out, out_lp_ref, fudge_factor=2
6603*da0073e9SAndroid Build Coastguard Worker            )
6604*da0073e9SAndroid Build Coastguard Worker
6605*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
6606*da0073e9SAndroid Build Coastguard Worker                out, out_component, atol=output_ref_atol, rtol=output_ref_rtol
6607*da0073e9SAndroid Build Coastguard Worker            )
6608*da0073e9SAndroid Build Coastguard Worker
6609*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("SDPA test compiles internally")
6610*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
6611*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6612*da0073e9SAndroid Build Coastguard Worker    # mha_varlen_fwd not supported on ROCm
6613*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm
6614*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
6615*da0073e9SAndroid Build Coastguard Worker    @dtypes(
6616*da0073e9SAndroid Build Coastguard Worker        *(
6617*da0073e9SAndroid Build Coastguard Worker            [torch.float16, torch.bfloat16, torch.float32]
6618*da0073e9SAndroid Build Coastguard Worker            if SM80OrLater
6619*da0073e9SAndroid Build Coastguard Worker            else [torch.float16, torch.float32]
6620*da0073e9SAndroid Build Coastguard Worker        )
6621*da0073e9SAndroid Build Coastguard Worker    )
6622*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_backwards(self, device, dtype):
6623*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(9, 3, 256, requires_grad=True, device=device, dtype=dtype)
6624*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 1, 3, 5, 9], device=device, dtype=torch.int64)
6625*da0073e9SAndroid Build Coastguard Worker
6626*da0073e9SAndroid Build Coastguard Worker        @torch.compile
6627*da0073e9SAndroid Build Coastguard Worker        def f(values, offsets):
6628*da0073e9SAndroid Build Coastguard Worker            nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
6629*da0073e9SAndroid Build Coastguard Worker            nt = nt.transpose(-2, -3)
6630*da0073e9SAndroid Build Coastguard Worker            # purposefully graph break to trigger view replay for subclass view input
6631*da0073e9SAndroid Build Coastguard Worker            torch.tensor(1).item()
6632*da0073e9SAndroid Build Coastguard Worker            output = F.scaled_dot_product_attention(nt, nt, nt).transpose(-2, -3)
6633*da0073e9SAndroid Build Coastguard Worker            return convert_nt_to_jagged(output)
6634*da0073e9SAndroid Build Coastguard Worker
6635*da0073e9SAndroid Build Coastguard Worker        output = f(values, offsets)
6636*da0073e9SAndroid Build Coastguard Worker        output.sum().backward()
6637*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(values.grad, torch.ones_like(values))
6638*da0073e9SAndroid Build Coastguard Worker
6639*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
6640*da0073e9SAndroid Build Coastguard Worker        not PLATFORM_SUPPORTS_FUSED_ATTENTION,
6641*da0073e9SAndroid Build Coastguard Worker        "Platform doesn't support flash or mem-efficient attention",
6642*da0073e9SAndroid Build Coastguard Worker    )
6643*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6644*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm
6645*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
6646*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
6647*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
6648*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_autocast(self, device):
6649*da0073e9SAndroid Build Coastguard Worker        def fn_nt(values32, values16, offsets):
6650*da0073e9SAndroid Build Coastguard Worker            nt32 = convert_jagged_to_nested_tensor(values32, offsets, max_length=16)
6651*da0073e9SAndroid Build Coastguard Worker            nt16 = convert_jagged_to_nested_tensor(values16, offsets, max_length=16)
6652*da0073e9SAndroid Build Coastguard Worker            nt32 = nt32.transpose(1, 2)
6653*da0073e9SAndroid Build Coastguard Worker            nt16 = nt16.transpose(1, 2)
6654*da0073e9SAndroid Build Coastguard Worker            return F.scaled_dot_product_attention(nt32, nt16, nt32)
6655*da0073e9SAndroid Build Coastguard Worker
6656*da0073e9SAndroid Build Coastguard Worker        def fn_dense(x32, x16):
6657*da0073e9SAndroid Build Coastguard Worker            x32 = x32.view(8, 16, 4, 16).transpose(1, 2)
6658*da0073e9SAndroid Build Coastguard Worker            x16 = x16.view(8, 16, 4, 16).transpose(1, 2)
6659*da0073e9SAndroid Build Coastguard Worker            return F.scaled_dot_product_attention(x32, x16, x32)
6660*da0073e9SAndroid Build Coastguard Worker
6661*da0073e9SAndroid Build Coastguard Worker        values32 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float32)
6662*da0073e9SAndroid Build Coastguard Worker        values16 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float16)
6663*da0073e9SAndroid Build Coastguard Worker        offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32)
6664*da0073e9SAndroid Build Coastguard Worker
6665*da0073e9SAndroid Build Coastguard Worker        x32 = values32.clone()
6666*da0073e9SAndroid Build Coastguard Worker        x16 = values16.clone()
6667*da0073e9SAndroid Build Coastguard Worker
6668*da0073e9SAndroid Build Coastguard Worker        with torch.autocast(device_type="cuda", dtype=torch.float16):
6669*da0073e9SAndroid Build Coastguard Worker            out_dense_eager = fn_dense(x32, x16)
6670*da0073e9SAndroid Build Coastguard Worker            out_dense_compiled = torch.compile(fn_dense)(x32, x16)
6671*da0073e9SAndroid Build Coastguard Worker            out_nt_eager = fn_nt(values32, values16, offsets)
6672*da0073e9SAndroid Build Coastguard Worker            out_nt_compiled = torch.compile(fn_nt)(values32, values16, offsets)
6673*da0073e9SAndroid Build Coastguard Worker
6674*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_dense_eager, out_dense_compiled)
6675*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
6676*da0073e9SAndroid Build Coastguard Worker            out_dense_eager.transpose(1, 2),
6677*da0073e9SAndroid Build Coastguard Worker            out_nt_eager.values().transpose(0, 1).view(8, 16, 4, 16),
6678*da0073e9SAndroid Build Coastguard Worker        )
6679*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
6680*da0073e9SAndroid Build Coastguard Worker            out_dense_eager.transpose(1, 2),
6681*da0073e9SAndroid Build Coastguard Worker            out_nt_compiled.values().transpose(0, 1).view(8, 16, 4, 16),
6682*da0073e9SAndroid Build Coastguard Worker        )
6683*da0073e9SAndroid Build Coastguard Worker
6684*da0073e9SAndroid Build Coastguard Worker        def get_values():
6685*da0073e9SAndroid Build Coastguard Worker            return tuple(
6686*da0073e9SAndroid Build Coastguard Worker                x.clone().detach().requires_grad_(True) for x in (values32, values16)
6687*da0073e9SAndroid Build Coastguard Worker            )
6688*da0073e9SAndroid Build Coastguard Worker
6689*da0073e9SAndroid Build Coastguard Worker        v32_dense_eager, v16_dense_eager = get_values()
6690*da0073e9SAndroid Build Coastguard Worker        v32_dense_compile, v16_dense_compile = get_values()
6691*da0073e9SAndroid Build Coastguard Worker        v32_nt_eager, v16_nt_eager = get_values()
6692*da0073e9SAndroid Build Coastguard Worker        v32_nt_compile, v16_nt_compile = get_values()
6693*da0073e9SAndroid Build Coastguard Worker
6694*da0073e9SAndroid Build Coastguard Worker        with torch.autocast(device_type="cuda", dtype=torch.float16):
6695*da0073e9SAndroid Build Coastguard Worker            loss_dense_eager = fn_dense(v32_dense_eager, v16_dense_eager).sum()
6696*da0073e9SAndroid Build Coastguard Worker            loss_dense_compile = torch.compile(fn_dense)(
6697*da0073e9SAndroid Build Coastguard Worker                v32_dense_compile, v16_dense_compile
6698*da0073e9SAndroid Build Coastguard Worker            ).sum()
6699*da0073e9SAndroid Build Coastguard Worker            loss_nt_eager = fn_nt(v32_nt_eager, v16_nt_eager, offsets).values().sum()
6700*da0073e9SAndroid Build Coastguard Worker            loss_nt_compile = (
6701*da0073e9SAndroid Build Coastguard Worker                torch.compile(fn_nt)(v32_nt_compile, v16_nt_compile, offsets)
6702*da0073e9SAndroid Build Coastguard Worker                .values()
6703*da0073e9SAndroid Build Coastguard Worker                .sum()
6704*da0073e9SAndroid Build Coastguard Worker            )
6705*da0073e9SAndroid Build Coastguard Worker
6706*da0073e9SAndroid Build Coastguard Worker        loss_dense_eager.backward()
6707*da0073e9SAndroid Build Coastguard Worker        loss_dense_compile.backward()
6708*da0073e9SAndroid Build Coastguard Worker        loss_nt_eager.backward()
6709*da0073e9SAndroid Build Coastguard Worker        loss_nt_compile.backward()
6710*da0073e9SAndroid Build Coastguard Worker
6711*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v32_dense_eager.grad, v32_dense_compile.grad)
6712*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v32_dense_eager.grad, v32_nt_eager.grad)
6713*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v32_dense_eager.grad, v32_nt_compile.grad)
6714*da0073e9SAndroid Build Coastguard Worker
6715*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v16_dense_eager.grad, v16_dense_compile.grad)
6716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v16_dense_eager.grad, v16_nt_eager.grad)
6717*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v16_dense_eager.grad, v16_nt_compile.grad)
6718*da0073e9SAndroid Build Coastguard Worker
6719*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
6720*da0073e9SAndroid Build Coastguard Worker        not PLATFORM_SUPPORTS_FUSED_ATTENTION,
6721*da0073e9SAndroid Build Coastguard Worker        "Platform doesn't support flash or mem-efficient attention",
6722*da0073e9SAndroid Build Coastguard Worker    )
6723*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6724*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm
6725*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
6726*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
6727*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_flop_counter(self, device):
6728*da0073e9SAndroid Build Coastguard Worker        from torch.utils.flop_counter import FlopCounterMode
6729*da0073e9SAndroid Build Coastguard Worker
6730*da0073e9SAndroid Build Coastguard Worker        def get_flops(nt):
6731*da0073e9SAndroid Build Coastguard Worker            flop_counter = FlopCounterMode(display=False)
6732*da0073e9SAndroid Build Coastguard Worker            with flop_counter:
6733*da0073e9SAndroid Build Coastguard Worker                ret = torch.nn.functional.scaled_dot_product_attention(nt, nt, nt)
6734*da0073e9SAndroid Build Coastguard Worker                ret.values().sum().backward()
6735*da0073e9SAndroid Build Coastguard Worker            return flop_counter.get_total_flops()
6736*da0073e9SAndroid Build Coastguard Worker
6737*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(
6738*da0073e9SAndroid Build Coastguard Worker            (8 * 16, 4, 16), requires_grad=True, device=device, dtype=torch.float16
6739*da0073e9SAndroid Build Coastguard Worker        )
6740*da0073e9SAndroid Build Coastguard Worker        offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32)
6741*da0073e9SAndroid Build Coastguard Worker        nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16)
6742*da0073e9SAndroid Build Coastguard Worker
6743*da0073e9SAndroid Build Coastguard Worker        values_meta = torch.randn(
6744*da0073e9SAndroid Build Coastguard Worker            (8 * 16, 4, 16), requires_grad=True, device="meta", dtype=torch.float16
6745*da0073e9SAndroid Build Coastguard Worker        )
6746*da0073e9SAndroid Build Coastguard Worker        offsets_meta = torch.arange(0, 8 * 16 + 1, 16, device="meta", dtype=torch.int32)
6747*da0073e9SAndroid Build Coastguard Worker        nt_meta = convert_jagged_to_nested_tensor(values, offsets, max_length=16)
6748*da0073e9SAndroid Build Coastguard Worker
6749*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(get_flops(nt), get_flops(nt_meta))
6750*da0073e9SAndroid Build Coastguard Worker
6751*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
6752*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_activation_checkpoint(self, device):
6753*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(
6754*da0073e9SAndroid Build Coastguard Worker            9, 3, 256, requires_grad=True, device=device, dtype=torch.float32
6755*da0073e9SAndroid Build Coastguard Worker        )
6756*da0073e9SAndroid Build Coastguard Worker        lengths = torch.tensor([1, 2, 3, 3], device=device, dtype=torch.int64)
6757*da0073e9SAndroid Build Coastguard Worker        offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0)
6758*da0073e9SAndroid Build Coastguard Worker
6759*da0073e9SAndroid Build Coastguard Worker        def fn(values, offsets):
6760*da0073e9SAndroid Build Coastguard Worker            nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
6761*da0073e9SAndroid Build Coastguard Worker            return convert_nt_to_jagged(nt).sum()
6762*da0073e9SAndroid Build Coastguard Worker
6763*da0073e9SAndroid Build Coastguard Worker        checkpoint(fn, values, offsets, use_reentrant=False).backward()
6764*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(values.grad)
6765*da0073e9SAndroid Build Coastguard Worker
6766*da0073e9SAndroid Build Coastguard Worker        context_fn = partial(
6767*da0073e9SAndroid Build Coastguard Worker            create_selective_checkpoint_contexts, [torch.ops.aten.cumsum.default]
6768*da0073e9SAndroid Build Coastguard Worker        )
6769*da0073e9SAndroid Build Coastguard Worker
6770*da0073e9SAndroid Build Coastguard Worker        values.grad = None
6771*da0073e9SAndroid Build Coastguard Worker
6772*da0073e9SAndroid Build Coastguard Worker        def fn(values, lengths):
6773*da0073e9SAndroid Build Coastguard Worker            offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0)
6774*da0073e9SAndroid Build Coastguard Worker            nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
6775*da0073e9SAndroid Build Coastguard Worker            return convert_nt_to_jagged(nt).sum()
6776*da0073e9SAndroid Build Coastguard Worker
6777*da0073e9SAndroid Build Coastguard Worker        checkpoint(
6778*da0073e9SAndroid Build Coastguard Worker            fn, values, lengths, use_reentrant=False, context_fn=context_fn
6779*da0073e9SAndroid Build Coastguard Worker        ).backward()
6780*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(values.grad)
6781*da0073e9SAndroid Build Coastguard Worker
6782*da0073e9SAndroid Build Coastguard Worker    # Internally-defined NT use cases are lifted to here for maximum test realism.
6783*da0073e9SAndroid Build Coastguard Worker    # TODO: Remove these when ViewNestedFromBuffer, etc. are deprecated.
6784*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm  # not needed
6785*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compiles internally")
6786*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
6787*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6788*da0073e9SAndroid Build Coastguard Worker    @parametrize("use_legacy_api", [True, False])
6789*da0073e9SAndroid Build Coastguard Worker    @skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644")
6790*da0073e9SAndroid Build Coastguard Worker    def test_dummy_mha_with_nt(self, device, use_legacy_api):
6791*da0073e9SAndroid Build Coastguard Worker        bs = 3
6792*da0073e9SAndroid Build Coastguard Worker        d1 = 2
6793*da0073e9SAndroid Build Coastguard Worker        d2 = 4
6794*da0073e9SAndroid Build Coastguard Worker        d3 = 16
6795*da0073e9SAndroid Build Coastguard Worker        n_heads = 2
6796*da0073e9SAndroid Build Coastguard Worker        d_head = d3 // n_heads
6797*da0073e9SAndroid Build Coastguard Worker        max_length_1 = 10
6798*da0073e9SAndroid Build Coastguard Worker        max_length_2 = 20
6799*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(0)
6800*da0073e9SAndroid Build Coastguard Worker
6801*da0073e9SAndroid Build Coastguard Worker        class mha(torch.nn.Module):
6802*da0073e9SAndroid Build Coastguard Worker            def __init__(self, use_legacy_api) -> None:
6803*da0073e9SAndroid Build Coastguard Worker                super().__init__()
6804*da0073e9SAndroid Build Coastguard Worker                torch.manual_seed(0)
6805*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(d2, d3, device=device)
6806*da0073e9SAndroid Build Coastguard Worker                self.use_legacy_api = use_legacy_api
6807*da0073e9SAndroid Build Coastguard Worker
6808*da0073e9SAndroid Build Coastguard Worker            def forward(self, query, value, offsets):
6809*da0073e9SAndroid Build Coastguard Worker                value = self.linear(value)
6810*da0073e9SAndroid Build Coastguard Worker                if self.use_legacy_api:
6811*da0073e9SAndroid Build Coastguard Worker                    key = convert_jagged_to_nested_tensor_legacy(
6812*da0073e9SAndroid Build Coastguard Worker                        value, offsets, max_length_1
6813*da0073e9SAndroid Build Coastguard Worker                    )
6814*da0073e9SAndroid Build Coastguard Worker                    value = convert_jagged_to_nested_tensor_legacy(
6815*da0073e9SAndroid Build Coastguard Worker                        value, offsets, max_length_2
6816*da0073e9SAndroid Build Coastguard Worker                    )
6817*da0073e9SAndroid Build Coastguard Worker                    query = convert_dense_to_nested_tensor_legacy(query)
6818*da0073e9SAndroid Build Coastguard Worker                else:
6819*da0073e9SAndroid Build Coastguard Worker                    key = convert_jagged_to_nested_tensor(value, offsets, max_length_1)
6820*da0073e9SAndroid Build Coastguard Worker                    value = convert_jagged_to_nested_tensor(
6821*da0073e9SAndroid Build Coastguard Worker                        value, offsets, max_length_2
6822*da0073e9SAndroid Build Coastguard Worker                    )
6823*da0073e9SAndroid Build Coastguard Worker                    query = convert_dense_to_nested_tensor(query)
6824*da0073e9SAndroid Build Coastguard Worker                q = query.view(bs, -1, n_heads, d_head).transpose(1, 2)
6825*da0073e9SAndroid Build Coastguard Worker                k = key.view(bs, -1, n_heads, d_head).transpose(1, 2)
6826*da0073e9SAndroid Build Coastguard Worker                v = value.view(bs, -1, n_heads, d_head).transpose(1, 2)
6827*da0073e9SAndroid Build Coastguard Worker
6828*da0073e9SAndroid Build Coastguard Worker                with torch.nn.attention.sdpa_kernel(
6829*da0073e9SAndroid Build Coastguard Worker                    [
6830*da0073e9SAndroid Build Coastguard Worker                        torch.nn.attention.SDPBackend.FLASH_ATTENTION,
6831*da0073e9SAndroid Build Coastguard Worker                        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
6832*da0073e9SAndroid Build Coastguard Worker                    ]
6833*da0073e9SAndroid Build Coastguard Worker                ):
6834*da0073e9SAndroid Build Coastguard Worker                    attn_output = torch.nn.functional.scaled_dot_product_attention(
6835*da0073e9SAndroid Build Coastguard Worker                        q,
6836*da0073e9SAndroid Build Coastguard Worker                        k,
6837*da0073e9SAndroid Build Coastguard Worker                        v,
6838*da0073e9SAndroid Build Coastguard Worker                        attn_mask=None,
6839*da0073e9SAndroid Build Coastguard Worker                        dropout_p=0.0,
6840*da0073e9SAndroid Build Coastguard Worker                        is_causal=False,
6841*da0073e9SAndroid Build Coastguard Worker                    )
6842*da0073e9SAndroid Build Coastguard Worker                attn_output = attn_output.transpose(1, 2)
6843*da0073e9SAndroid Build Coastguard Worker                if self.use_legacy_api:
6844*da0073e9SAndroid Build Coastguard Worker                    attn_output = convert_nt_to_jagged_legacy(attn_output)
6845*da0073e9SAndroid Build Coastguard Worker                else:
6846*da0073e9SAndroid Build Coastguard Worker                    attn_output = convert_nt_to_jagged(attn_output)
6847*da0073e9SAndroid Build Coastguard Worker                return attn_output, key._max_seqlen, value._max_seqlen
6848*da0073e9SAndroid Build Coastguard Worker
6849*da0073e9SAndroid Build Coastguard Worker        query = torch.rand(bs, d1, d3, device=device)
6850*da0073e9SAndroid Build Coastguard Worker        value = torch.rand(30, d2, requires_grad=True, device=device)
6851*da0073e9SAndroid Build Coastguard Worker        # total_length must > than max_length otherwise flash_attn backwark will fail
6852*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 3, 30], device=device)
6853*da0073e9SAndroid Build Coastguard Worker
6854*da0073e9SAndroid Build Coastguard Worker        m = mha(use_legacy_api)
6855*da0073e9SAndroid Build Coastguard Worker        symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m)
6856*da0073e9SAndroid Build Coastguard Worker        m = torch.compile(symbolic_traced)
6857*da0073e9SAndroid Build Coastguard Worker        attn_output, cached_key_max_seqlen, cached_value_max_seqlen = m(
6858*da0073e9SAndroid Build Coastguard Worker            query, value, offsets
6859*da0073e9SAndroid Build Coastguard Worker        )
6860*da0073e9SAndroid Build Coastguard Worker        loss = attn_output.sum()
6861*da0073e9SAndroid Build Coastguard Worker        # Check that NT can be fx traced and torch.compile, and backward works
6862*da0073e9SAndroid Build Coastguard Worker        loss.backward()
6863*da0073e9SAndroid Build Coastguard Worker
6864*da0073e9SAndroid Build Coastguard Worker        # Check that value.requires_grad is not lost after tracing and compiling
6865*da0073e9SAndroid Build Coastguard Worker        value_grad = value.grad  # save for comparison later
6866*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(value_grad)
6867*da0073e9SAndroid Build Coastguard Worker        # check that max_seqlen is cached properly
6868*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cached_key_max_seqlen, max_length_1)
6869*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cached_value_max_seqlen, max_length_2)
6870*da0073e9SAndroid Build Coastguard Worker
6871*da0073e9SAndroid Build Coastguard Worker        # check if the output is numerically equivalent with the eager mode
6872*da0073e9SAndroid Build Coastguard Worker        m_eager = mha(use_legacy_api)
6873*da0073e9SAndroid Build Coastguard Worker
6874*da0073e9SAndroid Build Coastguard Worker        value.grad = None
6875*da0073e9SAndroid Build Coastguard Worker        attn_output_eager, _, _ = m_eager(query, value, offsets)
6876*da0073e9SAndroid Build Coastguard Worker        attn_output_eager.sum().backward()
6877*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(attn_output_eager, attn_output))
6878*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(value_grad, value.grad))
6879*da0073e9SAndroid Build Coastguard Worker
6880*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
6881*da0073e9SAndroid Build Coastguard Worker    def test_apply_(self, device, dtype):
6882*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
6883*da0073e9SAndroid Build Coastguard Worker            [5, None, 10],
6884*da0073e9SAndroid Build Coastguard Worker            device=device,
6885*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
6886*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
6887*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
6888*da0073e9SAndroid Build Coastguard Worker        )
6889*da0073e9SAndroid Build Coastguard Worker
6890*da0073e9SAndroid Build Coastguard Worker        def f(x):
6891*da0073e9SAndroid Build Coastguard Worker            return x * 2
6892*da0073e9SAndroid Build Coastguard Worker
6893*da0073e9SAndroid Build Coastguard Worker        if device != "cpu":
6894*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
6895*da0073e9SAndroid Build Coastguard Worker                TypeError, "apply_ is only implemented on CPU tensors"
6896*da0073e9SAndroid Build Coastguard Worker            ):
6897*da0073e9SAndroid Build Coastguard Worker                nt.apply_(f)
6898*da0073e9SAndroid Build Coastguard Worker            return
6899*da0073e9SAndroid Build Coastguard Worker
6900*da0073e9SAndroid Build Coastguard Worker        before = nt._values.clone().detach()
6901*da0073e9SAndroid Build Coastguard Worker
6902*da0073e9SAndroid Build Coastguard Worker        nt.apply_(f)
6903*da0073e9SAndroid Build Coastguard Worker        expected = f(before)
6904*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, nt._values)
6905*da0073e9SAndroid Build Coastguard Worker        # apply_ should swap values in-place without appending to autograd graph
6906*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(nt.grad)
6907*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(nt._values.grad_fn)
6908*da0073e9SAndroid Build Coastguard Worker
6909*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float64, torch.float32, torch.half)
6910*da0073e9SAndroid Build Coastguard Worker    def test_jagged_padded_dense_conversion_kernels(self, device, dtype):
6911*da0073e9SAndroid Build Coastguard Worker        values = torch.randn(10, 5, device=device, dtype=dtype)
6912*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 1, 3, 8, 10], device=device, dtype=torch.int64)
6913*da0073e9SAndroid Build Coastguard Worker        max_length = offsets.diff().max().item()
6914*da0073e9SAndroid Build Coastguard Worker        padding_value = 1.3
6915*da0073e9SAndroid Build Coastguard Worker
6916*da0073e9SAndroid Build Coastguard Worker        # convert jagged -> padded dense
6917*da0073e9SAndroid Build Coastguard Worker        padded = torch.ops.aten._jagged_to_padded_dense_forward(
6918*da0073e9SAndroid Build Coastguard Worker            values, [offsets], [max_length], padding_value
6919*da0073e9SAndroid Build Coastguard Worker        )
6920*da0073e9SAndroid Build Coastguard Worker
6921*da0073e9SAndroid Build Coastguard Worker        batch_size = offsets.shape[0] - 1
6922*da0073e9SAndroid Build Coastguard Worker        expected_padded_shape = (batch_size, max_length, values.shape[-1])
6923*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(padded.shape, expected_padded_shape)
6924*da0073e9SAndroid Build Coastguard Worker
6925*da0073e9SAndroid Build Coastguard Worker        # convert padded dense -> jagged
6926*da0073e9SAndroid Build Coastguard Worker        total_L = values.shape[0]
6927*da0073e9SAndroid Build Coastguard Worker        output_jagged = torch.ops.aten._padded_dense_to_jagged_forward(
6928*da0073e9SAndroid Build Coastguard Worker            padded, [offsets], total_L
6929*da0073e9SAndroid Build Coastguard Worker        )
6930*da0073e9SAndroid Build Coastguard Worker
6931*da0073e9SAndroid Build Coastguard Worker        # should be equivalent to the original values
6932*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(values, output_jagged)
6933*da0073e9SAndroid Build Coastguard Worker
6934*da0073e9SAndroid Build Coastguard Worker        # success case: truncate to max length as needed
6935*da0073e9SAndroid Build Coastguard Worker        trunc_max_length = max_length - 1
6936*da0073e9SAndroid Build Coastguard Worker        trunc_padded = torch.ops.aten._jagged_to_padded_dense_forward(
6937*da0073e9SAndroid Build Coastguard Worker            values, [offsets], [trunc_max_length], padding_value
6938*da0073e9SAndroid Build Coastguard Worker        )
6939*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(padded[:, :trunc_max_length, :], trunc_padded)
6940*da0073e9SAndroid Build Coastguard Worker
6941*da0073e9SAndroid Build Coastguard Worker        # specific to CPU impls
6942*da0073e9SAndroid Build Coastguard Worker        if device == "cpu":
6943*da0073e9SAndroid Build Coastguard Worker            # error case: multiple offsets on cpu since CPU kernels don't support more now
6944*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
6945*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "only a single jagged dim is supported"
6946*da0073e9SAndroid Build Coastguard Worker            ):
6947*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten._jagged_to_padded_dense_forward(
6948*da0073e9SAndroid Build Coastguard Worker                    values, [offsets, offsets], [max_length, max_length], padding_value
6949*da0073e9SAndroid Build Coastguard Worker                )
6950*da0073e9SAndroid Build Coastguard Worker
6951*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
6952*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "only a single jagged dim is supported"
6953*da0073e9SAndroid Build Coastguard Worker            ):
6954*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten._padded_dense_to_jagged_forward(
6955*da0073e9SAndroid Build Coastguard Worker                    padded, [offsets, offsets], total_L
6956*da0073e9SAndroid Build Coastguard Worker                )
6957*da0073e9SAndroid Build Coastguard Worker
6958*da0073e9SAndroid Build Coastguard Worker            # error case: > 1D offsets
6959*da0073e9SAndroid Build Coastguard Worker            offsets2d = offsets.unsqueeze(-1)
6960*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"):
6961*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten._jagged_to_padded_dense_forward(
6962*da0073e9SAndroid Build Coastguard Worker                    values, [offsets2d], [max_length], padding_value
6963*da0073e9SAndroid Build Coastguard Worker                )
6964*da0073e9SAndroid Build Coastguard Worker
6965*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"):
6966*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten._padded_dense_to_jagged_forward(
6967*da0073e9SAndroid Build Coastguard Worker                    padded, [offsets2d], total_L
6968*da0073e9SAndroid Build Coastguard Worker                )
6969*da0073e9SAndroid Build Coastguard Worker
6970*da0073e9SAndroid Build Coastguard Worker            # error case: final offset != total_L
6971*da0073e9SAndroid Build Coastguard Worker            offsets_wrong = offsets.clone().detach()
6972*da0073e9SAndroid Build Coastguard Worker            offsets_wrong[-1] = total_L + 1
6973*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
6974*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "final offset should match total_L value"
6975*da0073e9SAndroid Build Coastguard Worker            ):
6976*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten._padded_dense_to_jagged_forward(
6977*da0073e9SAndroid Build Coastguard Worker                    padded, [offsets_wrong], total_L
6978*da0073e9SAndroid Build Coastguard Worker                )
6979*da0073e9SAndroid Build Coastguard Worker
6980*da0073e9SAndroid Build Coastguard Worker            # error case: 1D padded input
6981*da0073e9SAndroid Build Coastguard Worker            padded_wrong = padded.flatten().clone().detach()
6982*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "expected padded dim >= 2"):
6983*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten._padded_dense_to_jagged_forward(
6984*da0073e9SAndroid Build Coastguard Worker                    padded_wrong, [offsets], total_L
6985*da0073e9SAndroid Build Coastguard Worker                )
6986*da0073e9SAndroid Build Coastguard Worker
6987*da0073e9SAndroid Build Coastguard Worker            # error case: batch item has length > max length
6988*da0073e9SAndroid Build Coastguard Worker            # max_length is 5 above; 7 here
6989*da0073e9SAndroid Build Coastguard Worker            offsets_wrong = torch.tensor(
6990*da0073e9SAndroid Build Coastguard Worker                [0, 1, 8, 9, 10], device=device, dtype=torch.int64
6991*da0073e9SAndroid Build Coastguard Worker            )
6992*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "found batch item of length"):
6993*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten._padded_dense_to_jagged_forward(
6994*da0073e9SAndroid Build Coastguard Worker                    padded, [offsets_wrong], total_L
6995*da0073e9SAndroid Build Coastguard Worker                )
6996*da0073e9SAndroid Build Coastguard Worker
6997*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
6998*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Test compiles internally")
6999*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
7000*da0073e9SAndroid Build Coastguard Worker        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
7001*da0073e9SAndroid Build Coastguard Worker    )
7002*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
7003*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
7004*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm
7005*da0073e9SAndroid Build Coastguard Worker    def test_compile_preserves_metadata_cache(self, device, dtype):
7006*da0073e9SAndroid Build Coastguard Worker        # shape (B, *, D)
7007*da0073e9SAndroid Build Coastguard Worker        nt = random_nt_from_dims(
7008*da0073e9SAndroid Build Coastguard Worker            [4, None, 3, 16],
7009*da0073e9SAndroid Build Coastguard Worker            device=device,
7010*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
7011*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
7012*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
7013*da0073e9SAndroid Build Coastguard Worker        )
7014*da0073e9SAndroid Build Coastguard Worker
7015*da0073e9SAndroid Build Coastguard Worker        # expect min / max seqlen to be stored here
7016*da0073e9SAndroid Build Coastguard Worker        cache = dict(nt._metadata_cache)
7017*da0073e9SAndroid Build Coastguard Worker
7018*da0073e9SAndroid Build Coastguard Worker        @torch.compile
7019*da0073e9SAndroid Build Coastguard Worker        def f(nt):
7020*da0073e9SAndroid Build Coastguard Worker            q = nt.transpose(-3, -2)
7021*da0073e9SAndroid Build Coastguard Worker            output = F.scaled_dot_product_attention(q, q, q).transpose(-3, -2)
7022*da0073e9SAndroid Build Coastguard Worker            return output
7023*da0073e9SAndroid Build Coastguard Worker
7024*da0073e9SAndroid Build Coastguard Worker        output = f(nt)
7025*da0073e9SAndroid Build Coastguard Worker        output.backward(torch.ones_like(output))
7026*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output._metadata_cache, cache)
7027*da0073e9SAndroid Build Coastguard Worker
7028*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
7029*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Test compiles internally")
7030*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
7031*da0073e9SAndroid Build Coastguard Worker        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
7032*da0073e9SAndroid Build Coastguard Worker    )
7033*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
7034*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
7035*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm
7036*da0073e9SAndroid Build Coastguard Worker    def test_compile_with_dynamic_max_seq_len(self, device, dtype):
7037*da0073e9SAndroid Build Coastguard Worker        # shape (B, *, D)
7038*da0073e9SAndroid Build Coastguard Worker        # max seq len: 18
7039*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
7040*da0073e9SAndroid Build Coastguard Worker            [
7041*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 5),
7042*da0073e9SAndroid Build Coastguard Worker                torch.randn(3, 5),
7043*da0073e9SAndroid Build Coastguard Worker                torch.randn(18, 5),
7044*da0073e9SAndroid Build Coastguard Worker            ],
7045*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
7046*da0073e9SAndroid Build Coastguard Worker        )
7047*da0073e9SAndroid Build Coastguard Worker
7048*da0073e9SAndroid Build Coastguard Worker        # max seq len: 19
7049*da0073e9SAndroid Build Coastguard Worker        nt2 = torch.nested.nested_tensor(
7050*da0073e9SAndroid Build Coastguard Worker            [
7051*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 5),
7052*da0073e9SAndroid Build Coastguard Worker                torch.randn(3, 5),
7053*da0073e9SAndroid Build Coastguard Worker                torch.randn(19, 5),
7054*da0073e9SAndroid Build Coastguard Worker            ],
7055*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
7056*da0073e9SAndroid Build Coastguard Worker        )
7057*da0073e9SAndroid Build Coastguard Worker
7058*da0073e9SAndroid Build Coastguard Worker        def f(nt):
7059*da0073e9SAndroid Build Coastguard Worker            # TODO: Replace with public API when we can use @properties
7060*da0073e9SAndroid Build Coastguard Worker            return torch.ones_like(nt) * nt._get_max_seqlen()
7061*da0073e9SAndroid Build Coastguard Worker
7062*da0073e9SAndroid Build Coastguard Worker        for dynamic in [False, True, None]:
7063*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
7064*da0073e9SAndroid Build Coastguard Worker
7065*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
7066*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Test compiles internally")
7067*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
7068*da0073e9SAndroid Build Coastguard Worker        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
7069*da0073e9SAndroid Build Coastguard Worker    )
7070*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
7071*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
7072*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm
7073*da0073e9SAndroid Build Coastguard Worker    def test_compile_with_dynamic_min_seq_len(self, device, dtype):
7074*da0073e9SAndroid Build Coastguard Worker        # shape (B, *, D)
7075*da0073e9SAndroid Build Coastguard Worker        # min seq len: 7
7076*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
7077*da0073e9SAndroid Build Coastguard Worker            [
7078*da0073e9SAndroid Build Coastguard Worker                torch.randn(7, 5),
7079*da0073e9SAndroid Build Coastguard Worker                torch.randn(8, 5),
7080*da0073e9SAndroid Build Coastguard Worker                torch.randn(9, 5),
7081*da0073e9SAndroid Build Coastguard Worker            ],
7082*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
7083*da0073e9SAndroid Build Coastguard Worker        )
7084*da0073e9SAndroid Build Coastguard Worker
7085*da0073e9SAndroid Build Coastguard Worker        # min seq len: 8
7086*da0073e9SAndroid Build Coastguard Worker        nt2 = torch.nested.nested_tensor(
7087*da0073e9SAndroid Build Coastguard Worker            [
7088*da0073e9SAndroid Build Coastguard Worker                torch.randn(8, 5),
7089*da0073e9SAndroid Build Coastguard Worker                torch.randn(9, 5),
7090*da0073e9SAndroid Build Coastguard Worker                torch.randn(10, 5),
7091*da0073e9SAndroid Build Coastguard Worker            ],
7092*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
7093*da0073e9SAndroid Build Coastguard Worker        )
7094*da0073e9SAndroid Build Coastguard Worker
7095*da0073e9SAndroid Build Coastguard Worker        def f(nt):
7096*da0073e9SAndroid Build Coastguard Worker            # TODO: Replace with public API when we can use @properties
7097*da0073e9SAndroid Build Coastguard Worker            return torch.ones_like(nt) * nt._get_min_seqlen()
7098*da0073e9SAndroid Build Coastguard Worker
7099*da0073e9SAndroid Build Coastguard Worker        for dynamic in [False, True, None]:
7100*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
7101*da0073e9SAndroid Build Coastguard Worker
7102*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
7103*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Test compiles internally")
7104*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
7105*da0073e9SAndroid Build Coastguard Worker        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
7106*da0073e9SAndroid Build Coastguard Worker    )
7107*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
7108*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
7109*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm
7110*da0073e9SAndroid Build Coastguard Worker    def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype):
7111*da0073e9SAndroid Build Coastguard Worker        # shape (B, *, D)
7112*da0073e9SAndroid Build Coastguard Worker        # max seq len: 18
7113*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
7114*da0073e9SAndroid Build Coastguard Worker            [
7115*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 5),
7116*da0073e9SAndroid Build Coastguard Worker                torch.randn(3, 5),
7117*da0073e9SAndroid Build Coastguard Worker                torch.randn(18, 5),
7118*da0073e9SAndroid Build Coastguard Worker            ],
7119*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
7120*da0073e9SAndroid Build Coastguard Worker        )
7121*da0073e9SAndroid Build Coastguard Worker
7122*da0073e9SAndroid Build Coastguard Worker        # max seq len: 19
7123*da0073e9SAndroid Build Coastguard Worker        nt2 = torch.nested.nested_tensor(
7124*da0073e9SAndroid Build Coastguard Worker            [
7125*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 5),
7126*da0073e9SAndroid Build Coastguard Worker                torch.randn(3, 5),
7127*da0073e9SAndroid Build Coastguard Worker                torch.randn(19, 5),
7128*da0073e9SAndroid Build Coastguard Worker            ],
7129*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
7130*da0073e9SAndroid Build Coastguard Worker        )
7131*da0073e9SAndroid Build Coastguard Worker
7132*da0073e9SAndroid Build Coastguard Worker        def f(nt):
7133*da0073e9SAndroid Build Coastguard Worker            nt2 = nt.sin() + 1
7134*da0073e9SAndroid Build Coastguard Worker            # TODO: Replace with public API when we can use @properties
7135*da0073e9SAndroid Build Coastguard Worker            return torch.ones_like(nt2) * nt2._get_max_seqlen()
7136*da0073e9SAndroid Build Coastguard Worker
7137*da0073e9SAndroid Build Coastguard Worker        ref = f(nt)
7138*da0073e9SAndroid Build Coastguard Worker        output = torch.compile(f, fullgraph=True, dynamic=False)(nt)
7139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, output)
7140*da0073e9SAndroid Build Coastguard Worker
7141*da0073e9SAndroid Build Coastguard Worker        for dynamic in [False, True, None]:
7142*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
7143*da0073e9SAndroid Build Coastguard Worker
7144*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.double, torch.half)
7145*da0073e9SAndroid Build Coastguard Worker    def test_unbind_backward(self, device, dtype):
7146*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
7147*da0073e9SAndroid Build Coastguard Worker            [
7148*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 4, device=device),
7149*da0073e9SAndroid Build Coastguard Worker                torch.randn(5, 4, device=device),
7150*da0073e9SAndroid Build Coastguard Worker                torch.randn(3, 4, device=device),
7151*da0073e9SAndroid Build Coastguard Worker            ],
7152*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
7153*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
7154*da0073e9SAndroid Build Coastguard Worker        )
7155*da0073e9SAndroid Build Coastguard Worker
7156*da0073e9SAndroid Build Coastguard Worker        a, b, c = nt.unbind()
7157*da0073e9SAndroid Build Coastguard Worker        b.sum().backward()
7158*da0073e9SAndroid Build Coastguard Worker
7159*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.disable
7160*da0073e9SAndroid Build Coastguard Worker        def check(nt):
7161*da0073e9SAndroid Build Coastguard Worker            expected_grad = torch.zeros_like(nt)
7162*da0073e9SAndroid Build Coastguard Worker            expected_grad.unbind()[1].add_(1.0)
7163*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nt.grad, expected_grad)
7164*da0073e9SAndroid Build Coastguard Worker
7165*da0073e9SAndroid Build Coastguard Worker        check(nt)
7166*da0073e9SAndroid Build Coastguard Worker
7167*da0073e9SAndroid Build Coastguard Worker
7168*da0073e9SAndroid Build Coastguard WorkerFORWARD_FAILURES = {
7169*da0073e9SAndroid Build Coastguard Worker    # === BEGIN NotImplementedError SECTION ===
7170*da0073e9SAndroid Build Coastguard Worker    # unary
7171*da0073e9SAndroid Build Coastguard Worker    "nn.functional.celu",
7172*da0073e9SAndroid Build Coastguard Worker    "nn.functional.elu",
7173*da0073e9SAndroid Build Coastguard Worker    "nn.functional.hardshrink",
7174*da0073e9SAndroid Build Coastguard Worker    "nn.functional.hardsigmoid",
7175*da0073e9SAndroid Build Coastguard Worker    "nn.functional.hardtanh",
7176*da0073e9SAndroid Build Coastguard Worker    "nn.functional.logsigmoid",
7177*da0073e9SAndroid Build Coastguard Worker    "nn.functional.mish",
7178*da0073e9SAndroid Build Coastguard Worker    "nn.functional.relu6",
7179*da0073e9SAndroid Build Coastguard Worker    "nn.functional.rrelu",
7180*da0073e9SAndroid Build Coastguard Worker    "nn.functional.selu",
7181*da0073e9SAndroid Build Coastguard Worker    "nn.functional.softplus",
7182*da0073e9SAndroid Build Coastguard Worker    "nn.functional.softshrink",
7183*da0073e9SAndroid Build Coastguard Worker    "nn.functional.threshold",
7184*da0073e9SAndroid Build Coastguard Worker    "rad2deg",
7185*da0073e9SAndroid Build Coastguard Worker    # binary
7186*da0073e9SAndroid Build Coastguard Worker    "__rsub__",
7187*da0073e9SAndroid Build Coastguard Worker    "complex",
7188*da0073e9SAndroid Build Coastguard Worker    "floor_divide",
7189*da0073e9SAndroid Build Coastguard Worker    "polar",
7190*da0073e9SAndroid Build Coastguard Worker    "rsub",
7191*da0073e9SAndroid Build Coastguard Worker    # reduction
7192*da0073e9SAndroid Build Coastguard Worker    "all",
7193*da0073e9SAndroid Build Coastguard Worker    "amax",
7194*da0073e9SAndroid Build Coastguard Worker    "amin",
7195*da0073e9SAndroid Build Coastguard Worker    "any",
7196*da0073e9SAndroid Build Coastguard Worker    "argmax",
7197*da0073e9SAndroid Build Coastguard Worker    "argmin",
7198*da0073e9SAndroid Build Coastguard Worker    "count_nonzero",
7199*da0073e9SAndroid Build Coastguard Worker    "linalg.vector_norm",
7200*da0073e9SAndroid Build Coastguard Worker    "nansum",
7201*da0073e9SAndroid Build Coastguard Worker    "std",
7202*da0073e9SAndroid Build Coastguard Worker    "std.unbiased",
7203*da0073e9SAndroid Build Coastguard Worker    "var",
7204*da0073e9SAndroid Build Coastguard Worker    "var.unbiased",
7205*da0073e9SAndroid Build Coastguard Worker    # === BEGIN UNSUPPORTED SECTION ===
7206*da0073e9SAndroid Build Coastguard Worker    # RuntimeError: mean(): not supported for NestedTensor on dim=1
7207*da0073e9SAndroid Build Coastguard Worker    "mean",
7208*da0073e9SAndroid Build Coastguard Worker    # ValueError: expects strided tensor (got torch.jagged tensor)
7209*da0073e9SAndroid Build Coastguard Worker    "masked.amax",
7210*da0073e9SAndroid Build Coastguard Worker    "masked.amin",
7211*da0073e9SAndroid Build Coastguard Worker    "masked.argmax",
7212*da0073e9SAndroid Build Coastguard Worker    "masked.argmin",
7213*da0073e9SAndroid Build Coastguard Worker    "masked.logsumexp",
7214*da0073e9SAndroid Build Coastguard Worker    "masked.mean",
7215*da0073e9SAndroid Build Coastguard Worker    "masked.norm",
7216*da0073e9SAndroid Build Coastguard Worker    "masked.prod",
7217*da0073e9SAndroid Build Coastguard Worker    "masked.std",
7218*da0073e9SAndroid Build Coastguard Worker    "masked.sum",
7219*da0073e9SAndroid Build Coastguard Worker    "masked.var",
7220*da0073e9SAndroid Build Coastguard Worker    # === BEGIN BUG SECTION ===
7221*da0073e9SAndroid Build Coastguard Worker    # Returns a tuple of Tensors so it doesn't work with NJT's unary pointwise logic
7222*da0073e9SAndroid Build Coastguard Worker    "frexp",
7223*da0073e9SAndroid Build Coastguard Worker    # Need to adjust sample input func to pass the right thing
7224*da0073e9SAndroid Build Coastguard Worker    "nn.functional.prelu",
7225*da0073e9SAndroid Build Coastguard Worker    # TypeError: fill() received an invalid combination of arguments
7226*da0073e9SAndroid Build Coastguard Worker    # got (NestedTensor), but expected one of:
7227*da0073e9SAndroid Build Coastguard Worker    # * (Tensor input, Tensor value)
7228*da0073e9SAndroid Build Coastguard Worker    # * (Tensor input, Number value)
7229*da0073e9SAndroid Build Coastguard Worker    "fill",
7230*da0073e9SAndroid Build Coastguard Worker    # RuntimeError: unsupported tensor layout: Jagged
7231*da0073e9SAndroid Build Coastguard Worker    "jiterator_binary",
7232*da0073e9SAndroid Build Coastguard Worker    "jiterator_binary_return_by_ref",
7233*da0073e9SAndroid Build Coastguard Worker    "jiterator_unary",
7234*da0073e9SAndroid Build Coastguard Worker    # Bug found: sum() with keepdim=True returns invalid shape
7235*da0073e9SAndroid Build Coastguard Worker    "sum",
7236*da0073e9SAndroid Build Coastguard Worker    # RuntimeError: prod(): keepdim=True must be set for NestedTensor
7237*da0073e9SAndroid Build Coastguard Worker    "prod",
7238*da0073e9SAndroid Build Coastguard Worker    # RuntimeError: "jagged_to_padded_dense" not implemented for 'Bool'
7239*da0073e9SAndroid Build Coastguard Worker    "nanmean",
7240*da0073e9SAndroid Build Coastguard Worker}
7241*da0073e9SAndroid Build Coastguard Worker
7242*da0073e9SAndroid Build Coastguard WorkerBACKWARD_FAILURES = {
7243*da0073e9SAndroid Build Coastguard Worker    *FORWARD_FAILURES,
7244*da0073e9SAndroid Build Coastguard Worker    # TODO: categorize these
7245*da0073e9SAndroid Build Coastguard Worker    "__rpow__",
7246*da0073e9SAndroid Build Coastguard Worker    "atanh",
7247*da0073e9SAndroid Build Coastguard Worker    "cdouble",
7248*da0073e9SAndroid Build Coastguard Worker    "cfloat",
7249*da0073e9SAndroid Build Coastguard Worker    "chalf",
7250*da0073e9SAndroid Build Coastguard Worker    "clamp_max",
7251*da0073e9SAndroid Build Coastguard Worker    "clamp_min",
7252*da0073e9SAndroid Build Coastguard Worker    "copysign",
7253*da0073e9SAndroid Build Coastguard Worker    "float_power",
7254*da0073e9SAndroid Build Coastguard Worker    "max.binary",
7255*da0073e9SAndroid Build Coastguard Worker    "maximum",
7256*da0073e9SAndroid Build Coastguard Worker    "min.binary",
7257*da0073e9SAndroid Build Coastguard Worker    "minimum",
7258*da0073e9SAndroid Build Coastguard Worker    "pow",
7259*da0073e9SAndroid Build Coastguard Worker    "sgn",
7260*da0073e9SAndroid Build Coastguard Worker    "sinc",
7261*da0073e9SAndroid Build Coastguard Worker    "special.i1",
7262*da0073e9SAndroid Build Coastguard Worker    "special.i1e",
7263*da0073e9SAndroid Build Coastguard Worker    # clone() on a "non-contiguous with holes" NJT allocates a new offsets -> new nested int
7264*da0073e9SAndroid Build Coastguard Worker    # RuntimeError: Function CloneBackward0 returned an invalid gradient at index 0 -
7265*da0073e9SAndroid Build Coastguard Worker    # got [3, j29, 5] but expected shape compatible with [3, j28, 5]
7266*da0073e9SAndroid Build Coastguard Worker    "clone",
7267*da0073e9SAndroid Build Coastguard Worker    # Calling into torch.ops.aten.size directly
7268*da0073e9SAndroid Build Coastguard Worker    "masked_select",
7269*da0073e9SAndroid Build Coastguard Worker}
7270*da0073e9SAndroid Build Coastguard Worker
7271*da0073e9SAndroid Build Coastguard WorkerCOMPILE_FORWARD_FAILURES = {
7272*da0073e9SAndroid Build Coastguard Worker    *FORWARD_FAILURES,
7273*da0073e9SAndroid Build Coastguard Worker    # clone() on non-contiguous with holes NJTs currently use unbind(), leading to
7274*da0073e9SAndroid Build Coastguard Worker    # data-dependent error in torch.compile
7275*da0073e9SAndroid Build Coastguard Worker    "clone",
7276*da0073e9SAndroid Build Coastguard Worker}
7277*da0073e9SAndroid Build Coastguard Worker
7278*da0073e9SAndroid Build Coastguard WorkerCOMPARE_TENSOR_COMPONENT_EQUALITY = {
7279*da0073e9SAndroid Build Coastguard Worker    # masked_select is expected to output a different shape
7280*da0073e9SAndroid Build Coastguard Worker    "masked_select",
7281*da0073e9SAndroid Build Coastguard Worker}
7282*da0073e9SAndroid Build Coastguard Worker
7283*da0073e9SAndroid Build Coastguard Worker
7284*da0073e9SAndroid Build Coastguard Workerdef withXFails(failure_list):
7285*da0073e9SAndroid Build Coastguard Worker    return decorateIf(
7286*da0073e9SAndroid Build Coastguard Worker        unittest.expectedFailure,
7287*da0073e9SAndroid Build Coastguard Worker        lambda params: params["op"].full_name in failure_list,
7288*da0073e9SAndroid Build Coastguard Worker    )
7289*da0073e9SAndroid Build Coastguard Worker
7290*da0073e9SAndroid Build Coastguard Worker
7291*da0073e9SAndroid Build Coastguard Worker# OpInfo-based NJT tests. These tests utilize an NJT-specific op_db generated from the standard
7292*da0073e9SAndroid Build Coastguard Worker# op_db. Note that certain tradeoffs were made wrt coverage vs. time spent running tests:
7293*da0073e9SAndroid Build Coastguard Worker#   * All tests run with dtype=torch.float32 only
7294*da0073e9SAndroid Build Coastguard Workerclass TestNestedTensorOpInfo(NestedTensorTestCase):
7295*da0073e9SAndroid Build Coastguard Worker    # TODO: move this
7296*da0073e9SAndroid Build Coastguard Worker    def _gen_grad_outputs(self, out_val):
7297*da0073e9SAndroid Build Coastguard Worker        if isinstance(out_val, (list, tuple)):
7298*da0073e9SAndroid Build Coastguard Worker            return tuple(torch.ones_like(c) for c in out_val)
7299*da0073e9SAndroid Build Coastguard Worker        else:
7300*da0073e9SAndroid Build Coastguard Worker            return (torch.ones_like(out_val),)
7301*da0073e9SAndroid Build Coastguard Worker
7302*da0073e9SAndroid Build Coastguard Worker    @withXFails(FORWARD_FAILURES)
7303*da0073e9SAndroid Build Coastguard Worker    @ops([op for op in njt_op_db if op.supports_njt], allowed_dtypes=(torch.float32,))
7304*da0073e9SAndroid Build Coastguard Worker    def test_forward(self, device, dtype, op):
7305*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=False):
7306*da0073e9SAndroid Build Coastguard Worker            # compare to reference, but expect different nested int
7307*da0073e9SAndroid Build Coastguard Worker            out = op.op(sample.input, *sample.args, **sample.kwargs)
7308*da0073e9SAndroid Build Coastguard Worker            out_ref = op.ref(op, sample)
7309*da0073e9SAndroid Build Coastguard Worker            self.assertEqualIgnoringNestedInts(out, out_ref)
7310*da0073e9SAndroid Build Coastguard Worker
7311*da0073e9SAndroid Build Coastguard Worker    @withXFails(BACKWARD_FAILURES)
7312*da0073e9SAndroid Build Coastguard Worker    @ops(
7313*da0073e9SAndroid Build Coastguard Worker        [op for op in njt_op_db if op.supports_njt and op.supports_autograd],
7314*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=(torch.float32,),
7315*da0073e9SAndroid Build Coastguard Worker    )
7316*da0073e9SAndroid Build Coastguard Worker    def test_backward(self, device, dtype, op):
7317*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=True):
7318*da0073e9SAndroid Build Coastguard Worker            # compare to reference, but expect different nested int
7319*da0073e9SAndroid Build Coastguard Worker            out = op.op(sample.input, *sample.args, **sample.kwargs)
7320*da0073e9SAndroid Build Coastguard Worker            out_ref = op.ref(op, sample)
7321*da0073e9SAndroid Build Coastguard Worker            self.assertEqualIgnoringNestedInts(out, out_ref)
7322*da0073e9SAndroid Build Coastguard Worker
7323*da0073e9SAndroid Build Coastguard Worker            inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs))
7324*da0073e9SAndroid Build Coastguard Worker            g_inps = [
7325*da0073e9SAndroid Build Coastguard Worker                inp
7326*da0073e9SAndroid Build Coastguard Worker                for inp in inps
7327*da0073e9SAndroid Build Coastguard Worker                if isinstance(inp, torch.Tensor) and inp.requires_grad
7328*da0073e9SAndroid Build Coastguard Worker            ]
7329*da0073e9SAndroid Build Coastguard Worker            if len(g_inps) > 0:
7330*da0073e9SAndroid Build Coastguard Worker                grads = torch.autograd.grad(
7331*da0073e9SAndroid Build Coastguard Worker                    out, inputs=g_inps, grad_outputs=self._gen_grad_outputs(out)
7332*da0073e9SAndroid Build Coastguard Worker                )
7333*da0073e9SAndroid Build Coastguard Worker
7334*da0073e9SAndroid Build Coastguard Worker                grads_ref = torch.autograd.grad(
7335*da0073e9SAndroid Build Coastguard Worker                    out_ref,
7336*da0073e9SAndroid Build Coastguard Worker                    inputs=g_inps,
7337*da0073e9SAndroid Build Coastguard Worker                    grad_outputs=self._gen_grad_outputs(out_ref),
7338*da0073e9SAndroid Build Coastguard Worker                )
7339*da0073e9SAndroid Build Coastguard Worker
7340*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grads, grads_ref)
7341*da0073e9SAndroid Build Coastguard Worker
7342*da0073e9SAndroid Build Coastguard Worker    @withXFails(COMPILE_FORWARD_FAILURES)
7343*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
7344*da0073e9SAndroid Build Coastguard Worker    @ops([op for op in njt_op_db if op.supports_njt], allowed_dtypes=(torch.float32,))
7345*da0073e9SAndroid Build Coastguard Worker    def test_compile_forward(self, device, dtype, op):
7346*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=False):
7347*da0073e9SAndroid Build Coastguard Worker            torch.compiler.reset()
7348*da0073e9SAndroid Build Coastguard Worker
7349*da0073e9SAndroid Build Coastguard Worker            op_fn = op.op
7350*da0073e9SAndroid Build Coastguard Worker
7351*da0073e9SAndroid Build Coastguard Worker            def f(*args, **kwargs):
7352*da0073e9SAndroid Build Coastguard Worker                return op_fn(*args, **kwargs)
7353*da0073e9SAndroid Build Coastguard Worker
7354*da0073e9SAndroid Build Coastguard Worker            compiled_f = torch.compile(
7355*da0073e9SAndroid Build Coastguard Worker                f, fullgraph=True, backend="aot_eager_decomp_partition"
7356*da0073e9SAndroid Build Coastguard Worker            )
7357*da0073e9SAndroid Build Coastguard Worker
7358*da0073e9SAndroid Build Coastguard Worker            out_ref = f(sample.input, *sample.args, **sample.kwargs)
7359*da0073e9SAndroid Build Coastguard Worker            out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)
7360*da0073e9SAndroid Build Coastguard Worker
7361*da0073e9SAndroid Build Coastguard Worker            if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY:
7362*da0073e9SAndroid Build Coastguard Worker                self.assertEqualIgnoringNestedInts(out_compile, out_ref)
7363*da0073e9SAndroid Build Coastguard Worker            else:
7364*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out_compile, out_ref)
7365*da0073e9SAndroid Build Coastguard Worker
7366*da0073e9SAndroid Build Coastguard Worker    @withXFails(BACKWARD_FAILURES)
7367*da0073e9SAndroid Build Coastguard Worker    @ops(
7368*da0073e9SAndroid Build Coastguard Worker        [op for op in njt_op_db if op.supports_njt and op.supports_autograd],
7369*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=(torch.float32,),
7370*da0073e9SAndroid Build Coastguard Worker    )
7371*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
7372*da0073e9SAndroid Build Coastguard Worker    def test_compile_backward(self, device, dtype, op):
7373*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=True):
7374*da0073e9SAndroid Build Coastguard Worker            torch.compiler.reset()
7375*da0073e9SAndroid Build Coastguard Worker
7376*da0073e9SAndroid Build Coastguard Worker            op_fn = op.op
7377*da0073e9SAndroid Build Coastguard Worker
7378*da0073e9SAndroid Build Coastguard Worker            def f(*args, **kwargs):
7379*da0073e9SAndroid Build Coastguard Worker                return op_fn(*args, **kwargs)
7380*da0073e9SAndroid Build Coastguard Worker
7381*da0073e9SAndroid Build Coastguard Worker            compiled_f = torch.compile(
7382*da0073e9SAndroid Build Coastguard Worker                f, fullgraph=True, backend="aot_eager_decomp_partition"
7383*da0073e9SAndroid Build Coastguard Worker            )
7384*da0073e9SAndroid Build Coastguard Worker
7385*da0073e9SAndroid Build Coastguard Worker            out_ref = f(sample.input, *sample.args, **sample.kwargs)
7386*da0073e9SAndroid Build Coastguard Worker            out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)
7387*da0073e9SAndroid Build Coastguard Worker
7388*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_compile, out_ref)
7389*da0073e9SAndroid Build Coastguard Worker
7390*da0073e9SAndroid Build Coastguard Worker            inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs))
7391*da0073e9SAndroid Build Coastguard Worker            g_inps = [
7392*da0073e9SAndroid Build Coastguard Worker                inp
7393*da0073e9SAndroid Build Coastguard Worker                for inp in inps
7394*da0073e9SAndroid Build Coastguard Worker                if isinstance(inp, torch.Tensor) and inp.requires_grad
7395*da0073e9SAndroid Build Coastguard Worker            ]
7396*da0073e9SAndroid Build Coastguard Worker            if len(g_inps) > 0:
7397*da0073e9SAndroid Build Coastguard Worker                grads_compile = torch.autograd.grad(
7398*da0073e9SAndroid Build Coastguard Worker                    out_compile,
7399*da0073e9SAndroid Build Coastguard Worker                    inputs=g_inps,
7400*da0073e9SAndroid Build Coastguard Worker                    grad_outputs=self._gen_grad_outputs(out_compile),
7401*da0073e9SAndroid Build Coastguard Worker                )
7402*da0073e9SAndroid Build Coastguard Worker
7403*da0073e9SAndroid Build Coastguard Worker                grads_ref = torch.autograd.grad(
7404*da0073e9SAndroid Build Coastguard Worker                    out_ref, inputs=g_inps, grad_outputs=self._gen_grad_outputs(out_ref)
7405*da0073e9SAndroid Build Coastguard Worker                )
7406*da0073e9SAndroid Build Coastguard Worker
7407*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grads_compile, grads_ref)
7408*da0073e9SAndroid Build Coastguard Worker
7409*da0073e9SAndroid Build Coastguard Worker
7410*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestNestedTensor)
7411*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNestedTensorDeviceType, globals())
7412*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNestedTensorAutograd, globals())
7413*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNestedTensorSubclass, globals())
7414*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNestedTensorOpInfo, globals())
7415*da0073e9SAndroid Build Coastguard Worker
7416*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
7417*da0073e9SAndroid Build Coastguard Worker    run_tests()
7418