1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nestedtensor"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport io 4*da0073e9SAndroid Build Coastguard Workerimport itertools 5*da0073e9SAndroid Build Coastguard Workerimport math 6*da0073e9SAndroid Build Coastguard Workerimport sys 7*da0073e9SAndroid Build Coastguard Workerimport unittest 8*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 9*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Tuple 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerimport numpy as np 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerimport torch 14*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo 15*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 16*da0073e9SAndroid Build Coastguard Workerimport torch.nn 17*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 18*da0073e9SAndroid Build Coastguard Workerfrom torch.nested._internal.nested_tensor import ( 19*da0073e9SAndroid Build Coastguard Worker buffer_from_jagged, 20*da0073e9SAndroid Build Coastguard Worker jagged_from_list, 21*da0073e9SAndroid Build Coastguard Worker nested_view_from_values_offsets, 22*da0073e9SAndroid Build Coastguard Worker NestedTensor, 23*da0073e9SAndroid Build Coastguard Worker ViewNestedFromBuffer, 24*da0073e9SAndroid Build Coastguard Worker) 25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import ( 26*da0073e9SAndroid Build Coastguard Worker PLATFORM_SUPPORTS_FUSED_ATTENTION, 27*da0073e9SAndroid Build Coastguard Worker SM70OrLater, 28*da0073e9SAndroid Build Coastguard Worker SM80OrLater, 29*da0073e9SAndroid Build Coastguard Worker) 30*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 31*da0073e9SAndroid Build Coastguard Worker dtypes, 32*da0073e9SAndroid Build Coastguard Worker dtypesIfCUDA, 33*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 34*da0073e9SAndroid Build Coastguard Worker onlyCPU, 35*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 36*da0073e9SAndroid Build Coastguard Worker ops, 37*da0073e9SAndroid Build Coastguard Worker PYTORCH_CUDA_MEMCHECK, 38*da0073e9SAndroid Build Coastguard Worker skipCPUIf, 39*da0073e9SAndroid Build Coastguard Worker skipCUDAIf, 40*da0073e9SAndroid Build Coastguard Worker skipCUDAIfRocm, 41*da0073e9SAndroid Build Coastguard Worker skipMeta, 42*da0073e9SAndroid Build Coastguard Worker) 43*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import floating_types_and_half 44*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 45*da0073e9SAndroid Build Coastguard Worker decorateIf, 46*da0073e9SAndroid Build Coastguard Worker freeze_rng_state, 47*da0073e9SAndroid Build Coastguard Worker gradcheck, 48*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 49*da0073e9SAndroid Build Coastguard Worker IS_FBCODE, 50*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 51*da0073e9SAndroid Build Coastguard Worker markDynamoStrictTest, 52*da0073e9SAndroid Build Coastguard Worker NestedTensorTestCase, 53*da0073e9SAndroid Build Coastguard Worker parametrize, 54*da0073e9SAndroid Build Coastguard Worker run_tests, 55*da0073e9SAndroid Build Coastguard Worker skipIfSlowGradcheckEnv, 56*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 57*da0073e9SAndroid Build Coastguard Worker subtest, 58*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 59*da0073e9SAndroid Build Coastguard Worker xfailIfTorchDynamo, 60*da0073e9SAndroid Build Coastguard Worker) 61*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.opinfo.definitions.nested import njt_op_db 62*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_flatten 63*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker# Tests are ported from pytorch/nestedtensor. 67*da0073e9SAndroid Build Coastguard Worker# This makes porting as_nested_tensor easier in the future. 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Workerdef _iter_constructors(): 71*da0073e9SAndroid Build Coastguard Worker # yield as_nested_tensor 72*da0073e9SAndroid Build Coastguard Worker yield torch.nested.nested_tensor 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker# Returns True if the function recompiles between inputs1 and inputs2 with the 76*da0073e9SAndroid Build Coastguard Worker# specified dynamic setting. 77*da0073e9SAndroid Build Coastguard Workerdef _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): 78*da0073e9SAndroid Build Coastguard Worker compile_count = [0] 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def counter(gm, example_inputs): 81*da0073e9SAndroid Build Coastguard Worker compile_count[0] += 1 82*da0073e9SAndroid Build Coastguard Worker return gm 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) 85*da0073e9SAndroid Build Coastguard Worker compiled_f(*inputs1) 86*da0073e9SAndroid Build Coastguard Worker compiled_f(*inputs2) 87*da0073e9SAndroid Build Coastguard Worker return compile_count[0] > 1 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker# Helper function to generate a pair of random nested tensors 91*da0073e9SAndroid Build Coastguard Worker# one is contiguous, the other is not, but they appear to have same entries 92*da0073e9SAndroid Build Coastguard Worker# an output nested tensor consists of 93*da0073e9SAndroid Build Coastguard Worker# * `len(ragged_sizes)` matrices 94*da0073e9SAndroid Build Coastguard Worker# * matrices[i].shape == (20, ragged_sizes[i]) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Workerdef random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16): 98*da0073e9SAndroid Build Coastguard Worker xs = [] 99*da0073e9SAndroid Build Coastguard Worker for size in ragged_sizes: 100*da0073e9SAndroid Build Coastguard Worker xs.append(torch.randn((size, 20), device=device, dtype=dtype)) 101*da0073e9SAndroid Build Coastguard Worker # contiguous nested tensor 102*da0073e9SAndroid Build Coastguard Worker ys = [] 103*da0073e9SAndroid Build Coastguard Worker for x in xs: 104*da0073e9SAndroid Build Coastguard Worker ys.append(x.transpose(-1, -2)) 105*da0073e9SAndroid Build Coastguard Worker nt_contiguous = torch.nested.nested_tensor(ys) 106*da0073e9SAndroid Build Coastguard Worker # noncontiguous nested tensor 107*da0073e9SAndroid Build Coastguard Worker n = len(ragged_sizes) 108*da0073e9SAndroid Build Coastguard Worker nt_noncontiguous = torch.nested.nested_tensor(xs).transpose(-1, -2) 109*da0073e9SAndroid Build Coastguard Worker return nt_contiguous, nt_noncontiguous 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker# Helper functions to pad a noncontiguous nested tensor 113*da0073e9SAndroid Build Coastguard Worker# can be replaced once to_padded_tensor supports noncontiguous memory 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Workerdef noncontiguous_to_padded_tensor(input, shape=None): 117*da0073e9SAndroid Build Coastguard Worker tensors = input.unbind() 118*da0073e9SAndroid Build Coastguard Worker ntensors = len(tensors) 119*da0073e9SAndroid Build Coastguard Worker assert ntensors > 0 120*da0073e9SAndroid Build Coastguard Worker if shape is None: 121*da0073e9SAndroid Build Coastguard Worker shape = [] 122*da0073e9SAndroid Build Coastguard Worker for size in tensors[0].shape: 123*da0073e9SAndroid Build Coastguard Worker shape.append(size) 124*da0073e9SAndroid Build Coastguard Worker for i in range(1, ntensors): 125*da0073e9SAndroid Build Coastguard Worker new_shape = tensors[i].shape 126*da0073e9SAndroid Build Coastguard Worker for j in range(len(shape)): 127*da0073e9SAndroid Build Coastguard Worker shape[j] = max(shape[j], new_shape[j]) 128*da0073e9SAndroid Build Coastguard Worker shape = [ntensors] + shape 129*da0073e9SAndroid Build Coastguard Worker result = tensors[0].new_zeros(shape) 130*da0073e9SAndroid Build Coastguard Worker for itensor in range(ntensors): 131*da0073e9SAndroid Build Coastguard Worker tensor = tensors[itensor] 132*da0073e9SAndroid Build Coastguard Worker view = result[itensor] 133*da0073e9SAndroid Build Coastguard Worker for idim in range(tensor.dim()): 134*da0073e9SAndroid Build Coastguard Worker view = view.narrow(idim, 0, tensor.size(idim)) 135*da0073e9SAndroid Build Coastguard Worker view.copy_(tensor) 136*da0073e9SAndroid Build Coastguard Worker return result 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker# Helper function to generate a random nested tensor 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Workerdef random_nt( 143*da0073e9SAndroid Build Coastguard Worker device, 144*da0073e9SAndroid Build Coastguard Worker dtype, 145*da0073e9SAndroid Build Coastguard Worker num_tensors, 146*da0073e9SAndroid Build Coastguard Worker max_dims, 147*da0073e9SAndroid Build Coastguard Worker min_dims=None, 148*da0073e9SAndroid Build Coastguard Worker layout=torch.strided, 149*da0073e9SAndroid Build Coastguard Worker require_non_empty=True, 150*da0073e9SAndroid Build Coastguard Worker): 151*da0073e9SAndroid Build Coastguard Worker if min_dims is None: 152*da0073e9SAndroid Build Coastguard Worker min_dims = tuple([0] * len(max_dims)) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker assert len(max_dims) == len(min_dims) 155*da0073e9SAndroid Build Coastguard Worker for min_dim, max_dim in zip(min_dims, max_dims): 156*da0073e9SAndroid Build Coastguard Worker assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim" 157*da0073e9SAndroid Build Coastguard Worker assert min_dim >= 0, "random_nt: min_dim must be non-negative" 158*da0073e9SAndroid Build Coastguard Worker if require_non_empty: 159*da0073e9SAndroid Build Coastguard Worker assert not ( 160*da0073e9SAndroid Build Coastguard Worker min_dim == 0 and max_dim == 1 161*da0073e9SAndroid Build Coastguard Worker ), "random_nt: zero cannot be the only possible value if require_non_empty is True" 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker if require_non_empty: 164*da0073e9SAndroid Build Coastguard Worker # Select a random idx that will be required to be non-empty 165*da0073e9SAndroid Build Coastguard Worker non_zero_idx = torch.randint(low=0, high=num_tensors, size=(1,)).item() 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker ts1 = [] 168*da0073e9SAndroid Build Coastguard Worker for i, _ in enumerate(range(num_tensors)): 169*da0073e9SAndroid Build Coastguard Worker tensor_dims = [] 170*da0073e9SAndroid Build Coastguard Worker for min_dim, max_dim in zip(min_dims, max_dims): 171*da0073e9SAndroid Build Coastguard Worker new_min_dim = min_dim 172*da0073e9SAndroid Build Coastguard Worker if require_non_empty and i == non_zero_idx and min_dim == 0: 173*da0073e9SAndroid Build Coastguard Worker new_min_dim = 1 174*da0073e9SAndroid Build Coastguard Worker tensor_dims.append( 175*da0073e9SAndroid Build Coastguard Worker torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item() 176*da0073e9SAndroid Build Coastguard Worker ) 177*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn(tensor_dims, device=device, dtype=dtype) 178*da0073e9SAndroid Build Coastguard Worker ts1.append(t1) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker# Alternate approach to generating a random NT. 184*da0073e9SAndroid Build Coastguard Worker# dims should be something like [5, None, 10], with None indicating that a 185*da0073e9SAndroid Build Coastguard Worker# random ragged structure should be used 186*da0073e9SAndroid Build Coastguard Workerdef random_nt_from_dims( 187*da0073e9SAndroid Build Coastguard Worker dims, device=None, dtype=None, layout=torch.strided, requires_grad=False 188*da0073e9SAndroid Build Coastguard Worker): 189*da0073e9SAndroid Build Coastguard Worker sizes = [ 190*da0073e9SAndroid Build Coastguard Worker [ 191*da0073e9SAndroid Build Coastguard Worker d if d is not None else torch.randint(2, 10, size=(1,)).item() 192*da0073e9SAndroid Build Coastguard Worker for d in dims[1:] 193*da0073e9SAndroid Build Coastguard Worker ] 194*da0073e9SAndroid Build Coastguard Worker for d in range(dims[0]) 195*da0073e9SAndroid Build Coastguard Worker ] 196*da0073e9SAndroid Build Coastguard Worker return torch.nested.nested_tensor( 197*da0073e9SAndroid Build Coastguard Worker [torch.randn(*size) for size in sizes], 198*da0073e9SAndroid Build Coastguard Worker device=device, 199*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 200*da0073e9SAndroid Build Coastguard Worker layout=layout, 201*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 202*da0073e9SAndroid Build Coastguard Worker ) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker# Creates an NT matching another NT's number of components and 206*da0073e9SAndroid Build Coastguard Worker# shape / ragged structure for all dims specified to be -1. 207*da0073e9SAndroid Build Coastguard Workerdef random_nt_from_similar(other, dims=None): 208*da0073e9SAndroid Build Coastguard Worker if dims is None: 209*da0073e9SAndroid Build Coastguard Worker return torch.randn_like(other) 210*da0073e9SAndroid Build Coastguard Worker assert len(dims) == other.dim() 211*da0073e9SAndroid Build Coastguard Worker assert dims[0] == -1 or dims[0] == other.size(0) 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker ret_sizes = [] 214*da0073e9SAndroid Build Coastguard Worker for t in other.unbind(): 215*da0073e9SAndroid Build Coastguard Worker other_size = t.shape 216*da0073e9SAndroid Build Coastguard Worker ret_size = [] 217*da0073e9SAndroid Build Coastguard Worker for i, d in enumerate(dims[1:]): 218*da0073e9SAndroid Build Coastguard Worker if d == -1: 219*da0073e9SAndroid Build Coastguard Worker ret_size.append(other_size[i]) 220*da0073e9SAndroid Build Coastguard Worker else: 221*da0073e9SAndroid Build Coastguard Worker ret_size.append(d) 222*da0073e9SAndroid Build Coastguard Worker ret_sizes.append(ret_size) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker return torch.nested.nested_tensor( 225*da0073e9SAndroid Build Coastguard Worker [torch.randn(*size) for size in ret_sizes], device=other.device 226*da0073e9SAndroid Build Coastguard Worker ) 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker# makes naming nice for tests that parametrize over layout. 230*da0073e9SAndroid Build Coastguard Workerdef layout_name(layout): 231*da0073e9SAndroid Build Coastguard Worker # e.g. "torch.jagged" -> "jagged" 232*da0073e9SAndroid Build Coastguard Worker return layout.__repr__().split(".")[-1] 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Workerdef get_op_name(layout): 236*da0073e9SAndroid Build Coastguard Worker # e.g. "<OpOverload(op='aten.sum', overload='dim_IntList')>" -> "sum" 237*da0073e9SAndroid Build Coastguard Worker return layout.__name__.split(".")[0].split("_")[-1] 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker# Helper function for test_dummy_mha_with_nt 241*da0073e9SAndroid Build Coastguard Worker@torch.fx.wrap 242*da0073e9SAndroid Build Coastguard Workerdef convert_dense_to_nested_tensor_legacy(values): 243*da0073e9SAndroid Build Coastguard Worker offsets = torch.arange( 244*da0073e9SAndroid Build Coastguard Worker 0, values.shape[0] * values.shape[1] + 1, values.shape[1], device=values.device 245*da0073e9SAndroid Build Coastguard Worker ) 246*da0073e9SAndroid Build Coastguard Worker metadata_cache = {"max_seqlen": values.shape[1], "min_seqlen": 1} 247*da0073e9SAndroid Build Coastguard Worker nt = ViewNestedFromBuffer.apply( 248*da0073e9SAndroid Build Coastguard Worker values.view(-1, values.shape[-1]), offsets, metadata_cache 249*da0073e9SAndroid Build Coastguard Worker ) 250*da0073e9SAndroid Build Coastguard Worker return nt 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker# Helper function for test_dummy_mha_with_nt 254*da0073e9SAndroid Build Coastguard Worker@torch.fx.wrap 255*da0073e9SAndroid Build Coastguard Workerdef convert_jagged_to_nested_tensor_legacy( 256*da0073e9SAndroid Build Coastguard Worker values: torch.Tensor, offsets: torch.Tensor, max_length: int 257*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor: 258*da0073e9SAndroid Build Coastguard Worker metadata_cache = {"max_seqlen": max_length, "min_seqlen": 1} 259*da0073e9SAndroid Build Coastguard Worker nt = ViewNestedFromBuffer.apply(values, offsets, metadata_cache) 260*da0073e9SAndroid Build Coastguard Worker return nt 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker# Helper function for test_dummy_mha_with_nt 264*da0073e9SAndroid Build Coastguard Worker@torch.fx.wrap 265*da0073e9SAndroid Build Coastguard Workerdef convert_nt_to_jagged_legacy(nt): 266*da0073e9SAndroid Build Coastguard Worker return buffer_from_jagged(nt) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker# Helper function for test_dummy_mha_with_nt 270*da0073e9SAndroid Build Coastguard Worker@torch.fx.wrap 271*da0073e9SAndroid Build Coastguard Workerdef convert_dense_to_nested_tensor(values): 272*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor(values, layout=torch.jagged) 273*da0073e9SAndroid Build Coastguard Worker return nt 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker# Helper function for test_dummy_mha_with_nt 277*da0073e9SAndroid Build Coastguard Worker@torch.fx.wrap 278*da0073e9SAndroid Build Coastguard Workerdef convert_jagged_to_nested_tensor( 279*da0073e9SAndroid Build Coastguard Worker values: torch.Tensor, offsets: torch.Tensor, max_length: int 280*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor: 281*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged( 282*da0073e9SAndroid Build Coastguard Worker values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length 283*da0073e9SAndroid Build Coastguard Worker ) 284*da0073e9SAndroid Build Coastguard Worker return nt 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker# Helper function for test_dummy_mha_with_nt 288*da0073e9SAndroid Build Coastguard Workerdef convert_nt_to_jagged(nt): 289*da0073e9SAndroid Build Coastguard Worker return nt.values() 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker@markDynamoStrictTest 293*da0073e9SAndroid Build Coastguard Workerclass TestNestedTensor(NestedTensorTestCase): 294*da0073e9SAndroid Build Coastguard Worker @parametrize("batch_size", [2, 4]) 295*da0073e9SAndroid Build Coastguard Worker @parametrize("max_seq_len", [3, 5]) 296*da0073e9SAndroid Build Coastguard Worker @parametrize("vocab_size", [10, 20]) 297*da0073e9SAndroid Build Coastguard Worker def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size): 298*da0073e9SAndroid Build Coastguard Worker data = [] 299*da0073e9SAndroid Build Coastguard Worker nested_tensor_ref_list = [] 300*da0073e9SAndroid Build Coastguard Worker for _ in range(batch_size): 301*da0073e9SAndroid Build Coastguard Worker if max_seq_len == 0: 302*da0073e9SAndroid Build Coastguard Worker length = 0 303*da0073e9SAndroid Build Coastguard Worker else: 304*da0073e9SAndroid Build Coastguard Worker length = np.random.randint(low=1, high=max_seq_len) 305*da0073e9SAndroid Build Coastguard Worker row = list(np.random.randint(low=0, high=vocab_size, size=(length,))) 306*da0073e9SAndroid Build Coastguard Worker data.append(row) 307*da0073e9SAndroid Build Coastguard Worker nested_tensor_ref_list.append(torch.Tensor(row)) 308*da0073e9SAndroid Build Coastguard Worker nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64) 309*da0073e9SAndroid Build Coastguard Worker nested_tensor_list = nested_tensor.unbind() 310*da0073e9SAndroid Build Coastguard Worker for id in range(batch_size): 311*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 312*da0073e9SAndroid Build Coastguard Worker nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64) 313*da0073e9SAndroid Build Coastguard Worker ) 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker @parametrize("batch_size", [2, 4]) 316*da0073e9SAndroid Build Coastguard Worker @parametrize("max_seq_len", [3, 5]) 317*da0073e9SAndroid Build Coastguard Worker @parametrize("vocab_size", [10, 20]) 318*da0073e9SAndroid Build Coastguard Worker def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size): 319*da0073e9SAndroid Build Coastguard Worker data = [] 320*da0073e9SAndroid Build Coastguard Worker nested_tensor_ref_list = [] 321*da0073e9SAndroid Build Coastguard Worker for _ in range(batch_size): 322*da0073e9SAndroid Build Coastguard Worker if max_seq_len == 0: 323*da0073e9SAndroid Build Coastguard Worker length = 0 324*da0073e9SAndroid Build Coastguard Worker else: 325*da0073e9SAndroid Build Coastguard Worker length = np.random.randint(low=1, high=max_seq_len) 326*da0073e9SAndroid Build Coastguard Worker row = list(np.random.randint(low=0, high=vocab_size, size=(length,))) 327*da0073e9SAndroid Build Coastguard Worker row = [list(item * np.arange(max_seq_len)) for item in row] 328*da0073e9SAndroid Build Coastguard Worker data.append(row) 329*da0073e9SAndroid Build Coastguard Worker nested_tensor_ref_list.append(torch.Tensor(row)) 330*da0073e9SAndroid Build Coastguard Worker nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64) 331*da0073e9SAndroid Build Coastguard Worker nested_tensor_list = nested_tensor.unbind() 332*da0073e9SAndroid Build Coastguard Worker for id in range(batch_size): 333*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 334*da0073e9SAndroid Build Coastguard Worker nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64) 335*da0073e9SAndroid Build Coastguard Worker ) 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker @parametrize("batch_size", [2, 4]) 338*da0073e9SAndroid Build Coastguard Worker @parametrize("max_seq_len", [3, 5]) 339*da0073e9SAndroid Build Coastguard Worker @parametrize("vocab_size", [10, 20]) 340*da0073e9SAndroid Build Coastguard Worker def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size): 341*da0073e9SAndroid Build Coastguard Worker data = [] 342*da0073e9SAndroid Build Coastguard Worker nested_tensor_ref_list = [] 343*da0073e9SAndroid Build Coastguard Worker for _ in range(batch_size): 344*da0073e9SAndroid Build Coastguard Worker if max_seq_len == 0: 345*da0073e9SAndroid Build Coastguard Worker length = 0 346*da0073e9SAndroid Build Coastguard Worker else: 347*da0073e9SAndroid Build Coastguard Worker length = np.random.randint(low=1, high=max_seq_len) 348*da0073e9SAndroid Build Coastguard Worker row = list( 349*da0073e9SAndroid Build Coastguard Worker np.random.randint(low=0, high=vocab_size, size=(length,)).astype(float) 350*da0073e9SAndroid Build Coastguard Worker ) 351*da0073e9SAndroid Build Coastguard Worker row = [list(item * np.arange(max_seq_len)) for item in row] 352*da0073e9SAndroid Build Coastguard Worker data.append(row) 353*da0073e9SAndroid Build Coastguard Worker nested_tensor_ref_list.append(torch.Tensor(row)) 354*da0073e9SAndroid Build Coastguard Worker nested_tensor = torch.nested.nested_tensor(data, dtype=torch.float) 355*da0073e9SAndroid Build Coastguard Worker nested_tensor_list = nested_tensor.unbind() 356*da0073e9SAndroid Build Coastguard Worker for id in range(batch_size): 357*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 358*da0073e9SAndroid Build Coastguard Worker nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.float) 359*da0073e9SAndroid Build Coastguard Worker ) 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 362*da0073e9SAndroid Build Coastguard Worker def _test_unbind_case(self, a, b): 363*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([a, b]) 364*da0073e9SAndroid Build Coastguard Worker a1, b1 = nt.unbind() 365*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a is not a1) 366*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b is not b1) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([a, b], dtype=a.dtype) 369*da0073e9SAndroid Build Coastguard Worker a1, b1 = nt.unbind(0) 370*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, a1) 371*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, b1) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker a = torch.randn((2, 3)).add_(1) 374*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([a]) 375*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, nt.unbind(0)[0]) 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 378*da0073e9SAndroid Build Coastguard Worker def test_unbind_0(self): 379*da0073e9SAndroid Build Coastguard Worker self._test_unbind_case(torch.tensor([1, 2]), torch.tensor([7, 8])) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 382*da0073e9SAndroid Build Coastguard Worker def test_unbind_1(self): 383*da0073e9SAndroid Build Coastguard Worker self._test_unbind_case(torch.tensor([1]), torch.tensor([7])) 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 386*da0073e9SAndroid Build Coastguard Worker def test_unbind_3(self): 387*da0073e9SAndroid Build Coastguard Worker self._test_unbind_case(torch.tensor([1.0]), torch.tensor([])) 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 390*da0073e9SAndroid Build Coastguard Worker def test_unbind_4(self): 391*da0073e9SAndroid Build Coastguard Worker self._test_unbind_case(torch.tensor([]), torch.tensor([])) 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 394*da0073e9SAndroid Build Coastguard Worker def test_unbind_dim(self): 395*da0073e9SAndroid Build Coastguard Worker def _test_fn(unbind_fn): 396*da0073e9SAndroid Build Coastguard Worker a = torch.rand(3, 2) 397*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 3) 398*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([a, b]) 399*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1)) 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker # Both of these tests are necessary, because we're using 402*da0073e9SAndroid Build Coastguard Worker # torch_function. 403*da0073e9SAndroid Build Coastguard Worker _test_fn(lambda x, dim: x.unbind(dim)) 404*da0073e9SAndroid Build Coastguard Worker # TODO: Re-enable this once using torch_dispatch 405*da0073e9SAndroid Build Coastguard Worker # _test_fn(lambda x, dim: torch.unbind(x, dim)) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 408*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor(self): 409*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 410*da0073e9SAndroid Build Coastguard Worker TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0])) 411*da0073e9SAndroid Build Coastguard Worker ) 412*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0)) 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 415*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_matching_dim(self): 416*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 417*da0073e9SAndroid Build Coastguard Worker RuntimeError, 418*da0073e9SAndroid Build Coastguard Worker "Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.", 419*da0073e9SAndroid Build Coastguard Worker lambda: torch.nested.nested_tensor([torch.tensor(1.0), torch.tensor([])]), 420*da0073e9SAndroid Build Coastguard Worker ) 421*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 422*da0073e9SAndroid Build Coastguard Worker RuntimeError, 423*da0073e9SAndroid Build Coastguard Worker "Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.", 424*da0073e9SAndroid Build Coastguard Worker lambda: torch.nested.nested_tensor( 425*da0073e9SAndroid Build Coastguard Worker [torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])] 426*da0073e9SAndroid Build Coastguard Worker ), 427*da0073e9SAndroid Build Coastguard Worker ) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 430*da0073e9SAndroid Build Coastguard Worker def test_default_nested_tensor(self): 431*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: torch.nested.nested_tensor()) 432*da0073e9SAndroid Build Coastguard Worker default_nested_tensor = torch.nested.nested_tensor([]) 433*da0073e9SAndroid Build Coastguard Worker default_tensor = torch.tensor([]) 434*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(default_nested_tensor.nested_dim(), 1) 435*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(default_nested_tensor.nested_size(), ()) 436*da0073e9SAndroid Build Coastguard Worker self.assertEqual(default_nested_tensor.dim(), default_tensor.dim()) 437*da0073e9SAndroid Build Coastguard Worker self.assertEqual(default_nested_tensor.layout, default_tensor.layout) 438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(default_nested_tensor.device, default_tensor.device) 439*da0073e9SAndroid Build Coastguard Worker self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype) 440*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 441*da0073e9SAndroid Build Coastguard Worker default_nested_tensor.requires_grad, default_tensor.requires_grad 442*da0073e9SAndroid Build Coastguard Worker ) 443*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(default_tensor.grad) 444*da0073e9SAndroid Build Coastguard Worker # TODO: Re-enable once we have a performance driven 445*da0073e9SAndroid Build Coastguard Worker # use case and implementation. 446*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(default_nested_tensor.is_pinned(), 447*da0073e9SAndroid Build Coastguard Worker # default_tensor.is_pinned()) 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 450*da0073e9SAndroid Build Coastguard Worker def test_dim(self): 451*da0073e9SAndroid Build Coastguard Worker for constructor in _iter_constructors(): 452*da0073e9SAndroid Build Coastguard Worker a1 = constructor([]) 453*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.dim(), 1) 454*da0073e9SAndroid Build Coastguard Worker a1 = constructor([torch.tensor(3.0)]) 455*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.dim(), 1) 456*da0073e9SAndroid Build Coastguard Worker a1 = constructor([torch.tensor([1, 2, 3, 4])]) 457*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.dim(), 2) 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.") 460*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 461*da0073e9SAndroid Build Coastguard Worker def test_numel(self): 462*da0073e9SAndroid Build Coastguard Worker for constructor in _iter_constructors(): 463*da0073e9SAndroid Build Coastguard Worker a1 = constructor([]) 464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.numel(), 0) 465*da0073e9SAndroid Build Coastguard Worker a1 = constructor([torch.tensor(3.0), torch.tensor(4.0)]) 466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.numel(), 2) 467*da0073e9SAndroid Build Coastguard Worker a1 = constructor([torch.randn(2, 2, 2)]) 468*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.numel(), 8) 469*da0073e9SAndroid Build Coastguard Worker a1 = constructor([torch.randn([1, 2, 3]), torch.randn(3, 2, 1)]) 470*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.numel(), 12) 471*da0073e9SAndroid Build Coastguard Worker a1 = constructor([torch.randn([1, 1, 3]), torch.randn(3, 2, 4)]) 472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.numel(), 27) 473*da0073e9SAndroid Build Coastguard Worker a1 = constructor([torch.randn([5, 5, 5]), torch.randn(6, 6, 6)]) 474*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.numel(), 341) 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker # Interesting edge case 477*da0073e9SAndroid Build Coastguard Worker a1 = constructor([torch.randn([1, 2, 3]), torch.randn(1, 2, 0)]) 478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.numel(), 6) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 481*da0073e9SAndroid Build Coastguard Worker def test_size(self): 482*da0073e9SAndroid Build Coastguard Worker for constructor in _iter_constructors(): 483*da0073e9SAndroid Build Coastguard Worker a1 = constructor([]) 484*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 485*da0073e9SAndroid Build Coastguard Worker RuntimeError, 486*da0073e9SAndroid Build Coastguard Worker "NestedTensorImpl doesn't support sizes", 487*da0073e9SAndroid Build Coastguard Worker lambda: a1.size(), 488*da0073e9SAndroid Build Coastguard Worker ) 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker def test_size_dim(self): 491*da0073e9SAndroid Build Coastguard Worker a = torch.nested.nested_tensor([]) 492*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(0), 0) 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker a = torch.nested.nested_tensor([torch.tensor(1)]) 495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(0), 1) 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker a = torch.nested.nested_tensor([torch.tensor(1), torch.tensor(2)]) 498*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(0), 2) 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker a = torch.nested.nested_tensor([torch.rand(1, 2), torch.rand(1, 8)]) 501*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(0), 2) 502*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(1), 1) 503*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 504*da0073e9SAndroid Build Coastguard Worker RuntimeError, 505*da0073e9SAndroid Build Coastguard Worker "Given dimension 2 is irregular and does not have a size", 506*da0073e9SAndroid Build Coastguard Worker lambda: a.size(2), 507*da0073e9SAndroid Build Coastguard Worker ) 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker a = torch.nested.nested_tensor([torch.rand(3, 4), torch.rand(5, 4)]) 510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(0), 2) 511*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 512*da0073e9SAndroid Build Coastguard Worker RuntimeError, 513*da0073e9SAndroid Build Coastguard Worker "Given dimension 1 is irregular and does not have a size", 514*da0073e9SAndroid Build Coastguard Worker lambda: a.size(1), 515*da0073e9SAndroid Build Coastguard Worker ) 516*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(2), 4) 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.") 519*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 520*da0073e9SAndroid Build Coastguard Worker def test_stride(self): 521*da0073e9SAndroid Build Coastguard Worker for constructor in _iter_constructors(): 522*da0073e9SAndroid Build Coastguard Worker a1 = constructor([]) 523*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 524*da0073e9SAndroid Build Coastguard Worker RuntimeError, 525*da0073e9SAndroid Build Coastguard Worker "NestedTensorImpl doesn't support strides", 526*da0073e9SAndroid Build Coastguard Worker lambda: a1.stride(), 527*da0073e9SAndroid Build Coastguard Worker ) 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.") 530*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 531*da0073e9SAndroid Build Coastguard Worker def test_is_contiguous(self): 532*da0073e9SAndroid Build Coastguard Worker # Test empty case 533*da0073e9SAndroid Build Coastguard Worker nt_empty = torch.nested.nested_tensor([]) 534*da0073e9SAndroid Build Coastguard Worker assert nt_empty.is_contiguous() 535*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_empty, nt_empty.contiguous()) 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker # Test contiguous case 540*da0073e9SAndroid Build Coastguard Worker assert nt_contiguous.is_contiguous() 541*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_contiguous, nt_contiguous.contiguous()) 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker # Test non_contiguous case 544*da0073e9SAndroid Build Coastguard Worker assert not nt_noncontiguous.is_contiguous() 545*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous()) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker # Test querying by memory_format 548*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 549*da0073e9SAndroid Build Coastguard Worker nt_contiguous.is_contiguous(memory_format=torch.contiguous_format) 550*da0073e9SAndroid Build Coastguard Worker ) 551*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 552*da0073e9SAndroid Build Coastguard Worker not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format) 553*da0073e9SAndroid Build Coastguard Worker ) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 556*da0073e9SAndroid Build Coastguard Worker def test_repr_string(self): 557*da0073e9SAndroid Build Coastguard Worker a = torch.nested.nested_tensor([]) 558*da0073e9SAndroid Build Coastguard Worker expected = "nested_tensor([\n\n])" 559*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(a), expected) 560*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(a), expected) 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker a = torch.nested.nested_tensor([torch.tensor(1.0)]) 563*da0073e9SAndroid Build Coastguard Worker expected = "nested_tensor([\n tensor(1.)\n])" 564*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(a), expected) 565*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(a), expected) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker a = torch.nested.nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])]) 568*da0073e9SAndroid Build Coastguard Worker expected = "nested_tensor([\n tensor([[1, 2]]),\n tensor([[4, 5]])\n])" 569*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(a), expected) 570*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(a), expected) 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker def test_to_padded_tensor_on_empty_tensor(self): 573*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([]) 574*da0073e9SAndroid Build Coastguard Worker empty = torch.nested.to_padded_tensor(nt, 4) 575*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty, torch.tensor([])) 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker def test_nested_namespace(self): 578*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([torch.randn(2, 3), torch.randn(4, 5)]) 579*da0073e9SAndroid Build Coastguard Worker result = nt.to_padded_tensor(4) 580*da0073e9SAndroid Build Coastguard Worker nested_namespace_result = torch.nested.to_padded_tensor(nt, 4) 581*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, nested_namespace_result) 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker def test_to(self): 584*da0073e9SAndroid Build Coastguard Worker ntensors = 4 585*da0073e9SAndroid Build Coastguard Worker nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker def test_copy_behavior(t, non_blocking=False): 588*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, t.to(t, non_blocking=non_blocking)) 589*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking)) 590*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking)) 591*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True)) 592*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True)) 593*da0073e9SAndroid Build Coastguard Worker self.assertIsNot( 594*da0073e9SAndroid Build Coastguard Worker t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True) 595*da0073e9SAndroid Build Coastguard Worker ) 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker devices = [t.device] 598*da0073e9SAndroid Build Coastguard Worker if t.device.type == "cuda": 599*da0073e9SAndroid Build Coastguard Worker if t.device.index == -1: 600*da0073e9SAndroid Build Coastguard Worker devices.append(f"cuda:{torch.cuda.current_device()}") 601*da0073e9SAndroid Build Coastguard Worker elif t.device.index == torch.cuda.current_device(): 602*da0073e9SAndroid Build Coastguard Worker devices.append("cuda") 603*da0073e9SAndroid Build Coastguard Worker for device in devices: 604*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, t.to(device, non_blocking=non_blocking)) 605*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking)) 606*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True)) 607*da0073e9SAndroid Build Coastguard Worker self.assertIsNot( 608*da0073e9SAndroid Build Coastguard Worker t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True) 609*da0073e9SAndroid Build Coastguard Worker ) 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker test_copy_behavior(nt) 612*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.device, nt.to("cpu").device) 613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.device, nt.to("cpu", dtype=torch.float32).device) 614*da0073e9SAndroid Build Coastguard Worker self.assertIs(torch.float32, nt.to("cpu", dtype=torch.float32).dtype) 615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.device, nt.to(torch.float32).device) 616*da0073e9SAndroid Build Coastguard Worker self.assertIs(torch.float32, nt.to(dtype=torch.float32).dtype) 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker def test_data_ptr(getter): 619*da0073e9SAndroid Build Coastguard Worker self.assertEqual(getter(nt), getter(nt.to("cpu"))) 620*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 621*da0073e9SAndroid Build Coastguard Worker getter(nt), getter(nt.to(dtype=nt.dtype, device=nt.device, copy=False)) 622*da0073e9SAndroid Build Coastguard Worker ) 623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(getter(nt), getter(nt.to("cpu", copy=False))) 624*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(getter(nt), getter(nt.to("cpu", copy=True))) 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker test_data_ptr(lambda nt: nt.data_ptr()) 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 629*da0073e9SAndroid Build Coastguard Worker for non_blocking in [True, False]: 630*da0073e9SAndroid Build Coastguard Worker for cuda in [ 631*da0073e9SAndroid Build Coastguard Worker "cuda", 632*da0073e9SAndroid Build Coastguard Worker "cuda:0" if torch.cuda.device_count() == 1 else "cuda:1", 633*da0073e9SAndroid Build Coastguard Worker ]: 634*da0073e9SAndroid Build Coastguard Worker nt2 = random_nt(cuda, torch.float32, ntensors, (4, 4)) 635*da0073e9SAndroid Build Coastguard Worker test_copy_behavior(nt2, non_blocking) 636*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 637*da0073e9SAndroid Build Coastguard Worker nt2.device, nt2.to(cuda, non_blocking=non_blocking).device 638*da0073e9SAndroid Build Coastguard Worker ) 639*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 640*da0073e9SAndroid Build Coastguard Worker nt.device, nt2.to("cpu", non_blocking=non_blocking).device 641*da0073e9SAndroid Build Coastguard Worker ) 642*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 643*da0073e9SAndroid Build Coastguard Worker nt2.device, nt.to(cuda, non_blocking=non_blocking).device 644*da0073e9SAndroid Build Coastguard Worker ) 645*da0073e9SAndroid Build Coastguard Worker self.assertIs( 646*da0073e9SAndroid Build Coastguard Worker torch.int32, 647*da0073e9SAndroid Build Coastguard Worker nt2.to( 648*da0073e9SAndroid Build Coastguard Worker "cpu", dtype=torch.int32, non_blocking=non_blocking 649*da0073e9SAndroid Build Coastguard Worker ).dtype, 650*da0073e9SAndroid Build Coastguard Worker ) 651*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 652*da0073e9SAndroid Build Coastguard Worker nt.device, 653*da0073e9SAndroid Build Coastguard Worker nt2.to( 654*da0073e9SAndroid Build Coastguard Worker "cpu", dtype=torch.int32, non_blocking=non_blocking 655*da0073e9SAndroid Build Coastguard Worker ).device, 656*da0073e9SAndroid Build Coastguard Worker ) 657*da0073e9SAndroid Build Coastguard Worker self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype) 658*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device) 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker def test_copy_(self): 661*da0073e9SAndroid Build Coastguard Worker ntensors = 4 662*da0073e9SAndroid Build Coastguard Worker nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) 663*da0073e9SAndroid Build Coastguard Worker nt_copy = torch.empty_like(nt) 664*da0073e9SAndroid Build Coastguard Worker nt_copy.copy_(nt) 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): 667*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_ub, nt_copy_ub) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker nt_error = torch.nested.nested_tensor([torch.tensor([0, 0])]) 670*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 671*da0073e9SAndroid Build Coastguard Worker RuntimeError, 672*da0073e9SAndroid Build Coastguard Worker "copy_ only supports tensors that are the same size for Nested implementations", 673*da0073e9SAndroid Build Coastguard Worker lambda: nt_error.copy_(nt), 674*da0073e9SAndroid Build Coastguard Worker ) 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 677*da0073e9SAndroid Build Coastguard Worker nt = random_nt(torch.device("cuda"), torch.float32, ntensors, (4, 4)) 678*da0073e9SAndroid Build Coastguard Worker nt_copy = torch.empty_like(nt, device=torch.device("cpu")) 679*da0073e9SAndroid Build Coastguard Worker nt_copy.copy_(nt, non_blocking=True) 680*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream(torch.cuda.current_device()).synchronize() 681*da0073e9SAndroid Build Coastguard Worker for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): 682*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_ub, nt_copy_ub) 683*da0073e9SAndroid Build Coastguard Worker 684*da0073e9SAndroid Build Coastguard Worker nt_copy = torch.empty_like(nt, device=torch.device("cpu")) 685*da0073e9SAndroid Build Coastguard Worker nt_copy.copy_(nt, non_blocking=False) 686*da0073e9SAndroid Build Coastguard Worker for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): 687*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_ub, nt_copy_ub) 688*da0073e9SAndroid Build Coastguard Worker 689*da0073e9SAndroid Build Coastguard Worker def test_fill_(self): 690*da0073e9SAndroid Build Coastguard Worker ntensors = 4 691*da0073e9SAndroid Build Coastguard Worker nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) 692*da0073e9SAndroid Build Coastguard Worker nt.fill_(10.0) 693*da0073e9SAndroid Build Coastguard Worker for nt_ub in nt.unbind(): 694*da0073e9SAndroid Build Coastguard Worker t = torch.empty_like(nt_ub) 695*da0073e9SAndroid Build Coastguard Worker t.fill_(10.0) 696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_ub, t) 697*da0073e9SAndroid Build Coastguard Worker 698*da0073e9SAndroid Build Coastguard Worker fill_tensor = torch.tensor([11.0]) 699*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 700*da0073e9SAndroid Build Coastguard Worker RuntimeError, 701*da0073e9SAndroid Build Coastguard Worker "fill_ only supports 0-dimension value tensor", 702*da0073e9SAndroid Build Coastguard Worker lambda: nt.fill_(fill_tensor), 703*da0073e9SAndroid Build Coastguard Worker ) 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Worker nt.fill_(fill_tensor[0]) 706*da0073e9SAndroid Build Coastguard Worker for nt_ub in nt.unbind(): 707*da0073e9SAndroid Build Coastguard Worker t = torch.empty_like(nt_ub) 708*da0073e9SAndroid Build Coastguard Worker t.fill_(11.0) 709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_ub, t) 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker def test_zero_(self): 712*da0073e9SAndroid Build Coastguard Worker ntensors = 4 713*da0073e9SAndroid Build Coastguard Worker nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) 714*da0073e9SAndroid Build Coastguard Worker nt.zero_() 715*da0073e9SAndroid Build Coastguard Worker for nt_ub in nt.unbind(): 716*da0073e9SAndroid Build Coastguard Worker t = torch.empty_like(nt_ub) 717*da0073e9SAndroid Build Coastguard Worker t.fill_(0.0) 718*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_ub, t) 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker @parametrize( 721*da0073e9SAndroid Build Coastguard Worker "func", 722*da0073e9SAndroid Build Coastguard Worker [torch.ones_like, torch.zeros_like, torch.randn_like], 723*da0073e9SAndroid Build Coastguard Worker name_fn=lambda f: f.__name__, 724*da0073e9SAndroid Build Coastguard Worker ) 725*da0073e9SAndroid Build Coastguard Worker def test_like_functions(self, func): 726*da0073e9SAndroid Build Coastguard Worker ntensors = 4 727*da0073e9SAndroid Build Coastguard Worker nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) 728*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 729*da0073e9SAndroid Build Coastguard Worker nt_like = func(nt) 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 732*da0073e9SAndroid Build Coastguard Worker for nt_ub in nt_like.unbind(): 733*da0073e9SAndroid Build Coastguard Worker t_like = func(nt_ub) 734*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_ub, t_like) 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker def test_cat(self): 737*da0073e9SAndroid Build Coastguard Worker # dim=0 success case 738*da0073e9SAndroid Build Coastguard Worker # No constraints on ragged structures matching. 739*da0073e9SAndroid Build Coastguard Worker x = random_nt_from_dims([5, None, 10]) 740*da0073e9SAndroid Build Coastguard Worker y = random_nt_from_dims([3, 4, None]) 741*da0073e9SAndroid Build Coastguard Worker output = torch.cat([x, y], dim=0) 742*da0073e9SAndroid Build Coastguard Worker for out_component, xy_component in zip( 743*da0073e9SAndroid Build Coastguard Worker output.unbind(), itertools.chain(x.unbind(), y.unbind()) 744*da0073e9SAndroid Build Coastguard Worker ): 745*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_component, xy_component) 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Worker # dim=-1 success case 748*da0073e9SAndroid Build Coastguard Worker # shape (B, *, D) 749*da0073e9SAndroid Build Coastguard Worker x = random_nt_from_dims([5, None, 10]) 750*da0073e9SAndroid Build Coastguard Worker # shape (B, *, D'); same structure as x but dim=-1 differs 751*da0073e9SAndroid Build Coastguard Worker y = random_nt_from_similar(x, dims=[-1, -1, 8]) 752*da0073e9SAndroid Build Coastguard Worker # should be shape (B, *, D + D') when supported 753*da0073e9SAndroid Build Coastguard Worker output = torch.cat([x, y], dim=-1) 754*da0073e9SAndroid Build Coastguard Worker for out_component, x_component, y_component in zip( 755*da0073e9SAndroid Build Coastguard Worker output.unbind(), x.unbind(), y.unbind() 756*da0073e9SAndroid Build Coastguard Worker ): 757*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 758*da0073e9SAndroid Build Coastguard Worker out_component, torch.cat([x_component, y_component], dim=-1) 759*da0073e9SAndroid Build Coastguard Worker ) 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker # dim between 0 and -1 success case 762*da0073e9SAndroid Build Coastguard Worker x = random_nt_from_dims([5, None, 2, 3]) 763*da0073e9SAndroid Build Coastguard Worker # same structure as x but dim=2 differs 764*da0073e9SAndroid Build Coastguard Worker y = random_nt_from_similar(x, dims=[-1, -1, 4, -1]) 765*da0073e9SAndroid Build Coastguard Worker output = torch.cat([x, y], dim=2) 766*da0073e9SAndroid Build Coastguard Worker for out_component, x_component, y_component in zip( 767*da0073e9SAndroid Build Coastguard Worker output.unbind(), x.unbind(), y.unbind() 768*da0073e9SAndroid Build Coastguard Worker ): 769*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 770*da0073e9SAndroid Build Coastguard Worker out_component, torch.cat([x_component, y_component], dim=1) 771*da0073e9SAndroid Build Coastguard Worker ) 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker # error case: mixed NT / dense inputs 774*da0073e9SAndroid Build Coastguard Worker x = random_nt_from_dims([5, None, 2]) 775*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 3, 2) 776*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 777*da0073e9SAndroid Build Coastguard Worker RuntimeError, "expected each tensor in given list to be nested" 778*da0073e9SAndroid Build Coastguard Worker ): 779*da0073e9SAndroid Build Coastguard Worker torch.cat([x, y], dim=-1) 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker # error case: NTs with different dims 782*da0073e9SAndroid Build Coastguard Worker x = random_nt_from_dims([5, None, 2]) 783*da0073e9SAndroid Build Coastguard Worker y = random_nt_from_dims([5, None, 2, 3]) 784*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 785*da0073e9SAndroid Build Coastguard Worker RuntimeError, 786*da0073e9SAndroid Build Coastguard Worker "expected all nested tensors to have matching ragged structures outside of the concatenated dim", 787*da0073e9SAndroid Build Coastguard Worker ): 788*da0073e9SAndroid Build Coastguard Worker torch.cat([x, y], dim=-1) 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Worker # error case: non-contiguous NT 791*da0073e9SAndroid Build Coastguard Worker x, y = random_nt_noncontiguous_pair((2, 3, 4), dtype=torch.float32) 792*da0073e9SAndroid Build Coastguard Worker # transpose to put ragged dim next to batch dim 793*da0073e9SAndroid Build Coastguard Worker x, y = x.transpose(-2, -1), y.transpose(-2, -1) 794*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 795*da0073e9SAndroid Build Coastguard Worker RuntimeError, "only contiguous nested tensors are supported" 796*da0073e9SAndroid Build Coastguard Worker ): 797*da0073e9SAndroid Build Coastguard Worker torch.cat([x, y], dim=-1) 798*da0073e9SAndroid Build Coastguard Worker 799*da0073e9SAndroid Build Coastguard Worker # error case: multiple ragged dims in inputs 800*da0073e9SAndroid Build Coastguard Worker x = random_nt_from_dims([5, None, None, 2]) 801*da0073e9SAndroid Build Coastguard Worker y = random_nt_from_similar(x) 802*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 803*da0073e9SAndroid Build Coastguard Worker RuntimeError, 804*da0073e9SAndroid Build Coastguard Worker "only nested tensors with a single ragged dim next to the batch dim are supported", 805*da0073e9SAndroid Build Coastguard Worker ): 806*da0073e9SAndroid Build Coastguard Worker torch.cat([x, y], dim=-1) 807*da0073e9SAndroid Build Coastguard Worker 808*da0073e9SAndroid Build Coastguard Worker # error case: ragged dim not next to batch dim 809*da0073e9SAndroid Build Coastguard Worker x = random_nt_from_dims([5, 2, None]) 810*da0073e9SAndroid Build Coastguard Worker y = random_nt_from_similar(x) 811*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 812*da0073e9SAndroid Build Coastguard Worker RuntimeError, 813*da0073e9SAndroid Build Coastguard Worker "only nested tensors with a single ragged dim next to the batch dim are supported", 814*da0073e9SAndroid Build Coastguard Worker ): 815*da0073e9SAndroid Build Coastguard Worker torch.cat([x, y], dim=1) 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker # error case: NTs with different batch sizes 818*da0073e9SAndroid Build Coastguard Worker x = random_nt_from_dims([5, None, 2]) 819*da0073e9SAndroid Build Coastguard Worker y = random_nt_from_dims([3, None, 2]) 820*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 821*da0073e9SAndroid Build Coastguard Worker RuntimeError, 822*da0073e9SAndroid Build Coastguard Worker "expected all nested tensors to have matching ragged structures outside of the concatenated dim", 823*da0073e9SAndroid Build Coastguard Worker ): 824*da0073e9SAndroid Build Coastguard Worker torch.cat([x, y], dim=-1) 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker # error case: NTs with different ragged structures 827*da0073e9SAndroid Build Coastguard Worker x = torch.nested.nested_tensor( 828*da0073e9SAndroid Build Coastguard Worker [ 829*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 6), 830*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 6), 831*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 6), 832*da0073e9SAndroid Build Coastguard Worker ] 833*da0073e9SAndroid Build Coastguard Worker ) 834*da0073e9SAndroid Build Coastguard Worker y = torch.nested.nested_tensor( 835*da0073e9SAndroid Build Coastguard Worker [ 836*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 6), 837*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 6), 838*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 6), 839*da0073e9SAndroid Build Coastguard Worker ] 840*da0073e9SAndroid Build Coastguard Worker ) 841*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 842*da0073e9SAndroid Build Coastguard Worker RuntimeError, 843*da0073e9SAndroid Build Coastguard Worker "expected all nested tensors to have matching ragged structures outside of the concatenated dim", 844*da0073e9SAndroid Build Coastguard Worker ): 845*da0073e9SAndroid Build Coastguard Worker torch.cat([x, y], dim=-1) 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker@markDynamoStrictTest 849*da0073e9SAndroid Build Coastguard Workerclass TestNestedTensorDeviceType(NestedTensorTestCase): 850*da0073e9SAndroid Build Coastguard Worker # Helper function to generate a pair of random nested tensors 851*da0073e9SAndroid Build Coastguard Worker # the 2 nested tensors have same shapes 852*da0073e9SAndroid Build Coastguard Worker def random_nt_pair(self, device, dtype, num_tensors, max_dims): 853*da0073e9SAndroid Build Coastguard Worker ts1 = [] 854*da0073e9SAndroid Build Coastguard Worker ts2 = [] 855*da0073e9SAndroid Build Coastguard Worker for _ in range(num_tensors): 856*da0073e9SAndroid Build Coastguard Worker tensor_dims = tuple( 857*da0073e9SAndroid Build Coastguard Worker [ 858*da0073e9SAndroid Build Coastguard Worker torch.randint(low=0, high=max_dim, size=(1,)).item() 859*da0073e9SAndroid Build Coastguard Worker for max_dim in max_dims 860*da0073e9SAndroid Build Coastguard Worker ] 861*da0073e9SAndroid Build Coastguard Worker ) 862*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn(tensor_dims, device=device, dtype=dtype) 863*da0073e9SAndroid Build Coastguard Worker t2 = torch.randn(tensor_dims, device=device, dtype=dtype) 864*da0073e9SAndroid Build Coastguard Worker ts1.append(t1) 865*da0073e9SAndroid Build Coastguard Worker ts2.append(t2) 866*da0073e9SAndroid Build Coastguard Worker return ( 867*da0073e9SAndroid Build Coastguard Worker torch.nested.nested_tensor(ts1, device=device, dtype=dtype), 868*da0073e9SAndroid Build Coastguard Worker torch.nested.nested_tensor(ts2, device=device, dtype=dtype), 869*da0073e9SAndroid Build Coastguard Worker ) 870*da0073e9SAndroid Build Coastguard Worker 871*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and_half()) 872*da0073e9SAndroid Build Coastguard Worker def test_detach(self, device, dtype): 873*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False) 874*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False) 875*da0073e9SAndroid Build Coastguard Worker x = torch.nested.nested_tensor([a, b], requires_grad=True) 876*da0073e9SAndroid Build Coastguard Worker 877*da0073e9SAndroid Build Coastguard Worker x_detach = x.detach() 878*da0073e9SAndroid Build Coastguard Worker 879*da0073e9SAndroid Build Coastguard Worker z = x_detach * 4 880*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x_detach.requires_grad) 881*da0073e9SAndroid Build Coastguard Worker self.assertFalse(z.requires_grad) 882*da0073e9SAndroid Build Coastguard Worker 883*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True) 884*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True) 885*da0073e9SAndroid Build Coastguard Worker x = torch.nested.as_nested_tensor([a, b]) 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker y = x * 2 888*da0073e9SAndroid Build Coastguard Worker y = y.detach() 889*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.requires_grad) 890*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(y.grad_fn) 891*da0073e9SAndroid Build Coastguard Worker 892*da0073e9SAndroid Build Coastguard Worker z = x + y 893*da0073e9SAndroid Build Coastguard Worker torch.nested.to_padded_tensor(z, 0).sum().backward() 894*da0073e9SAndroid Build Coastguard Worker # This is an incorrect gradient, but we assume that's what the user 895*da0073e9SAndroid Build Coastguard Worker # wanted. detach() is an advanced option. 896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype)) 897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype)) 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 900*da0073e9SAndroid Build Coastguard Worker def test_unbind_noncontiguous(self, device, dtype): 901*da0073e9SAndroid Build Coastguard Worker nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( 902*da0073e9SAndroid Build Coastguard Worker (2, 3, 6, 7), device, dtype 903*da0073e9SAndroid Build Coastguard Worker ) 904*da0073e9SAndroid Build Coastguard Worker ub_contiguous = nt_contiguous.unbind() 905*da0073e9SAndroid Build Coastguard Worker ub_noncontiguous = nt_noncontiguous.unbind() 906*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(ub_contiguous), len(ub_noncontiguous)) 907*da0073e9SAndroid Build Coastguard Worker n = len(ub_contiguous) 908*da0073e9SAndroid Build Coastguard Worker for i in range(n): 909*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ub_contiguous[i], ub_noncontiguous[i]) 910*da0073e9SAndroid Build Coastguard Worker 911*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 912*da0073e9SAndroid Build Coastguard Worker @skipMeta 913*da0073e9SAndroid Build Coastguard Worker def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype): 914*da0073e9SAndroid Build Coastguard Worker t = torch.randn(4, 4, 4, device=device, dtype=dtype) 915*da0073e9SAndroid Build Coastguard Worker ts = list(torch.unbind(t)) 916*da0073e9SAndroid Build Coastguard Worker ts[0] = ts[0][:-1] 917*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 918*da0073e9SAndroid Build Coastguard Worker padded = torch.nested.to_padded_tensor(nt, 0) 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker nt_to = torch._nested_from_padded_and_nested_example(padded, nt) 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Worker for t1, t2 in zip(nt.unbind(), nt_to.unbind()): 923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, t2) 924*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.device, nt_to.device) 925*da0073e9SAndroid Build Coastguard Worker 926*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 927*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float, torch.half) 928*da0073e9SAndroid Build Coastguard Worker @skipMeta 929*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 930*da0073e9SAndroid Build Coastguard Worker def test_layer_norm(self, device, dtype): 931*da0073e9SAndroid Build Coastguard Worker def _test(size): 932*da0073e9SAndroid Build Coastguard Worker # Simple shapes test 933*da0073e9SAndroid Build Coastguard Worker t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False) 934*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False) 935*da0073e9SAndroid Build Coastguard Worker ts = [t0, t1, t0, t1] 936*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 937*da0073e9SAndroid Build Coastguard Worker layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype) 938*da0073e9SAndroid Build Coastguard Worker nt_result = layer_norm(nt) 939*da0073e9SAndroid Build Coastguard Worker for nt_subresult, t in zip(nt_result.unbind(), ts): 940*da0073e9SAndroid Build Coastguard Worker t_result = layer_norm(t.reshape(1, -1, size).squeeze(0)) 941*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_subresult, t_result) 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker # More complex nt test with different lengths for each tensor 944*da0073e9SAndroid Build Coastguard Worker t0 = torch.randn(4, size, device=device, dtype=dtype, requires_grad=False) 945*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn(10, size, device=device, dtype=dtype, requires_grad=False) 946*da0073e9SAndroid Build Coastguard Worker t2 = torch.randn(7, size, device=device, dtype=dtype, requires_grad=False) 947*da0073e9SAndroid Build Coastguard Worker ts = [t0, t1, t2, t0, t2] 948*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 949*da0073e9SAndroid Build Coastguard Worker layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype) 950*da0073e9SAndroid Build Coastguard Worker nt_result = layer_norm(nt) 951*da0073e9SAndroid Build Coastguard Worker for nt_subresult, t in zip(nt_result.unbind(), ts): 952*da0073e9SAndroid Build Coastguard Worker t_result = layer_norm(t.reshape(1, -1, size).squeeze(0)) 953*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_subresult, t_result) 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker if size <= 128: 956*da0073e9SAndroid Build Coastguard Worker # Test with multidimensional tensors after irregular dim 957*da0073e9SAndroid Build Coastguard Worker # (run only with smaller dimensions to ensure fast execution) 958*da0073e9SAndroid Build Coastguard Worker t0 = torch.randn( 959*da0073e9SAndroid Build Coastguard Worker 4, size, size, 4, device=device, dtype=dtype, requires_grad=False 960*da0073e9SAndroid Build Coastguard Worker ) 961*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn( 962*da0073e9SAndroid Build Coastguard Worker 10, size, size, 4, device=device, dtype=dtype, requires_grad=False 963*da0073e9SAndroid Build Coastguard Worker ) 964*da0073e9SAndroid Build Coastguard Worker t2 = torch.randn( 965*da0073e9SAndroid Build Coastguard Worker 7, size, size, 4, device=device, dtype=dtype, requires_grad=False 966*da0073e9SAndroid Build Coastguard Worker ) 967*da0073e9SAndroid Build Coastguard Worker ts = [t0, t1, t2, t0, t2] 968*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 969*da0073e9SAndroid Build Coastguard Worker layer_norm = torch.nn.LayerNorm( 970*da0073e9SAndroid Build Coastguard Worker (size, size, 4), device=device, dtype=dtype 971*da0073e9SAndroid Build Coastguard Worker ) 972*da0073e9SAndroid Build Coastguard Worker nt_result = layer_norm(nt) 973*da0073e9SAndroid Build Coastguard Worker for nt_subresult, t in zip(nt_result.unbind(), ts): 974*da0073e9SAndroid Build Coastguard Worker t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0)) 975*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_subresult, t_result) 976*da0073e9SAndroid Build Coastguard Worker 977*da0073e9SAndroid Build Coastguard Worker # Test where the normalizing dimensions are not all 978*da0073e9SAndroid Build Coastguard Worker layer_norm = torch.nn.LayerNorm((size, 4), device=device, dtype=dtype) 979*da0073e9SAndroid Build Coastguard Worker nt_result = layer_norm(nt) 980*da0073e9SAndroid Build Coastguard Worker for nt_subresult, t in zip(nt_result.unbind(), ts): 981*da0073e9SAndroid Build Coastguard Worker t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0)) 982*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_subresult, t_result) 983*da0073e9SAndroid Build Coastguard Worker 984*da0073e9SAndroid Build Coastguard Worker for size in (1024, 1023, 513, 512, 256, 128, 2, 4, 32): 985*da0073e9SAndroid Build Coastguard Worker _test(size) 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 988*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float, torch.half) 989*da0073e9SAndroid Build Coastguard Worker @skipMeta 990*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 991*da0073e9SAndroid Build Coastguard Worker def test_layer_norm_breaking(self, device, dtype): 992*da0073e9SAndroid Build Coastguard Worker size = 128 993*da0073e9SAndroid Build Coastguard Worker t0 = torch.randn( 994*da0073e9SAndroid Build Coastguard Worker 4, size, size, 4, device=device, dtype=dtype, requires_grad=False 995*da0073e9SAndroid Build Coastguard Worker ) 996*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn( 997*da0073e9SAndroid Build Coastguard Worker 10, size, size, 4, device=device, dtype=dtype, requires_grad=False 998*da0073e9SAndroid Build Coastguard Worker ) 999*da0073e9SAndroid Build Coastguard Worker t2 = torch.randn( 1000*da0073e9SAndroid Build Coastguard Worker 7, size, size, 4, device=device, dtype=dtype, requires_grad=False 1001*da0073e9SAndroid Build Coastguard Worker ) 1002*da0073e9SAndroid Build Coastguard Worker ts = [t0, t1, t2, t0, t2] 1003*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1004*da0073e9SAndroid Build Coastguard Worker layer_norm = torch.nn.LayerNorm((4, size, size, 4), device=device, dtype=dtype) 1005*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1006*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1007*da0073e9SAndroid Build Coastguard Worker "normalized_shape extends into irregular dimensions for the nested tensor", 1008*da0073e9SAndroid Build Coastguard Worker lambda: layer_norm(nt), 1009*da0073e9SAndroid Build Coastguard Worker ) 1010*da0073e9SAndroid Build Coastguard Worker layer_norm = torch.nn.LayerNorm((size + 1, size, 4), device=device, dtype=dtype) 1011*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1012*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1013*da0073e9SAndroid Build Coastguard Worker "The shape at dimension 0", 1014*da0073e9SAndroid Build Coastguard Worker lambda: layer_norm(nt), 1015*da0073e9SAndroid Build Coastguard Worker ) 1016*da0073e9SAndroid Build Coastguard Worker 1017*da0073e9SAndroid Build Coastguard Worker @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) 1018*da0073e9SAndroid Build Coastguard Worker def test_embedding(self, device, layout): 1019*da0073e9SAndroid Build Coastguard Worker inputs = [ 1020*da0073e9SAndroid Build Coastguard Worker torch.randint(100, (L,), device=device, dtype=torch.int64) 1021*da0073e9SAndroid Build Coastguard Worker for L in torch.randint(5, 50, (8,)) 1022*da0073e9SAndroid Build Coastguard Worker ] 1023*da0073e9SAndroid Build Coastguard Worker x = torch.nested.nested_tensor( 1024*da0073e9SAndroid Build Coastguard Worker inputs, device=device, dtype=torch.int64, layout=layout 1025*da0073e9SAndroid Build Coastguard Worker ) 1026*da0073e9SAndroid Build Coastguard Worker emb = torch.nn.Embedding(100, 8, device=device) 1027*da0073e9SAndroid Build Coastguard Worker y = emb(x) 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable 1030*da0073e9SAndroid Build Coastguard Worker def check(inputs, y): 1031*da0073e9SAndroid Build Coastguard Worker ys = y.unbind() 1032*da0073e9SAndroid Build Coastguard Worker for i, inp in enumerate(inputs): 1033*da0073e9SAndroid Build Coastguard Worker self.assertEqual(emb(inp), ys[i]) 1034*da0073e9SAndroid Build Coastguard Worker 1035*da0073e9SAndroid Build Coastguard Worker check(inputs, y) 1036*da0073e9SAndroid Build Coastguard Worker 1037*da0073e9SAndroid Build Coastguard Worker @skipMeta 1038*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 1039*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and_half()) 1040*da0073e9SAndroid Build Coastguard Worker def test_masked_fill(self, device, dtype): 1041*da0073e9SAndroid Build Coastguard Worker # nested tensor * nested tensor 1042*da0073e9SAndroid Build Coastguard Worker (nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4)) 1043*da0073e9SAndroid Build Coastguard Worker mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()]) 1044*da0073e9SAndroid Build Coastguard Worker ref = torch.nested.nested_tensor( 1045*da0073e9SAndroid Build Coastguard Worker [t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())] 1046*da0073e9SAndroid Build Coastguard Worker ) 1047*da0073e9SAndroid Build Coastguard Worker out = nt.masked_fill(mask, 0) 1048*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, out) 1049*da0073e9SAndroid Build Coastguard Worker 1050*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16) 1051*da0073e9SAndroid Build Coastguard Worker def test_to_padded_tensor_simple(self, device, dtype): 1052*da0073e9SAndroid Build Coastguard Worker t = torch.randn(4, 4, 4, device=device, dtype=dtype) 1053*da0073e9SAndroid Build Coastguard Worker ts = list(torch.unbind(t)) 1054*da0073e9SAndroid Build Coastguard Worker ts[0] = ts[0][:-1] 1055*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1056*da0073e9SAndroid Build Coastguard Worker for padding_value in (0, 1): 1057*da0073e9SAndroid Build Coastguard Worker padded = torch.nested.to_padded_tensor(nt, padding_value) 1058*da0073e9SAndroid Build Coastguard Worker 1059*da0073e9SAndroid Build Coastguard Worker correct_output = t.clone() 1060*da0073e9SAndroid Build Coastguard Worker if padding_value == 0: 1061*da0073e9SAndroid Build Coastguard Worker correct_output[0][-1] = torch.zeros_like(correct_output[0][-1]) 1062*da0073e9SAndroid Build Coastguard Worker else: 1063*da0073e9SAndroid Build Coastguard Worker correct_output[0][-1] = torch.ones_like(correct_output[0][-1]) 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker self.assertEqual(padded, correct_output) 1066*da0073e9SAndroid Build Coastguard Worker self.assertEqual(padded.device, torch.device(device)) 1067*da0073e9SAndroid Build Coastguard Worker self.assertEqual(padded.dtype, dtype) 1068*da0073e9SAndroid Build Coastguard Worker 1069*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16) 1070*da0073e9SAndroid Build Coastguard Worker def test_to_padded_tensor_output_size(self, device, dtype): 1071*da0073e9SAndroid Build Coastguard Worker t = torch.randn(4, 4, 4, device=device, dtype=dtype) 1072*da0073e9SAndroid Build Coastguard Worker output_size = (4, 6, 5) 1073*da0073e9SAndroid Build Coastguard Worker ts = list(torch.unbind(t)) 1074*da0073e9SAndroid Build Coastguard Worker ts[0] = ts[0][:-1] 1075*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1076*da0073e9SAndroid Build Coastguard Worker for padding_value in (0, 1): 1077*da0073e9SAndroid Build Coastguard Worker padded = torch.nested.to_padded_tensor( 1078*da0073e9SAndroid Build Coastguard Worker nt, padding_value, output_size=output_size 1079*da0073e9SAndroid Build Coastguard Worker ) 1080*da0073e9SAndroid Build Coastguard Worker correct_output = ( 1081*da0073e9SAndroid Build Coastguard Worker torch.ones(output_size, device=device, dtype=dtype) * padding_value 1082*da0073e9SAndroid Build Coastguard Worker ) 1083*da0073e9SAndroid Build Coastguard Worker correct_output[:4:, :4, :4] = t.clone() 1084*da0073e9SAndroid Build Coastguard Worker if padding_value == 0: 1085*da0073e9SAndroid Build Coastguard Worker correct_output[0][3] = torch.zeros_like(correct_output[0][3]) 1086*da0073e9SAndroid Build Coastguard Worker else: 1087*da0073e9SAndroid Build Coastguard Worker correct_output[0][3] = torch.ones_like(correct_output[0][3]) 1088*da0073e9SAndroid Build Coastguard Worker 1089*da0073e9SAndroid Build Coastguard Worker self.assertEqual(padded, correct_output) 1090*da0073e9SAndroid Build Coastguard Worker self.assertEqual(padded.device, torch.device(device)) 1091*da0073e9SAndroid Build Coastguard Worker self.assertEqual(padded.dtype, dtype) 1092*da0073e9SAndroid Build Coastguard Worker 1093*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 1094*da0073e9SAndroid Build Coastguard Worker def test_to_padded_tensor_dim2(self, device, dtype): 1095*da0073e9SAndroid Build Coastguard Worker ts = [ 1096*da0073e9SAndroid Build Coastguard Worker torch.randn(160, device=device, dtype=dtype), 1097*da0073e9SAndroid Build Coastguard Worker torch.randn(1240, device=device, dtype=dtype), 1098*da0073e9SAndroid Build Coastguard Worker torch.randn(2400, device=device, dtype=dtype), 1099*da0073e9SAndroid Build Coastguard Worker ] 1100*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1101*da0073e9SAndroid Build Coastguard Worker pad = 42 1102*da0073e9SAndroid Build Coastguard Worker correct_output = [] 1103*da0073e9SAndroid Build Coastguard Worker for t in ts: 1104*da0073e9SAndroid Build Coastguard Worker next_output = torch.ones_like(ts[2]) * pad 1105*da0073e9SAndroid Build Coastguard Worker correct_output.append(next_output) 1106*da0073e9SAndroid Build Coastguard Worker next_output[: t.size(0)].copy_(t) 1107*da0073e9SAndroid Build Coastguard Worker correct_output = torch.stack(correct_output) 1108*da0073e9SAndroid Build Coastguard Worker padded = torch.nested.to_padded_tensor(nt, pad) 1109*da0073e9SAndroid Build Coastguard Worker self.assertEqual(padded, correct_output) 1110*da0073e9SAndroid Build Coastguard Worker 1111*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 1112*da0073e9SAndroid Build Coastguard Worker def test_to_padded_tensor_dim3(self, device, dtype): 1113*da0073e9SAndroid Build Coastguard Worker ts = [ 1114*da0073e9SAndroid Build Coastguard Worker torch.randn(16, 21, device=device, dtype=dtype), 1115*da0073e9SAndroid Build Coastguard Worker torch.randn(24, 32, device=device, dtype=dtype), 1116*da0073e9SAndroid Build Coastguard Worker torch.randn(40, 53, device=device, dtype=dtype), 1117*da0073e9SAndroid Build Coastguard Worker ] 1118*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1119*da0073e9SAndroid Build Coastguard Worker pad = 42 1120*da0073e9SAndroid Build Coastguard Worker correct_output = [] 1121*da0073e9SAndroid Build Coastguard Worker for t in ts: 1122*da0073e9SAndroid Build Coastguard Worker next_output = torch.ones_like(ts[2]) * pad 1123*da0073e9SAndroid Build Coastguard Worker correct_output.append(next_output) 1124*da0073e9SAndroid Build Coastguard Worker next_output[: t.size(0), : t.size(1)].copy_(t) 1125*da0073e9SAndroid Build Coastguard Worker correct_output = torch.stack(correct_output) 1126*da0073e9SAndroid Build Coastguard Worker padded = torch.nested.to_padded_tensor(nt, pad) 1127*da0073e9SAndroid Build Coastguard Worker self.assertEqual(padded, correct_output) 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 1130*da0073e9SAndroid Build Coastguard Worker def test_to_padded_tensor_dim4(self, device, dtype): 1131*da0073e9SAndroid Build Coastguard Worker ts = [ 1132*da0073e9SAndroid Build Coastguard Worker torch.randn(16, 21, 13, device=device, dtype=dtype), 1133*da0073e9SAndroid Build Coastguard Worker torch.randn(24, 32, 14, device=device, dtype=dtype), 1134*da0073e9SAndroid Build Coastguard Worker torch.randn(40, 53, 16, device=device, dtype=dtype), 1135*da0073e9SAndroid Build Coastguard Worker ] 1136*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1137*da0073e9SAndroid Build Coastguard Worker pad = 42 1138*da0073e9SAndroid Build Coastguard Worker correct_output = [] 1139*da0073e9SAndroid Build Coastguard Worker for t in ts: 1140*da0073e9SAndroid Build Coastguard Worker next_output = torch.ones_like(ts[2]) * pad 1141*da0073e9SAndroid Build Coastguard Worker correct_output.append(next_output) 1142*da0073e9SAndroid Build Coastguard Worker next_output[: t.size(0), : t.size(1), : t.size(2)].copy_(t) 1143*da0073e9SAndroid Build Coastguard Worker correct_output = torch.stack(correct_output) 1144*da0073e9SAndroid Build Coastguard Worker padded = torch.nested.to_padded_tensor(nt, pad) 1145*da0073e9SAndroid Build Coastguard Worker self.assertEqual(padded, correct_output) 1146*da0073e9SAndroid Build Coastguard Worker 1147*da0073e9SAndroid Build Coastguard Worker # TODO: test noncontiguous to_padded_tensor 1148*da0073e9SAndroid Build Coastguard Worker # For now this tests the functionality of noncontiguous_to_padded_tensor 1149*da0073e9SAndroid Build Coastguard Worker # and the error message of to_padded_tensor 1150*da0073e9SAndroid Build Coastguard Worker # since to_padded_tensor does not support noncontiguous buffer yet 1151*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 1152*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 1153*da0073e9SAndroid Build Coastguard Worker def test_to_padded_tensor_noncontiguous(self, device, dtype): 1154*da0073e9SAndroid Build Coastguard Worker nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( 1155*da0073e9SAndroid Build Coastguard Worker (2, 3, 6, 7), device, dtype 1156*da0073e9SAndroid Build Coastguard Worker ) 1157*da0073e9SAndroid Build Coastguard Worker # test noncontiguous_to_padded_tensor functionality 1158*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1159*da0073e9SAndroid Build Coastguard Worker torch.nested.to_padded_tensor(nt_contiguous, 0.0), 1160*da0073e9SAndroid Build Coastguard Worker noncontiguous_to_padded_tensor(nt_noncontiguous), 1161*da0073e9SAndroid Build Coastguard Worker ) 1162*da0073e9SAndroid Build Coastguard Worker # test to_padded_tensor error message 1163*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1164*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1165*da0073e9SAndroid Build Coastguard Worker r"for now to_padded_tensor only supports contiguous nested tensor", 1166*da0073e9SAndroid Build Coastguard Worker lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0), 1167*da0073e9SAndroid Build Coastguard Worker ) 1168*da0073e9SAndroid Build Coastguard Worker 1169*da0073e9SAndroid Build Coastguard Worker @skipMeta 1170*da0073e9SAndroid Build Coastguard Worker def test_device_checks(self, device): 1171*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([], device=device) 1172*da0073e9SAndroid Build Coastguard Worker is_cuda = "cuda" in str(device) 1173*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.is_cuda, is_cuda) 1174*da0073e9SAndroid Build Coastguard Worker 1175*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 1176*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_indexing(self, device, dtype): 1177*da0073e9SAndroid Build Coastguard Worker # edge case: empty nested tensor 1178*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor([]) 1179*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: nt0[0]) 1180*da0073e9SAndroid Build Coastguard Worker # normal case 1181*da0073e9SAndroid Build Coastguard Worker x0 = torch.randn((2, 5), device=device, dtype=dtype) 1182*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn((3, 4), device=device, dtype=dtype) 1183*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([x0, x1]) 1184*da0073e9SAndroid Build Coastguard Worker # single index: only support integer in the batch dimension 1185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt[0], x0) 1186*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt[-1], x1) 1187*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: nt[2]) 1188*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: nt[-3]) 1189*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, lambda: nt[:]) 1190*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt[...], nt) 1191*da0073e9SAndroid Build Coastguard Worker # tuple of indices: only support integer in the batch dimension 1192*da0073e9SAndroid Build Coastguard Worker # + all possible indexing in the original tensor dimensions 1193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt[0, 0, 0], x0[0, 0]) 1194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt[0, 1, :], x0[1, :]) 1195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt[1, ...], x1) 1196*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: nt[1, 4, 2]) 1197*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1]) 1198*da0073e9SAndroid Build Coastguard Worker # test select on non-batch dimensions 1199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.select(1, 0)[0], x0.select(0, 0)) 1200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.select(1, 0)[1], x1.select(0, 0)) 1201*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: nt.select(1, 3)) 1202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.select(2, 0)[0], x0.select(1, 0)) 1203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.select(2, 0)[1], x1.select(1, 0)) 1204*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: nt.select(2, 5)) 1205*da0073e9SAndroid Build Coastguard Worker # make sure indexing returns a view 1206*da0073e9SAndroid Build Coastguard Worker nt[0].fill_(100.0) 1207*da0073e9SAndroid Build Coastguard Worker answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5)) 1208*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt[0], answer) 1209*da0073e9SAndroid Build Coastguard Worker nt[1, 1, :].fill_(200.0) 1210*da0073e9SAndroid Build Coastguard Worker answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4) 1211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt[1, 1, :], answer) 1212*da0073e9SAndroid Build Coastguard Worker 1213*da0073e9SAndroid Build Coastguard Worker # Test that indexing works when requires_grad_(True) 1214*da0073e9SAndroid Build Coastguard Worker # previously this was failing because the backward kernel for select.int uses .sizes() 1215*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([x0, x1]).requires_grad_(True) 1216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt[0], x0) 1217*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt[-1], x1) 1218*da0073e9SAndroid Build Coastguard Worker grad_x0 = torch.randn((2, 5), device=device, dtype=dtype) 1219*da0073e9SAndroid Build Coastguard Worker nt[0].backward(grad_x0) 1220*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.nested.nested_tensor( 1221*da0073e9SAndroid Build Coastguard Worker [grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)] 1222*da0073e9SAndroid Build Coastguard Worker ) 1223*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.grad, expected_grad) 1224*da0073e9SAndroid Build Coastguard Worker 1225*da0073e9SAndroid Build Coastguard Worker @parametrize( 1226*da0073e9SAndroid Build Coastguard Worker "func", 1227*da0073e9SAndroid Build Coastguard Worker [ 1228*da0073e9SAndroid Build Coastguard Worker subtest(torch.nn.functional.relu, name="relu"), 1229*da0073e9SAndroid Build Coastguard Worker subtest(torch.nn.functional.relu_, name="relu_"), 1230*da0073e9SAndroid Build Coastguard Worker subtest(torch.nn.functional.gelu, name="gelu"), 1231*da0073e9SAndroid Build Coastguard Worker subtest(torch._C._nn.gelu_, name="gelu_"), 1232*da0073e9SAndroid Build Coastguard Worker subtest(torch.tanh, name="tanh"), 1233*da0073e9SAndroid Build Coastguard Worker subtest(torch.tanh_, name="tanh_"), 1234*da0073e9SAndroid Build Coastguard Worker subtest(torch.neg, name="neg"), 1235*da0073e9SAndroid Build Coastguard Worker subtest(torch.nn.functional.silu, name="silu"), 1236*da0073e9SAndroid Build Coastguard Worker subtest(partial(torch.nn.functional.silu, inplace=True), name="silu_"), 1237*da0073e9SAndroid Build Coastguard Worker subtest(torch.abs, name="abs"), 1238*da0073e9SAndroid Build Coastguard Worker subtest(torch.abs_, name="abs_"), 1239*da0073e9SAndroid Build Coastguard Worker subtest(torch.sgn, name="sgn"), 1240*da0073e9SAndroid Build Coastguard Worker subtest(torch.logical_not, name="logical_not"), 1241*da0073e9SAndroid Build Coastguard Worker subtest(torch.sin, name="sin"), 1242*da0073e9SAndroid Build Coastguard Worker subtest(torch.cos, name="cos"), 1243*da0073e9SAndroid Build Coastguard Worker ], 1244*da0073e9SAndroid Build Coastguard Worker ) 1245*da0073e9SAndroid Build Coastguard Worker def test_activations(self, device, func): 1246*da0073e9SAndroid Build Coastguard Worker nt, nt_noncontiguous = random_nt_noncontiguous_pair( 1247*da0073e9SAndroid Build Coastguard Worker (2, 3, 6, 7), device=device, dtype=torch.float32 1248*da0073e9SAndroid Build Coastguard Worker ) 1249*da0073e9SAndroid Build Coastguard Worker nested_result = func(nt) 1250*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nested_result.is_nested) 1251*da0073e9SAndroid Build Coastguard Worker for t, t_res in zip(nt.unbind(), nested_result.unbind()): 1252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(t), t_res) 1253*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1254*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1255*da0073e9SAndroid Build Coastguard Worker "NestedTensor must be contiguous to get buffer.", 1256*da0073e9SAndroid Build Coastguard Worker lambda: func(nt_noncontiguous), 1257*da0073e9SAndroid Build Coastguard Worker ) 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker @parametrize("func", [subtest(torch.ge, name="ge"), subtest(torch.eq, name="eq")]) 1260*da0073e9SAndroid Build Coastguard Worker def test_binary_ops_with_scalar(self, device, func): 1261*da0073e9SAndroid Build Coastguard Worker nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( 1262*da0073e9SAndroid Build Coastguard Worker (2, 3, 6, 7), device=device, dtype=torch.float32 1263*da0073e9SAndroid Build Coastguard Worker ) 1264*da0073e9SAndroid Build Coastguard Worker scalar = 0.0 1265*da0073e9SAndroid Build Coastguard Worker 1266*da0073e9SAndroid Build Coastguard Worker # should work regardless of contiguity 1267*da0073e9SAndroid Build Coastguard Worker for nt in (nt_contiguous, nt_noncontiguous): 1268*da0073e9SAndroid Build Coastguard Worker nested_result = func(nt, scalar) 1269*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nested_result.is_nested) 1270*da0073e9SAndroid Build Coastguard Worker for t, t_res in zip(nt.unbind(), nested_result.unbind()): 1271*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(t, scalar), t_res) 1272*da0073e9SAndroid Build Coastguard Worker 1273*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and_half()) 1274*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_chunk(self, device, dtype): 1275*da0073e9SAndroid Build Coastguard Worker # Transformer use case 1276*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3 * 4, device=device, dtype=dtype) 1277*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 3 * 4, device=device, dtype=dtype) 1278*da0073e9SAndroid Build Coastguard Worker c = torch.randn(1, 3 * 4, device=device, dtype=dtype) 1279*da0073e9SAndroid Build Coastguard Worker a_chunks = a.chunk(3, dim=-1) 1280*da0073e9SAndroid Build Coastguard Worker b_chunks = b.chunk(3, dim=-1) 1281*da0073e9SAndroid Build Coastguard Worker c_chunks = c.chunk(3, dim=-1) 1282*da0073e9SAndroid Build Coastguard Worker 1283*da0073e9SAndroid Build Coastguard Worker a_nt = [a_chunks[0], b_chunks[0], c_chunks[0]] 1284*da0073e9SAndroid Build Coastguard Worker b_nt = [a_chunks[1], b_chunks[1], c_chunks[1]] 1285*da0073e9SAndroid Build Coastguard Worker c_nt = [a_chunks[2], b_chunks[2], c_chunks[2]] 1286*da0073e9SAndroid Build Coastguard Worker 1287*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([a, b, c]) 1288*da0073e9SAndroid Build Coastguard Worker chunked = nt.chunk(3, dim=-1) 1289*da0073e9SAndroid Build Coastguard Worker 1290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chunked[0], torch.nested.nested_tensor(a_nt)) 1291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chunked[1], torch.nested.nested_tensor(b_nt)) 1292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chunked[2], torch.nested.nested_tensor(c_nt)) 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker for chunk in chunked: 1295*da0073e9SAndroid Build Coastguard Worker self.assertFalse(chunk.is_contiguous()) 1296*da0073e9SAndroid Build Coastguard Worker 1297*da0073e9SAndroid Build Coastguard Worker # Failure chunking on ragged dimensions 1298*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1299*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1300*da0073e9SAndroid Build Coastguard Worker "Chunk for nested tensors is currently only supported for the last dimension.", 1301*da0073e9SAndroid Build Coastguard Worker lambda: torch.chunk(nt, 5, dim=1), 1302*da0073e9SAndroid Build Coastguard Worker ) 1303*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1304*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1305*da0073e9SAndroid Build Coastguard Worker "Chunk for nested tensors is currently only supported for the last dimension.", 1306*da0073e9SAndroid Build Coastguard Worker lambda: torch.chunk(nt, 5, dim=0), 1307*da0073e9SAndroid Build Coastguard Worker ) 1308*da0073e9SAndroid Build Coastguard Worker 1309*da0073e9SAndroid Build Coastguard Worker # Failure on non-contiguous nt 1310*da0073e9SAndroid Build Coastguard Worker _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) 1311*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1312*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1313*da0073e9SAndroid Build Coastguard Worker "chunk expects `self` to be contiguous.", 1314*da0073e9SAndroid Build Coastguard Worker lambda: torch.chunk(nt_noncontiguous, 5, dim=-1), 1315*da0073e9SAndroid Build Coastguard Worker ) 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker # Failure when calling non divisible n_chunks 1318*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1319*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1320*da0073e9SAndroid Build Coastguard Worker "Chunk for nested tensors is only supported for " 1321*da0073e9SAndroid Build Coastguard Worker "nested tensors with trailing dimension divisible by chunks.", 1322*da0073e9SAndroid Build Coastguard Worker lambda: torch.chunk(nt, 5, dim=-1), 1323*da0073e9SAndroid Build Coastguard Worker ) 1324*da0073e9SAndroid Build Coastguard Worker 1325*da0073e9SAndroid Build Coastguard Worker # Failure when calling backward on a chunk 1326*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3 * 4, device=device, dtype=dtype, requires_grad=True) 1327*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 3 * 4, device=device, dtype=dtype, requires_grad=True) 1328*da0073e9SAndroid Build Coastguard Worker nt_grad = torch.nested.as_nested_tensor([a, b]) 1329*da0073e9SAndroid Build Coastguard Worker chunked = torch.chunk(nt_grad, 2, dim=-1) 1330*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1331*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1332*da0073e9SAndroid Build Coastguard Worker "Nested Strided Tensor doesn't support chunk backward.", 1333*da0073e9SAndroid Build Coastguard Worker lambda: chunked[0].backward(chunked[0].clone()), 1334*da0073e9SAndroid Build Coastguard Worker ) 1335*da0073e9SAndroid Build Coastguard Worker 1336*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and_half()) 1337*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_split_with_sizes(self, device, dtype): 1338*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 20, device=device, dtype=dtype) 1339*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 20, device=device, dtype=dtype) 1340*da0073e9SAndroid Build Coastguard Worker c = torch.randn(1, 20, device=device, dtype=dtype) 1341*da0073e9SAndroid Build Coastguard Worker 1342*da0073e9SAndroid Build Coastguard Worker split_sizes = [4, 6, 10] 1343*da0073e9SAndroid Build Coastguard Worker a_splits = a.split_with_sizes(split_sizes, dim=-1) 1344*da0073e9SAndroid Build Coastguard Worker b_splits = b.split_with_sizes(split_sizes, dim=-1) 1345*da0073e9SAndroid Build Coastguard Worker c_splits = c.split_with_sizes(split_sizes, dim=-1) 1346*da0073e9SAndroid Build Coastguard Worker 1347*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([a, b, c]) 1348*da0073e9SAndroid Build Coastguard Worker nt_splits = nt.split_with_sizes(split_sizes, dim=-1) 1349*da0073e9SAndroid Build Coastguard Worker 1350*da0073e9SAndroid Build Coastguard Worker for i, nt_split in enumerate(nt_splits): 1351*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1352*da0073e9SAndroid Build Coastguard Worker nt_split, 1353*da0073e9SAndroid Build Coastguard Worker torch.nested.nested_tensor([a_splits[i], b_splits[i], c_splits[i]]), 1354*da0073e9SAndroid Build Coastguard Worker ) 1355*da0073e9SAndroid Build Coastguard Worker dense_strides = torch.stack( 1356*da0073e9SAndroid Build Coastguard Worker [ 1357*da0073e9SAndroid Build Coastguard Worker torch.tensor(a_splits[i].stride()), 1358*da0073e9SAndroid Build Coastguard Worker torch.tensor(b_splits[i].stride()), 1359*da0073e9SAndroid Build Coastguard Worker torch.tensor(c_splits[i].stride()), 1360*da0073e9SAndroid Build Coastguard Worker ] 1361*da0073e9SAndroid Build Coastguard Worker ) 1362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_split._nested_tensor_strides(), dense_strides) 1363*da0073e9SAndroid Build Coastguard Worker self.assertFalse(nt_split.is_contiguous()) 1364*da0073e9SAndroid Build Coastguard Worker 1365*da0073e9SAndroid Build Coastguard Worker # Failure calling on ragged dimensions 1366*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1367*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1368*da0073e9SAndroid Build Coastguard Worker "split_with_sizes for nested tensors is currently only supported for the last dimension.", 1369*da0073e9SAndroid Build Coastguard Worker lambda: torch.split_with_sizes(nt, split_sizes, dim=1), 1370*da0073e9SAndroid Build Coastguard Worker ) 1371*da0073e9SAndroid Build Coastguard Worker 1372*da0073e9SAndroid Build Coastguard Worker # Failure calling on non-last dimension 1373*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1374*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1375*da0073e9SAndroid Build Coastguard Worker "split_with_sizes for nested tensors is currently only supported for the last dimension.", 1376*da0073e9SAndroid Build Coastguard Worker lambda: torch.split_with_sizes(nt, split_sizes, dim=0), 1377*da0073e9SAndroid Build Coastguard Worker ) 1378*da0073e9SAndroid Build Coastguard Worker 1379*da0073e9SAndroid Build Coastguard Worker # Failure on non-contiguous nt 1380*da0073e9SAndroid Build Coastguard Worker _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) 1381*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1382*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1383*da0073e9SAndroid Build Coastguard Worker "split_with_sizes expects `self` to be contiguous.", 1384*da0073e9SAndroid Build Coastguard Worker lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1), 1385*da0073e9SAndroid Build Coastguard Worker ) 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker # Failure when calling with split_sizes that don't cover the full dim size 1388*da0073e9SAndroid Build Coastguard Worker bad_split_sizes = [4, 6, 9] # don't add up to 20 1389*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1390*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1391*da0073e9SAndroid Build Coastguard Worker "split_with_sizes expects split_sizes to sum exactly to 20", 1392*da0073e9SAndroid Build Coastguard Worker lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1), 1393*da0073e9SAndroid Build Coastguard Worker ) 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 1396*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 1397*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_indexing_noncontiguous(self, device, dtype): 1398*da0073e9SAndroid Build Coastguard Worker nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( 1399*da0073e9SAndroid Build Coastguard Worker (2, 3, 6, 7), device, dtype 1400*da0073e9SAndroid Build Coastguard Worker ) 1401*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_contiguous.size(0), nt_noncontiguous.size(0)) 1402*da0073e9SAndroid Build Coastguard Worker n = nt_contiguous.size(0) 1403*da0073e9SAndroid Build Coastguard Worker for i in range(n): 1404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_contiguous[i], nt_noncontiguous[i]) 1405*da0073e9SAndroid Build Coastguard Worker 1406*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16) 1407*da0073e9SAndroid Build Coastguard Worker @skipMeta 1408*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 1409*da0073e9SAndroid Build Coastguard Worker @parametrize("transpose", [True, False]) 1410*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_add(self, device, dtype, transpose): 1411*da0073e9SAndroid Build Coastguard Worker if transpose: 1412*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2, 2, device=device, dtype=dtype) 1413*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 2, 2, device=device, dtype=dtype) 1414*da0073e9SAndroid Build Coastguard Worker c = a.transpose(-1, -2).contiguous() 1415*da0073e9SAndroid Build Coastguard Worker d = b.transpose(-1, -2).contiguous() 1416*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor([a, b, a, b]) 1417*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2) 1418*da0073e9SAndroid Build Coastguard Worker else: 1419*da0073e9SAndroid Build Coastguard Worker (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) 1420*da0073e9SAndroid Build Coastguard Worker ref = torch.nested.nested_tensor( 1421*da0073e9SAndroid Build Coastguard Worker [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] 1422*da0073e9SAndroid Build Coastguard Worker ) 1423*da0073e9SAndroid Build Coastguard Worker out = nt1 + nt2 1424*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, out) 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16) 1427*da0073e9SAndroid Build Coastguard Worker @skipMeta 1428*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 1429*da0073e9SAndroid Build Coastguard Worker @parametrize("transpose", [True, False]) 1430*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_sub(self, device, dtype, transpose): 1431*da0073e9SAndroid Build Coastguard Worker if transpose: 1432*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2, 2, device=device, dtype=dtype) 1433*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 2, 2, device=device, dtype=dtype) 1434*da0073e9SAndroid Build Coastguard Worker c = a.transpose(-1, -2).contiguous() 1435*da0073e9SAndroid Build Coastguard Worker d = b.transpose(-1, -2).contiguous() 1436*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor([a, b, a, b]) 1437*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2) 1438*da0073e9SAndroid Build Coastguard Worker else: 1439*da0073e9SAndroid Build Coastguard Worker (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) 1440*da0073e9SAndroid Build Coastguard Worker ref = torch.nested.nested_tensor( 1441*da0073e9SAndroid Build Coastguard Worker [t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] 1442*da0073e9SAndroid Build Coastguard Worker ) 1443*da0073e9SAndroid Build Coastguard Worker out = nt1 - nt2 1444*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, out) 1445*da0073e9SAndroid Build Coastguard Worker 1446*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1447*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16) 1448*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 1449*da0073e9SAndroid Build Coastguard Worker @parametrize("embedding_dim", [8, 128, 256, 384]) 1450*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim): 1451*da0073e9SAndroid Build Coastguard Worker def _test_add_mul(nt, t): 1452*da0073e9SAndroid Build Coastguard Worker ref_add = torch.nested.nested_tensor( 1453*da0073e9SAndroid Build Coastguard Worker [t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())] 1454*da0073e9SAndroid Build Coastguard Worker ) 1455*da0073e9SAndroid Build Coastguard Worker ref_mul = torch.nested.nested_tensor( 1456*da0073e9SAndroid Build Coastguard Worker [t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())] 1457*da0073e9SAndroid Build Coastguard Worker ) 1458*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.add(t), ref_add) 1459*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.mul(t), ref_mul) 1460*da0073e9SAndroid Build Coastguard Worker 1461*da0073e9SAndroid Build Coastguard Worker batch_size = 32 1462*da0073e9SAndroid Build Coastguard Worker seq_lens = torch.randint(low=0, high=10, size=(batch_size,)) 1463*da0073e9SAndroid Build Coastguard Worker 1464*da0073e9SAndroid Build Coastguard Worker # [B, *, D], [B, 1, D] case 1465*da0073e9SAndroid Build Coastguard Worker ts = [torch.randn((seq_len, embedding_dim)) for seq_len in seq_lens] 1466*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1467*da0073e9SAndroid Build Coastguard Worker t = torch.randn((batch_size, 1, embedding_dim), device=device, dtype=dtype) 1468*da0073e9SAndroid Build Coastguard Worker _test_add_mul(nt, t) 1469*da0073e9SAndroid Build Coastguard Worker 1470*da0073e9SAndroid Build Coastguard Worker # [B, *], [B, 1] case 1471*da0073e9SAndroid Build Coastguard Worker ts = [torch.randn(seq_len) for seq_len in seq_lens] 1472*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) 1473*da0073e9SAndroid Build Coastguard Worker t = torch.randn((batch_size, 1), device=device, dtype=dtype) 1474*da0073e9SAndroid Build Coastguard Worker _test_add_mul(nt, t) 1475*da0073e9SAndroid Build Coastguard Worker 1476*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16) 1477*da0073e9SAndroid Build Coastguard Worker @skipMeta 1478*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 1479*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_mul(self, device, dtype): 1480*da0073e9SAndroid Build Coastguard Worker # nested tensor * nested tensor 1481*da0073e9SAndroid Build Coastguard Worker (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) 1482*da0073e9SAndroid Build Coastguard Worker ref = torch.nested.nested_tensor( 1483*da0073e9SAndroid Build Coastguard Worker [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] 1484*da0073e9SAndroid Build Coastguard Worker ) 1485*da0073e9SAndroid Build Coastguard Worker out = nt1 * nt2 1486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, out) 1487*da0073e9SAndroid Build Coastguard Worker # nested tensor * scalar 1488*da0073e9SAndroid Build Coastguard Worker number = 10.0 1489*da0073e9SAndroid Build Coastguard Worker scalar = torch.tensor(number).to(dtype).to(device) 1490*da0073e9SAndroid Build Coastguard Worker ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()]) 1491*da0073e9SAndroid Build Coastguard Worker out_number0 = nt1 * number 1492*da0073e9SAndroid Build Coastguard Worker out_number1 = number * nt1 1493*da0073e9SAndroid Build Coastguard Worker out_scalar0 = nt1 * scalar 1494*da0073e9SAndroid Build Coastguard Worker out_scalar1 = scalar * nt1 1495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_number0, ref) 1496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_number1, ref) 1497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_scalar0, ref) 1498*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_scalar1, ref) 1499*da0073e9SAndroid Build Coastguard Worker # error case: numel == 1 but dim > 0 1500*da0073e9SAndroid Build Coastguard Worker vector = torch.tensor([number]).to(dtype).to(device) 1501*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1502*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1503*da0073e9SAndroid Build Coastguard Worker "Expected both self and other to be nested, but got a nested self and non-nested other", 1504*da0073e9SAndroid Build Coastguard Worker lambda: nt1.mul(vector), 1505*da0073e9SAndroid Build Coastguard Worker ) 1506*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1507*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1508*da0073e9SAndroid Build Coastguard Worker "Expected both self and other to be nested, but got a non-nested self and nested other", 1509*da0073e9SAndroid Build Coastguard Worker lambda: vector.mul(nt1), 1510*da0073e9SAndroid Build Coastguard Worker ) 1511*da0073e9SAndroid Build Coastguard Worker 1512*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16) 1513*da0073e9SAndroid Build Coastguard Worker @skipMeta 1514*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 1515*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_div(self, device, dtype): 1516*da0073e9SAndroid Build Coastguard Worker nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4)) 1517*da0073e9SAndroid Build Coastguard Worker scale = 4.0 1518*da0073e9SAndroid Build Coastguard Worker ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()]) 1519*da0073e9SAndroid Build Coastguard Worker out = nt / 4.0 1520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, out) 1521*da0073e9SAndroid Build Coastguard Worker ref_transposed = ref.transpose(1, 2) 1522*da0073e9SAndroid Build Coastguard Worker out = nt.transpose(1, 2) / 4.0 1523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_transposed, out) 1524*da0073e9SAndroid Build Coastguard Worker 1525*da0073e9SAndroid Build Coastguard Worker ref = torch.nested.nested_tensor( 1526*da0073e9SAndroid Build Coastguard Worker [t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())] 1527*da0073e9SAndroid Build Coastguard Worker ) 1528*da0073e9SAndroid Build Coastguard Worker out = nt / nt2 1529*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, out) 1530*da0073e9SAndroid Build Coastguard Worker 1531*da0073e9SAndroid Build Coastguard Worker out = nt.transpose(1, 2) / nt2.transpose(1, 2) 1532*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref.transpose(1, 2), out) 1533*da0073e9SAndroid Build Coastguard Worker 1534*da0073e9SAndroid Build Coastguard Worker nt_transpose_copy = torch.nested.nested_tensor( 1535*da0073e9SAndroid Build Coastguard Worker [t.transpose(0, 1) for t in nt.unbind()] 1536*da0073e9SAndroid Build Coastguard Worker ) 1537*da0073e9SAndroid Build Coastguard Worker 1538*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1539*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1540*da0073e9SAndroid Build Coastguard Worker "div requires strides to match when given NestedTensors", 1541*da0073e9SAndroid Build Coastguard Worker lambda: nt_transpose_copy.transpose(1, 2) / nt2, 1542*da0073e9SAndroid Build Coastguard Worker ) 1543*da0073e9SAndroid Build Coastguard Worker 1544*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 1545*da0073e9SAndroid Build Coastguard Worker [torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype 1546*da0073e9SAndroid Build Coastguard Worker ) 1547*da0073e9SAndroid Build Coastguard Worker nt_chunks = nt.chunk(2, -1) 1548*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1549*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1550*da0073e9SAndroid Build Coastguard Worker "div requires offsets to match when given NestedTensors", 1551*da0073e9SAndroid Build Coastguard Worker lambda: nt_chunks[0] / nt_chunks[1], 1552*da0073e9SAndroid Build Coastguard Worker ) 1553*da0073e9SAndroid Build Coastguard Worker 1554*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16) 1555*da0073e9SAndroid Build Coastguard Worker @skipMeta 1556*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 1557*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_add_in_place(self, device, dtype): 1558*da0073e9SAndroid Build Coastguard Worker (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) 1559*da0073e9SAndroid Build Coastguard Worker ref = torch.nested.nested_tensor( 1560*da0073e9SAndroid Build Coastguard Worker [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] 1561*da0073e9SAndroid Build Coastguard Worker ) 1562*da0073e9SAndroid Build Coastguard Worker nt1 += nt2 1563*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, nt1) 1564*da0073e9SAndroid Build Coastguard Worker 1565*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16) 1566*da0073e9SAndroid Build Coastguard Worker @skipMeta 1567*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 1568*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_mul_in_place(self, device, dtype): 1569*da0073e9SAndroid Build Coastguard Worker # nested tensor * nested tensor 1570*da0073e9SAndroid Build Coastguard Worker (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) 1571*da0073e9SAndroid Build Coastguard Worker ref = torch.nested.nested_tensor( 1572*da0073e9SAndroid Build Coastguard Worker [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] 1573*da0073e9SAndroid Build Coastguard Worker ) 1574*da0073e9SAndroid Build Coastguard Worker nt1 *= nt2 1575*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, nt1) 1576*da0073e9SAndroid Build Coastguard Worker # nested tensor * scalar 1577*da0073e9SAndroid Build Coastguard Worker number = 10.0 1578*da0073e9SAndroid Build Coastguard Worker scalar = torch.tensor(number).to(dtype).to(device) 1579*da0073e9SAndroid Build Coastguard Worker ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()]) 1580*da0073e9SAndroid Build Coastguard Worker out_number = nt1.clone() 1581*da0073e9SAndroid Build Coastguard Worker out_number *= number 1582*da0073e9SAndroid Build Coastguard Worker out_scalar = nt1.clone() 1583*da0073e9SAndroid Build Coastguard Worker out_scalar *= scalar 1584*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_number, ref) 1585*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_scalar, ref) 1586*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1587*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1588*da0073e9SAndroid Build Coastguard Worker r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]", 1589*da0073e9SAndroid Build Coastguard Worker lambda: scalar.mul_(nt1), 1590*da0073e9SAndroid Build Coastguard Worker ) 1591*da0073e9SAndroid Build Coastguard Worker # error case: numel == 1 but dim > 0 1592*da0073e9SAndroid Build Coastguard Worker vector = torch.tensor([number]).to(dtype).to(device) 1593*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1594*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1595*da0073e9SAndroid Build Coastguard Worker "Expected both self and other to be nested, but got a nested self and non-nested other", 1596*da0073e9SAndroid Build Coastguard Worker lambda: nt1.mul_(vector), 1597*da0073e9SAndroid Build Coastguard Worker ) 1598*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1599*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1600*da0073e9SAndroid Build Coastguard Worker "Expected both self and other to be nested, but got a non-nested self and nested other", 1601*da0073e9SAndroid Build Coastguard Worker lambda: vector.mul_(nt1), 1602*da0073e9SAndroid Build Coastguard Worker ) 1603*da0073e9SAndroid Build Coastguard Worker 1604*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1605*da0073e9SAndroid Build Coastguard Worker @skipMeta 1606*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1607*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_sum_dim(self, device, dtype): 1608*da0073e9SAndroid Build Coastguard Worker params = ((2, (1, 1)), ((4), (4, 4)), (10, (3, 5, 7))) 1609*da0073e9SAndroid Build Coastguard Worker 1610*da0073e9SAndroid Build Coastguard Worker def test_sum(device, dtype, ntensors, max_sizes, dim, keepdim=True): 1611*da0073e9SAndroid Build Coastguard Worker nt = random_nt(device, dtype, ntensors, max_sizes, require_non_empty=False) 1612*da0073e9SAndroid Build Coastguard Worker nt2 = nt.clone() 1613*da0073e9SAndroid Build Coastguard Worker ub2 = nt2.unbind() 1614*da0073e9SAndroid Build Coastguard Worker nt.requires_grad_(True) 1615*da0073e9SAndroid Build Coastguard Worker [t.requires_grad_(True) for t in ub2] 1616*da0073e9SAndroid Build Coastguard Worker nt_sum = nt.sum(dim=dim, keepdim=keepdim) 1617*da0073e9SAndroid Build Coastguard Worker ub2_sum = [t.sum(-1, keepdim=keepdim) for t in ub2] 1618*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_sum, torch.nested.nested_tensor(ub2_sum)) 1619*da0073e9SAndroid Build Coastguard Worker 1620*da0073e9SAndroid Build Coastguard Worker # test backward 1621*da0073e9SAndroid Build Coastguard Worker # generate gradient tensor that has the same size as the output 1622*da0073e9SAndroid Build Coastguard Worker size = nt_sum._nested_tensor_size() 1623*da0073e9SAndroid Build Coastguard Worker gt2 = [] 1624*da0073e9SAndroid Build Coastguard Worker for i in range(ntensors): 1625*da0073e9SAndroid Build Coastguard Worker gt2.append(torch.randn(size[i].tolist(), device=device, dtype=dtype)) 1626*da0073e9SAndroid Build Coastguard Worker gt = torch.nested.nested_tensor(gt2).clone() 1627*da0073e9SAndroid Build Coastguard Worker nt_sum.backward(gt) 1628*da0073e9SAndroid Build Coastguard Worker for t2, g2 in zip(ub2_sum, gt2): 1629*da0073e9SAndroid Build Coastguard Worker t2.backward(g2) 1630*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.grad, torch.nested.nested_tensor([t.grad for t in ub2])) 1631*da0073e9SAndroid Build Coastguard Worker return 1632*da0073e9SAndroid Build Coastguard Worker 1633*da0073e9SAndroid Build Coastguard Worker for ntensors, max_sizes in params: 1634*da0073e9SAndroid Build Coastguard Worker test_sum(device, dtype, ntensors, max_sizes, len(max_sizes)) 1635*da0073e9SAndroid Build Coastguard Worker 1636*da0073e9SAndroid Build Coastguard Worker # Test error inputs 1637*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1638*da0073e9SAndroid Build Coastguard Worker RuntimeError, "NestedTensor can only be reduced across the last" 1639*da0073e9SAndroid Build Coastguard Worker ): 1640*da0073e9SAndroid Build Coastguard Worker torch.nested.nested_tensor( 1641*da0073e9SAndroid Build Coastguard Worker [torch.tensor([3, 4, 5]), torch.tensor([1, 2])] 1642*da0073e9SAndroid Build Coastguard Worker ).sum(0, keepdim=True) 1643*da0073e9SAndroid Build Coastguard Worker 1644*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1645*da0073e9SAndroid Build Coastguard Worker RuntimeError, "NestedTensor only allows reduction of a single" 1646*da0073e9SAndroid Build Coastguard Worker ): 1647*da0073e9SAndroid Build Coastguard Worker torch.nested.nested_tensor( 1648*da0073e9SAndroid Build Coastguard Worker [torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])] 1649*da0073e9SAndroid Build Coastguard Worker ).sum([0, 1], keepdim=True) 1650*da0073e9SAndroid Build Coastguard Worker 1651*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1652*da0073e9SAndroid Build Coastguard Worker RuntimeError, "NestedTensor always requires keepdim=True for now." 1653*da0073e9SAndroid Build Coastguard Worker ): 1654*da0073e9SAndroid Build Coastguard Worker torch.nested.nested_tensor( 1655*da0073e9SAndroid Build Coastguard Worker [torch.tensor([3, 4, 5]), torch.tensor([1, 2])] 1656*da0073e9SAndroid Build Coastguard Worker ).sum(-1) 1657*da0073e9SAndroid Build Coastguard Worker 1658*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16) 1659*da0073e9SAndroid Build Coastguard Worker def test_contiguous(self, device, dtype): 1660*da0073e9SAndroid Build Coastguard Worker # Since we don't have access to the buffer in python this is harder to show what 1661*da0073e9SAndroid Build Coastguard Worker # we are testing for. When we call chunk on a consistent dim of a NT 1662*da0073e9SAndroid Build Coastguard Worker # for chunk_size > 1 the resulting tensors are views of the original NT 1663*da0073e9SAndroid Build Coastguard Worker # whose numels is now less than the size of the buffer. Clone was 1664*da0073e9SAndroid Build Coastguard Worker # previously creating a new NT with a buffer that was the same size as the 1665*da0073e9SAndroid Build Coastguard Worker # original. 1666*da0073e9SAndroid Build Coastguard Worker nt_contiguous = torch.nested.nested_tensor( 1667*da0073e9SAndroid Build Coastguard Worker [ 1668*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 20, device=device, dtype=dtype), 1669*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 20, device=device, dtype=dtype), 1670*da0073e9SAndroid Build Coastguard Worker ] 1671*da0073e9SAndroid Build Coastguard Worker ) 1672*da0073e9SAndroid Build Coastguard Worker # Split up the last dimension which has a consistent size of 20 into 5 chunks 1673*da0073e9SAndroid Build Coastguard Worker chunks = nt_contiguous.chunk(5, dim=-1) 1674*da0073e9SAndroid Build Coastguard Worker 1675*da0073e9SAndroid Build Coastguard Worker # # Check chunks are contiguous after calling contiguous 1676*da0073e9SAndroid Build Coastguard Worker for chunk in chunks: 1677*da0073e9SAndroid Build Coastguard Worker self.assertFalse(chunk.is_contiguous()) 1678*da0073e9SAndroid Build Coastguard Worker self.assertTrue(chunk.contiguous().is_contiguous()) 1679*da0073e9SAndroid Build Coastguard Worker 1680*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16) 1681*da0073e9SAndroid Build Coastguard Worker @skipMeta 1682*da0073e9SAndroid Build Coastguard Worker def test_clone(self, device, dtype): 1683*da0073e9SAndroid Build Coastguard Worker nt1 = random_nt(device, dtype, 4, (4, 4), (1, 1)) 1684*da0073e9SAndroid Build Coastguard Worker nt2 = nt1.clone() 1685*da0073e9SAndroid Build Coastguard Worker # Verify the values match 1686*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt1, nt2) 1687*da0073e9SAndroid Build Coastguard Worker # Verify modifying nt2 doesn't affect nt1 1688*da0073e9SAndroid Build Coastguard Worker nt2.mul_(nt1) 1689*da0073e9SAndroid Build Coastguard Worker ub1 = nt1.unbind() 1690*da0073e9SAndroid Build Coastguard Worker ub2 = nt2.unbind() 1691*da0073e9SAndroid Build Coastguard Worker for i in range(len(ub1)): 1692*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(ub1[i], ub2[i]) 1693*da0073e9SAndroid Build Coastguard Worker 1694*da0073e9SAndroid Build Coastguard Worker nt1.clone(memory_format=torch.preserve_format) 1695*da0073e9SAndroid Build Coastguard Worker msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast" 1696*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1697*da0073e9SAndroid Build Coastguard Worker nt1.clone(memory_format=torch.channels_last) 1698*da0073e9SAndroid Build Coastguard Worker 1699*da0073e9SAndroid Build Coastguard Worker # cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half' 1700*da0073e9SAndroid Build Coastguard Worker @decorateIf(xfailIfTorchDynamo, lambda params: params["layout"] == torch.jagged) 1701*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1702*da0073e9SAndroid Build Coastguard Worker @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) 1703*da0073e9SAndroid Build Coastguard Worker def test_dropout(self, device, dtype, layout): 1704*da0073e9SAndroid Build Coastguard Worker # edge case: empty nested tensor 1705*da0073e9SAndroid Build Coastguard Worker # TODO: support empty NT in jagged layout 1706*da0073e9SAndroid Build Coastguard Worker if layout == torch.strided: 1707*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor([], layout=layout) 1708*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.dropout(nt0, 0.5) 1709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt0, y) 1710*da0073e9SAndroid Build Coastguard Worker # normal nested tensor 1711*da0073e9SAndroid Build Coastguard Worker ntensors = 4 1712*da0073e9SAndroid Build Coastguard Worker if layout == torch.jagged: 1713*da0073e9SAndroid Build Coastguard Worker nt = random_nt(device, dtype, ntensors, (4, 4), (0, 3), layout=layout) 1714*da0073e9SAndroid Build Coastguard Worker else: 1715*da0073e9SAndroid Build Coastguard Worker nt = random_nt(device, dtype, ntensors, (4, 4), layout=layout) 1716*da0073e9SAndroid Build Coastguard Worker # edge case: invalid dropout 1717*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1)) 1718*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1)) 1719*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1)) 1720*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1)) 1721*da0073e9SAndroid Build Coastguard Worker # edge case: no dropout 1722*da0073e9SAndroid Build Coastguard Worker dropouter = torch.nn.Dropout(0.0) 1723*da0073e9SAndroid Build Coastguard Worker y0 = dropouter(nt) 1724*da0073e9SAndroid Build Coastguard Worker y1 = torch.nn.functional.dropout(nt, 0.0) 1725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt, y0) 1726*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt, y1) 1727*da0073e9SAndroid Build Coastguard Worker # edge case: all dropout 1728*da0073e9SAndroid Build Coastguard Worker dropouter = torch.nn.Dropout(1.0) 1729*da0073e9SAndroid Build Coastguard Worker y0 = dropouter(nt) 1730*da0073e9SAndroid Build Coastguard Worker y1 = torch.nn.functional.dropout(nt, 1.0) 1731*da0073e9SAndroid Build Coastguard Worker nt0 = torch.zeros_like(nt) 1732*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt0, y0) 1733*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt0, y1) 1734*da0073e9SAndroid Build Coastguard Worker # normal case: normal dropout 1735*da0073e9SAndroid Build Coastguard Worker p = 0.2 1736*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.dropout(nt, p) 1737*da0073e9SAndroid Build Coastguard Worker expect = nt.clone() 1738*da0073e9SAndroid Build Coastguard Worker if layout == torch.jagged: 1739*da0073e9SAndroid Build Coastguard Worker expect = torch.where(y == 0.0, y, nt) 1740*da0073e9SAndroid Build Coastguard Worker expect /= 1.0 - p 1741*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, expect) 1742*da0073e9SAndroid Build Coastguard Worker else: 1743*da0073e9SAndroid Build Coastguard Worker expect = nt.clone() 1744*da0073e9SAndroid Build Coastguard Worker for i in range(ntensors): 1745*da0073e9SAndroid Build Coastguard Worker actual_tensor = y[i].view(-1) 1746*da0073e9SAndroid Build Coastguard Worker expect_tensor = expect[i].view(-1) 1747*da0073e9SAndroid Build Coastguard Worker for j in range(actual_tensor.shape[0]): 1748*da0073e9SAndroid Build Coastguard Worker if actual_tensor[j].item() == 0.0: 1749*da0073e9SAndroid Build Coastguard Worker expect_tensor[j] = 0.0 1750*da0073e9SAndroid Build Coastguard Worker else: 1751*da0073e9SAndroid Build Coastguard Worker expect_tensor[j] /= 1.0 - p 1752*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, expect) 1753*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 1754*da0073e9SAndroid Build Coastguard Worker dropouter = torch.nn.Dropout(p) 1755*da0073e9SAndroid Build Coastguard Worker y0 = dropouter(nt) 1756*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 1757*da0073e9SAndroid Build Coastguard Worker y1 = torch.nn.functional.dropout(nt, p) 1758*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y0, y1) 1759*da0073e9SAndroid Build Coastguard Worker 1760*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1761*da0073e9SAndroid Build Coastguard Worker def test_dropout_noncontiguous(self, device, dtype): 1762*da0073e9SAndroid Build Coastguard Worker ntensors = 4 1763*da0073e9SAndroid Build Coastguard Worker nt0 = random_nt(device, dtype, ntensors, (4, 4)) 1764*da0073e9SAndroid Build Coastguard Worker nt1 = nt0.transpose(-1, -2) 1765*da0073e9SAndroid Build Coastguard Worker p = 0.3 1766*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 1767*da0073e9SAndroid Build Coastguard Worker dropouter = torch.nn.Dropout(p) 1768*da0073e9SAndroid Build Coastguard Worker y0 = dropouter(nt0) 1769*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 1770*da0073e9SAndroid Build Coastguard Worker y1 = torch.nn.functional.dropout(nt1, p).transpose(-1, -2) 1771*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y0, y1) 1772*da0073e9SAndroid Build Coastguard Worker 1773*da0073e9SAndroid Build Coastguard Worker # cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half' 1774*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1775*da0073e9SAndroid Build Coastguard Worker def test_softmax(self, device, dtype): 1776*da0073e9SAndroid Build Coastguard Worker # normal nested tensor 1777*da0073e9SAndroid Build Coastguard Worker ntensors = 4 1778*da0073e9SAndroid Build Coastguard Worker nt = random_nt(device, dtype, ntensors, (4, 4)) 1779*da0073e9SAndroid Build Coastguard Worker # error case: softmax across nested dimension 1780*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1781*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1782*da0073e9SAndroid Build Coastguard Worker "Cannot apply softmax across nested dimension 0", 1783*da0073e9SAndroid Build Coastguard Worker lambda: torch.nn.functional.softmax(nt, 0), 1784*da0073e9SAndroid Build Coastguard Worker ) 1785*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1786*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1787*da0073e9SAndroid Build Coastguard Worker "Cannot apply softmax across nested dimension 0", 1788*da0073e9SAndroid Build Coastguard Worker lambda: torch.nn.functional.softmax(nt, -3), 1789*da0073e9SAndroid Build Coastguard Worker ) 1790*da0073e9SAndroid Build Coastguard Worker # error case: dimension out of range 1791*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3)) 1792*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4)) 1793*da0073e9SAndroid Build Coastguard Worker # normal case: should equal to padding -inf 1794*da0073e9SAndroid Build Coastguard Worker softmaxer = torch.nn.Softmax(1) 1795*da0073e9SAndroid Build Coastguard Worker y0 = softmaxer(nt) 1796*da0073e9SAndroid Build Coastguard Worker y1 = torch.nn.functional.softmax(nt, 1) 1797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y0, y1) 1798*da0073e9SAndroid Build Coastguard Worker pt = torch.nested.to_padded_tensor(nt, float("-inf")) 1799*da0073e9SAndroid Build Coastguard Worker # if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan 1800*da0073e9SAndroid Build Coastguard Worker # however, physically speaking that should be 0.0 1801*da0073e9SAndroid Build Coastguard Worker expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0) 1802*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nested.to_padded_tensor(y0, 0.0), expect) 1803*da0073e9SAndroid Build Coastguard Worker # edge case: empty nested tensor 1804*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor([]) 1805*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.softmax(nt0, 1) 1806*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt0, y) 1807*da0073e9SAndroid Build Coastguard Worker # edge case: nesting scalars 1808*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)]) 1809*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0)) 1810*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1)) 1811*da0073e9SAndroid Build Coastguard Worker 1812*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1813*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 1814*da0073e9SAndroid Build Coastguard Worker def test_softmax_noncontiguous(self, device, dtype): 1815*da0073e9SAndroid Build Coastguard Worker nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( 1816*da0073e9SAndroid Build Coastguard Worker (2, 3, 6, 7), device, dtype 1817*da0073e9SAndroid Build Coastguard Worker ) 1818*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1819*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.softmax(nt_contiguous, -1), 1820*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.softmax(nt_noncontiguous, -1), 1821*da0073e9SAndroid Build Coastguard Worker ) 1822*da0073e9SAndroid Build Coastguard Worker 1823*da0073e9SAndroid Build Coastguard Worker def _test_bmm(self, device, dtype): 1824*da0073e9SAndroid Build Coastguard Worker # error case: not 3D tensors 1825*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype) 1826*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 1827*da0073e9SAndroid Build Coastguard Worker [torch.randn(2), torch.randn(3)], device=device, dtype=dtype 1828*da0073e9SAndroid Build Coastguard Worker ) 1829*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor( 1830*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype 1831*da0073e9SAndroid Build Coastguard Worker ) 1832*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1833*da0073e9SAndroid Build Coastguard Worker RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt0) 1834*da0073e9SAndroid Build Coastguard Worker ) 1835*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1836*da0073e9SAndroid Build Coastguard Worker RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt1) 1837*da0073e9SAndroid Build Coastguard Worker ) 1838*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1839*da0073e9SAndroid Build Coastguard Worker RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt2) 1840*da0073e9SAndroid Build Coastguard Worker ) 1841*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1842*da0073e9SAndroid Build Coastguard Worker RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt0) 1843*da0073e9SAndroid Build Coastguard Worker ) 1844*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1845*da0073e9SAndroid Build Coastguard Worker RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt1) 1846*da0073e9SAndroid Build Coastguard Worker ) 1847*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1848*da0073e9SAndroid Build Coastguard Worker RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt2) 1849*da0073e9SAndroid Build Coastguard Worker ) 1850*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1851*da0073e9SAndroid Build Coastguard Worker RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt0) 1852*da0073e9SAndroid Build Coastguard Worker ) 1853*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1854*da0073e9SAndroid Build Coastguard Worker RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt1) 1855*da0073e9SAndroid Build Coastguard Worker ) 1856*da0073e9SAndroid Build Coastguard Worker # error case: incompatible batch size 1857*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 1858*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype 1859*da0073e9SAndroid Build Coastguard Worker ) 1860*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 1861*da0073e9SAndroid Build Coastguard Worker [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))], 1862*da0073e9SAndroid Build Coastguard Worker device=device, 1863*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1864*da0073e9SAndroid Build Coastguard Worker ) 1865*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1866*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1867*da0073e9SAndroid Build Coastguard Worker "Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.", 1868*da0073e9SAndroid Build Coastguard Worker lambda: nt0.bmm(nt1), 1869*da0073e9SAndroid Build Coastguard Worker ) 1870*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1871*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1872*da0073e9SAndroid Build Coastguard Worker "Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.", 1873*da0073e9SAndroid Build Coastguard Worker lambda: nt1.bmm(nt0), 1874*da0073e9SAndroid Build Coastguard Worker ) 1875*da0073e9SAndroid Build Coastguard Worker # error case: underlying matrices cannot be multiplied 1876*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 1877*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype 1878*da0073e9SAndroid Build Coastguard Worker ) 1879*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1880*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1881*da0073e9SAndroid Build Coastguard Worker r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)", 1882*da0073e9SAndroid Build Coastguard Worker lambda: nt0.bmm(nt0), 1883*da0073e9SAndroid Build Coastguard Worker ) 1884*da0073e9SAndroid Build Coastguard Worker # normal nested tensor 1885*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 1886*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype 1887*da0073e9SAndroid Build Coastguard Worker ) 1888*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 1889*da0073e9SAndroid Build Coastguard Worker [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype 1890*da0073e9SAndroid Build Coastguard Worker ) 1891*da0073e9SAndroid Build Coastguard Worker actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) 1892*da0073e9SAndroid Build Coastguard Worker expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm( 1893*da0073e9SAndroid Build Coastguard Worker torch.nested.to_padded_tensor(nt1, 0.0) 1894*da0073e9SAndroid Build Coastguard Worker ) 1895*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16: 1896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) 1897*da0073e9SAndroid Build Coastguard Worker else: 1898*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 1899*da0073e9SAndroid Build Coastguard Worker 1900*da0073e9SAndroid Build Coastguard Worker # nested tensor bmm normal tensor 1901*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 1902*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 7)), torch.randn((3, 7))], device=device, dtype=dtype 1903*da0073e9SAndroid Build Coastguard Worker ) 1904*da0073e9SAndroid Build Coastguard Worker nt1 = torch.rand(2, 7, 5, dtype=dtype, device=device) 1905*da0073e9SAndroid Build Coastguard Worker actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) 1906*da0073e9SAndroid Build Coastguard Worker expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1) 1907*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16: 1908*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) 1909*da0073e9SAndroid Build Coastguard Worker else: 1910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 1911*da0073e9SAndroid Build Coastguard Worker 1912*da0073e9SAndroid Build Coastguard Worker # nested tensor bmm normal tensor with non-contiguous view 1913*da0073e9SAndroid Build Coastguard Worker nt1 = torch.rand(2, 5, 7, dtype=dtype, device=device) 1914*da0073e9SAndroid Build Coastguard Worker nt1 = nt1.transpose(1, 2) 1915*da0073e9SAndroid Build Coastguard Worker actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) 1916*da0073e9SAndroid Build Coastguard Worker expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1) 1917*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16: 1918*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) 1919*da0073e9SAndroid Build Coastguard Worker else: 1920*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 1921*da0073e9SAndroid Build Coastguard Worker 1922*da0073e9SAndroid Build Coastguard Worker # normal tensor bmm nested tensor 1923*da0073e9SAndroid Build Coastguard Worker nt0 = torch.rand(2, 5, 7, dtype=dtype, device=device) 1924*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 1925*da0073e9SAndroid Build Coastguard Worker [torch.randn((7, 6)), torch.randn((7, 5))], device=device, dtype=dtype 1926*da0073e9SAndroid Build Coastguard Worker ) 1927*da0073e9SAndroid Build Coastguard Worker actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) 1928*da0073e9SAndroid Build Coastguard Worker expect = nt0.bmm(torch.nested.to_padded_tensor(nt1, 0.0)) 1929*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16: 1930*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) 1931*da0073e9SAndroid Build Coastguard Worker else: 1932*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 1933*da0073e9SAndroid Build Coastguard Worker 1934*da0073e9SAndroid Build Coastguard Worker # test tensorcore path 1935*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 1936*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype 1937*da0073e9SAndroid Build Coastguard Worker ) 1938*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 1939*da0073e9SAndroid Build Coastguard Worker [torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype 1940*da0073e9SAndroid Build Coastguard Worker ) 1941*da0073e9SAndroid Build Coastguard Worker actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) 1942*da0073e9SAndroid Build Coastguard Worker expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm( 1943*da0073e9SAndroid Build Coastguard Worker torch.nested.to_padded_tensor(nt1, 0.0) 1944*da0073e9SAndroid Build Coastguard Worker ) 1945*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16: 1946*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) 1947*da0073e9SAndroid Build Coastguard Worker else: 1948*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 1949*da0073e9SAndroid Build Coastguard Worker 1950*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1951*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.float16) 1952*da0073e9SAndroid Build Coastguard Worker def test_bmm_cuda(self, device, dtype): 1953*da0073e9SAndroid Build Coastguard Worker self._test_bmm(device, dtype) 1954*da0073e9SAndroid Build Coastguard Worker 1955*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1956*da0073e9SAndroid Build Coastguard Worker # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' 1957*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1958*da0073e9SAndroid Build Coastguard Worker def test_bmm_cpu(self, device, dtype): 1959*da0073e9SAndroid Build Coastguard Worker self._test_bmm(device, dtype) 1960*da0073e9SAndroid Build Coastguard Worker 1961*da0073e9SAndroid Build Coastguard Worker # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' 1962*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1963*da0073e9SAndroid Build Coastguard Worker def test_bmm_noncontiguous(self, device, dtype): 1964*da0073e9SAndroid Build Coastguard Worker nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair( 1965*da0073e9SAndroid Build Coastguard Worker (2, 3), device, dtype 1966*da0073e9SAndroid Build Coastguard Worker ) 1967*da0073e9SAndroid Build Coastguard Worker nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair( 1968*da0073e9SAndroid Build Coastguard Worker (6, 7), device, dtype 1969*da0073e9SAndroid Build Coastguard Worker ) 1970*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1971*da0073e9SAndroid Build Coastguard Worker nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous), 1972*da0073e9SAndroid Build Coastguard Worker nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous), 1973*da0073e9SAndroid Build Coastguard Worker ) 1974*da0073e9SAndroid Build Coastguard Worker 1975*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1976*da0073e9SAndroid Build Coastguard Worker def test_matmul_with_bmm_path(self, device, dtype): 1977*da0073e9SAndroid Build Coastguard Worker def unbind_rebind_matmul(nt1, nt2): 1978*da0073e9SAndroid Build Coastguard Worker t1s = nt1.unbind() 1979*da0073e9SAndroid Build Coastguard Worker t2s = nt2.unbind() 1980*da0073e9SAndroid Build Coastguard Worker out_ts = [t1.matmul(t2) for t1, t2 in zip(t1s, t2s)] 1981*da0073e9SAndroid Build Coastguard Worker return torch.nested.nested_tensor(out_ts) 1982*da0073e9SAndroid Build Coastguard Worker 1983*da0073e9SAndroid Build Coastguard Worker # [N, n_head, *, head_dim], [N, n_head, head_dim, *] 1984*da0073e9SAndroid Build Coastguard Worker Ns = [1, 2, 5] 1985*da0073e9SAndroid Build Coastguard Worker n_heads = np.random.randint(2, 5) 1986*da0073e9SAndroid Build Coastguard Worker head_dim = 3 1987*da0073e9SAndroid Build Coastguard Worker t1s = [] 1988*da0073e9SAndroid Build Coastguard Worker t2s = [] 1989*da0073e9SAndroid Build Coastguard Worker for N in Ns: 1990*da0073e9SAndroid Build Coastguard Worker for _ in range(N): 1991*da0073e9SAndroid Build Coastguard Worker seq_len1 = np.random.randint(2, 5) 1992*da0073e9SAndroid Build Coastguard Worker seq_len2 = np.random.randint(2, 5) 1993*da0073e9SAndroid Build Coastguard Worker t1s.append(torch.randn(n_heads, seq_len1, head_dim)) 1994*da0073e9SAndroid Build Coastguard Worker t2s.append(torch.randn(n_heads, head_dim, seq_len2)) 1995*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor(t1s, device=device, dtype=dtype) 1996*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor(t2s, device=device, dtype=dtype) 1997*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.matmul(nt1, nt2), unbind_rebind_matmul(nt1, nt2)) 1998*da0073e9SAndroid Build Coastguard Worker 1999*da0073e9SAndroid Build Coastguard Worker # test with noncontiguous 2000*da0073e9SAndroid Build Coastguard Worker t3s = [] 2001*da0073e9SAndroid Build Coastguard Worker t4s = [] 2002*da0073e9SAndroid Build Coastguard Worker for _ in range(N): 2003*da0073e9SAndroid Build Coastguard Worker seq_len = np.random.randint(2, 5) 2004*da0073e9SAndroid Build Coastguard Worker t3s.append(torch.randn(seq_len, n_heads, head_dim)) 2005*da0073e9SAndroid Build Coastguard Worker t4s.append(torch.randn(seq_len, n_heads, head_dim)) 2006*da0073e9SAndroid Build Coastguard Worker nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose( 2007*da0073e9SAndroid Build Coastguard Worker 1, 2 2008*da0073e9SAndroid Build Coastguard Worker ) 2009*da0073e9SAndroid Build Coastguard Worker nt4 = ( 2010*da0073e9SAndroid Build Coastguard Worker torch.nested.nested_tensor(t4s, device=device, dtype=dtype) 2011*da0073e9SAndroid Build Coastguard Worker .transpose(1, 2) 2012*da0073e9SAndroid Build Coastguard Worker .transpose(2, 3) 2013*da0073e9SAndroid Build Coastguard Worker ) 2014*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4)) 2015*da0073e9SAndroid Build Coastguard Worker 2016*da0073e9SAndroid Build Coastguard Worker # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' 2017*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 2018*da0073e9SAndroid Build Coastguard Worker def test_matmul(self, device, dtype): 2019*da0073e9SAndroid Build Coastguard Worker # error case: one is nested but the other is not 2020*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 2021*da0073e9SAndroid Build Coastguard Worker [torch.randn(2), torch.randn(3)], device=device, dtype=dtype 2022*da0073e9SAndroid Build Coastguard Worker ) 2023*da0073e9SAndroid Build Coastguard Worker t = torch.randn(4, device=device, dtype=dtype) 2024*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2025*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2026*da0073e9SAndroid Build Coastguard Worker "Expected both to be nested, but got a nested self and non-nested other", 2027*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt, t), 2028*da0073e9SAndroid Build Coastguard Worker ) 2029*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2030*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2031*da0073e9SAndroid Build Coastguard Worker "Expected both to be nested, but got a non-nested self and nested other", 2032*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(t, nt), 2033*da0073e9SAndroid Build Coastguard Worker ) 2034*da0073e9SAndroid Build Coastguard Worker # error case: not 3+D tensors 2035*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype) 2036*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 2037*da0073e9SAndroid Build Coastguard Worker [torch.randn(2), torch.randn(3)], device=device, dtype=dtype 2038*da0073e9SAndroid Build Coastguard Worker ) 2039*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor( 2040*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype 2041*da0073e9SAndroid Build Coastguard Worker ) 2042*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2043*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2044*da0073e9SAndroid Build Coastguard Worker r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", 2045*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt0, nt0), 2046*da0073e9SAndroid Build Coastguard Worker ) 2047*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2048*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2049*da0073e9SAndroid Build Coastguard Worker r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", 2050*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt0, nt1), 2051*da0073e9SAndroid Build Coastguard Worker ) 2052*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2053*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2054*da0073e9SAndroid Build Coastguard Worker r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", 2055*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt0, nt2), 2056*da0073e9SAndroid Build Coastguard Worker ) 2057*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2058*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2059*da0073e9SAndroid Build Coastguard Worker r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", 2060*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt1, nt0), 2061*da0073e9SAndroid Build Coastguard Worker ) 2062*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2063*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2064*da0073e9SAndroid Build Coastguard Worker r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", 2065*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt1, nt1), 2066*da0073e9SAndroid Build Coastguard Worker ) 2067*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2068*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2069*da0073e9SAndroid Build Coastguard Worker r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", 2070*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt1, nt2), 2071*da0073e9SAndroid Build Coastguard Worker ) 2072*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2073*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2074*da0073e9SAndroid Build Coastguard Worker r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", 2075*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt2, nt0), 2076*da0073e9SAndroid Build Coastguard Worker ) 2077*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2078*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2079*da0073e9SAndroid Build Coastguard Worker r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", 2080*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt2, nt1), 2081*da0073e9SAndroid Build Coastguard Worker ) 2082*da0073e9SAndroid Build Coastguard Worker # error case: incompatible batch size 2083*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 2084*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype 2085*da0073e9SAndroid Build Coastguard Worker ) 2086*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 2087*da0073e9SAndroid Build Coastguard Worker [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))], 2088*da0073e9SAndroid Build Coastguard Worker device=device, 2089*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2090*da0073e9SAndroid Build Coastguard Worker ) 2091*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2092*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2093*da0073e9SAndroid Build Coastguard Worker r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", 2094*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt0, nt1), 2095*da0073e9SAndroid Build Coastguard Worker ) 2096*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2097*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2098*da0073e9SAndroid Build Coastguard Worker r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", 2099*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt1, nt0), 2100*da0073e9SAndroid Build Coastguard Worker ) 2101*da0073e9SAndroid Build Coastguard Worker # error case: incompatible (wrong) batch sizes that shouldn't even broadcast? 2102*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 2103*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 2, 4)), torch.randn((2, 3, 4))], device=device, dtype=dtype 2104*da0073e9SAndroid Build Coastguard Worker ) 2105*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 2106*da0073e9SAndroid Build Coastguard Worker [torch.randn((3, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype 2107*da0073e9SAndroid Build Coastguard Worker ) 2108*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2109*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2110*da0073e9SAndroid Build Coastguard Worker "matmul(): For nested tensors, batch dimensions must have the same sizes,", 2111*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt0, nt1), 2112*da0073e9SAndroid Build Coastguard Worker ) 2113*da0073e9SAndroid Build Coastguard Worker # error case: incompatible batch sizes that should technically broadcast 2114*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 2115*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 2, 4)), torch.randn((1, 3, 4))], device=device, dtype=dtype 2116*da0073e9SAndroid Build Coastguard Worker ) 2117*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 2118*da0073e9SAndroid Build Coastguard Worker [torch.randn((1, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype 2119*da0073e9SAndroid Build Coastguard Worker ) 2120*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2121*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2122*da0073e9SAndroid Build Coastguard Worker "matmul(): For nested tensors, batch dimensions must have the same sizes,", 2123*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt0, nt1), 2124*da0073e9SAndroid Build Coastguard Worker ) 2125*da0073e9SAndroid Build Coastguard Worker # error case: underlying matrices cannot be multiplied 2126*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 2127*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype 2128*da0073e9SAndroid Build Coastguard Worker ) 2129*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2130*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2131*da0073e9SAndroid Build Coastguard Worker "matmul(): Nested tensors cannot be matrix multiplied", 2132*da0073e9SAndroid Build Coastguard Worker lambda: torch.matmul(nt0, nt0), 2133*da0073e9SAndroid Build Coastguard Worker ) 2134*da0073e9SAndroid Build Coastguard Worker # normal nested tensor: 3D 2135*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 2136*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype 2137*da0073e9SAndroid Build Coastguard Worker ) 2138*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 2139*da0073e9SAndroid Build Coastguard Worker [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype 2140*da0073e9SAndroid Build Coastguard Worker ) 2141*da0073e9SAndroid Build Coastguard Worker actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) 2142*da0073e9SAndroid Build Coastguard Worker expect = torch.matmul( 2143*da0073e9SAndroid Build Coastguard Worker torch.nested.to_padded_tensor(nt0, 0.0), 2144*da0073e9SAndroid Build Coastguard Worker torch.nested.to_padded_tensor(nt1, 0.0), 2145*da0073e9SAndroid Build Coastguard Worker ) 2146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 2147*da0073e9SAndroid Build Coastguard Worker # normal nested tensor: 4D (with testing for batch_size=1) 2148*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 2149*da0073e9SAndroid Build Coastguard Worker [torch.randn((1, 2, 4)), torch.randn((8, 3, 7))], device=device, dtype=dtype 2150*da0073e9SAndroid Build Coastguard Worker ) 2151*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 2152*da0073e9SAndroid Build Coastguard Worker [torch.randn((1, 4, 6)), torch.randn((8, 7, 5))], device=device, dtype=dtype 2153*da0073e9SAndroid Build Coastguard Worker ) 2154*da0073e9SAndroid Build Coastguard Worker actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) 2155*da0073e9SAndroid Build Coastguard Worker expect = torch.matmul( 2156*da0073e9SAndroid Build Coastguard Worker torch.nested.to_padded_tensor(nt0, 0.0), 2157*da0073e9SAndroid Build Coastguard Worker torch.nested.to_padded_tensor(nt1, 0.0), 2158*da0073e9SAndroid Build Coastguard Worker ) 2159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 2160*da0073e9SAndroid Build Coastguard Worker # normal nested tensor: 5D 2161*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 2162*da0073e9SAndroid Build Coastguard Worker [torch.randn((8, 9, 2, 4)), torch.randn((8, 9, 3, 7))], 2163*da0073e9SAndroid Build Coastguard Worker device=device, 2164*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2165*da0073e9SAndroid Build Coastguard Worker ) 2166*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 2167*da0073e9SAndroid Build Coastguard Worker [torch.randn((8, 9, 4, 6)), torch.randn((8, 9, 7, 5))], 2168*da0073e9SAndroid Build Coastguard Worker device=device, 2169*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2170*da0073e9SAndroid Build Coastguard Worker ) 2171*da0073e9SAndroid Build Coastguard Worker actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) 2172*da0073e9SAndroid Build Coastguard Worker expect = torch.matmul( 2173*da0073e9SAndroid Build Coastguard Worker torch.nested.to_padded_tensor(nt0, 0.0), 2174*da0073e9SAndroid Build Coastguard Worker torch.nested.to_padded_tensor(nt1, 0.0), 2175*da0073e9SAndroid Build Coastguard Worker ) 2176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 2177*da0073e9SAndroid Build Coastguard Worker 2178*da0073e9SAndroid Build Coastguard Worker # only supported on CUDA for now 2179*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 2180*da0073e9SAndroid Build Coastguard Worker def test_matmul_nt_with_broadcasted_t(self, device, dtype): 2181*da0073e9SAndroid Build Coastguard Worker # NT (B, *, C, D) with T (D, E) broadcasting case 2182*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims([3, None, 4, 5], device=device, dtype=dtype) 2183*da0073e9SAndroid Build Coastguard Worker t = torch.randn(5, 6, device=device, dtype=dtype) 2184*da0073e9SAndroid Build Coastguard Worker output = torch.matmul(nt, t) 2185*da0073e9SAndroid Build Coastguard Worker 2186*da0073e9SAndroid Build Coastguard Worker # should be equivalent to matmul-ing each component with the dense tensor 2187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.size(0), output.size(0)) 2188*da0073e9SAndroid Build Coastguard Worker for component, out_component in zip(nt, output): 2189*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_component, torch.matmul(component, t)) 2190*da0073e9SAndroid Build Coastguard Worker 2191*da0073e9SAndroid Build Coastguard Worker # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' 2192*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 2193*da0073e9SAndroid Build Coastguard Worker def test_matmul_noncontiguous(self, device, dtype): 2194*da0073e9SAndroid Build Coastguard Worker nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair( 2195*da0073e9SAndroid Build Coastguard Worker (2, 3), device, dtype 2196*da0073e9SAndroid Build Coastguard Worker ) 2197*da0073e9SAndroid Build Coastguard Worker nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair( 2198*da0073e9SAndroid Build Coastguard Worker (6, 7), device, dtype 2199*da0073e9SAndroid Build Coastguard Worker ) 2200*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2201*da0073e9SAndroid Build Coastguard Worker torch.matmul(nt0_contiguous.transpose(-1, -2), nt1_contiguous), 2202*da0073e9SAndroid Build Coastguard Worker torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous), 2203*da0073e9SAndroid Build Coastguard Worker ) 2204*da0073e9SAndroid Build Coastguard Worker 2205*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 2206*da0073e9SAndroid Build Coastguard Worker def test_linear(self, device, dtype): 2207*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, device=device, dtype=dtype) 2208*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, device=device, dtype=dtype) 2209*da0073e9SAndroid Build Coastguard Worker c = torch.randn(3, 2, device=device, dtype=dtype) 2210*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([a, b, c]) 2211*da0073e9SAndroid Build Coastguard Worker 2212*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(2, 2, device=device, dtype=dtype) 2213*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(2, device=device, dtype=dtype) 2214*da0073e9SAndroid Build Coastguard Worker # success case 2215*da0073e9SAndroid Build Coastguard Worker torch.functional.F.linear(nt, weight, bias) 2216*da0073e9SAndroid Build Coastguard Worker 2217*da0073e9SAndroid Build Coastguard Worker # invalid nested tensor dimension 2218*da0073e9SAndroid Build Coastguard Worker msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2" 2219*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 2220*da0073e9SAndroid Build Coastguard Worker [ 2221*da0073e9SAndroid Build Coastguard Worker torch.randn(1, device=device, dtype=dtype), 2222*da0073e9SAndroid Build Coastguard Worker torch.randn(2, device=device, dtype=dtype), 2223*da0073e9SAndroid Build Coastguard Worker ] 2224*da0073e9SAndroid Build Coastguard Worker ) 2225*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 2226*da0073e9SAndroid Build Coastguard Worker torch.functional.F.linear(nt1, weight, bias) 2227*da0073e9SAndroid Build Coastguard Worker 2228*da0073e9SAndroid Build Coastguard Worker # invalid weight shape 2229*da0073e9SAndroid Build Coastguard Worker msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3" 2230*da0073e9SAndroid Build Coastguard Worker weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype) 2231*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 2232*da0073e9SAndroid Build Coastguard Worker torch.functional.F.linear(nt, weight1, bias) 2233*da0073e9SAndroid Build Coastguard Worker 2234*da0073e9SAndroid Build Coastguard Worker # inconsistent last dim of nested tensor 2235*da0073e9SAndroid Build Coastguard Worker msg = r"Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals:" 2236*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor( 2237*da0073e9SAndroid Build Coastguard Worker [ 2238*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, device=device, dtype=dtype), 2239*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, device=device, dtype=dtype), 2240*da0073e9SAndroid Build Coastguard Worker ] 2241*da0073e9SAndroid Build Coastguard Worker ) 2242*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 2243*da0073e9SAndroid Build Coastguard Worker torch.functional.F.linear(nt2, weight, bias) 2244*da0073e9SAndroid Build Coastguard Worker 2245*da0073e9SAndroid Build Coastguard Worker # Mismatch of nested tensor last dim and weight dimension 2246*da0073e9SAndroid Build Coastguard Worker weight2 = torch.randn(2, 4, device=device, dtype=dtype) 2247*da0073e9SAndroid Build Coastguard Worker msg = ( 2248*da0073e9SAndroid Build Coastguard Worker r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'" 2249*da0073e9SAndroid Build Coastguard Worker r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4" 2250*da0073e9SAndroid Build Coastguard Worker ) 2251*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 2252*da0073e9SAndroid Build Coastguard Worker torch.functional.F.linear(nt, weight2, bias) 2253*da0073e9SAndroid Build Coastguard Worker 2254*da0073e9SAndroid Build Coastguard Worker # Nested tensor input and nested weight 2255*da0073e9SAndroid Build Coastguard Worker nt_weight = nt.clone() 2256*da0073e9SAndroid Build Coastguard Worker msg = r"Linear does not support nested weight when input is a nested tensor." 2257*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 2258*da0073e9SAndroid Build Coastguard Worker torch.functional.F.linear(nt, nt_weight, bias) 2259*da0073e9SAndroid Build Coastguard Worker 2260*da0073e9SAndroid Build Coastguard Worker # TODO: test noncontiguous linear 2261*da0073e9SAndroid Build Coastguard Worker # For now this tests the error message of linear 2262*da0073e9SAndroid Build Coastguard Worker # since linear does not support noncontiguous buffer yet 2263*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 2264*da0073e9SAndroid Build Coastguard Worker def test_linear_noncontiguous(self, device, dtype): 2265*da0073e9SAndroid Build Coastguard Worker nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( 2266*da0073e9SAndroid Build Coastguard Worker (2, 3, 6, 7), device, dtype 2267*da0073e9SAndroid Build Coastguard Worker ) 2268*da0073e9SAndroid Build Coastguard Worker weight = torch.randn((8, 5), device=device, dtype=dtype) 2269*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2270*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2271*da0073e9SAndroid Build Coastguard Worker r"for now linear only supports contiguous nested tensor", 2272*da0073e9SAndroid Build Coastguard Worker lambda: torch.nn.functional.linear(nt_noncontiguous, weight), 2273*da0073e9SAndroid Build Coastguard Worker ) 2274*da0073e9SAndroid Build Coastguard Worker 2275*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 2276*da0073e9SAndroid Build Coastguard Worker def test_to_padded_tensor_zero_numel_errors(self, device, dtype): 2277*da0073e9SAndroid Build Coastguard Worker ts = [torch.ones(1, 0), torch.ones(0, 0)] 2278*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 2279*da0073e9SAndroid Build Coastguard Worker ts, device=device, dtype=dtype, layout=torch.strided 2280*da0073e9SAndroid Build Coastguard Worker ) 2281*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2282*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2283*da0073e9SAndroid Build Coastguard Worker r"at least one constituent tensor should have non-zero numel", 2284*da0073e9SAndroid Build Coastguard Worker lambda: torch.nested.to_padded_tensor(nt, 0.0), 2285*da0073e9SAndroid Build Coastguard Worker ) 2286*da0073e9SAndroid Build Coastguard Worker 2287*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 2288*da0073e9SAndroid Build Coastguard Worker def test_transpose(self, device, dtype): 2289*da0073e9SAndroid Build Coastguard Worker nt = random_nt(device, dtype, 4, (4, 4)) 2290*da0073e9SAndroid Build Coastguard Worker # error case: transpose nested dimension 2291*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2292*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2293*da0073e9SAndroid Build Coastguard Worker "Nested tensor dimension 0 cannot be transposed", 2294*da0073e9SAndroid Build Coastguard Worker lambda: nt.transpose(0, 1), 2295*da0073e9SAndroid Build Coastguard Worker ) 2296*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2297*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2298*da0073e9SAndroid Build Coastguard Worker "Nested tensor dimension 0 cannot be transposed", 2299*da0073e9SAndroid Build Coastguard Worker lambda: nt.transpose(1, -3), 2300*da0073e9SAndroid Build Coastguard Worker ) 2301*da0073e9SAndroid Build Coastguard Worker # error case: dimension out of range 2302*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: nt.transpose(1, 3)) 2303*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: nt.transpose(-4, -1)) 2304*da0073e9SAndroid Build Coastguard Worker # normal case 2305*da0073e9SAndroid Build Coastguard Worker ntT = nt.transpose(-1, -2) 2306*da0073e9SAndroid Build Coastguard Worker ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) 2307*da0073e9SAndroid Build Coastguard Worker pt = torch.nested.to_padded_tensor(nt, 0.0) 2308*da0073e9SAndroid Build Coastguard Worker ptT = pt.transpose(-1, -2) 2309*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ptT, ptT_from_ntT) 2310*da0073e9SAndroid Build Coastguard Worker 2311*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 2312*da0073e9SAndroid Build Coastguard Worker def test_squeeze_unsqueeze(self, device, dtype): 2313*da0073e9SAndroid Build Coastguard Worker a = torch.arange(6).reshape(2, 3) 2314*da0073e9SAndroid Build Coastguard Worker b = torch.arange(15).reshape(5, 3) 2315*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([a, b], device=device, dtype=dtype) 2316*da0073e9SAndroid Build Coastguard Worker # error case: squeeze no dimension 2317*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2318*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2319*da0073e9SAndroid Build Coastguard Worker "For nested tensors, squeeze without the dim argument", 2320*da0073e9SAndroid Build Coastguard Worker lambda: nt.squeeze(), 2321*da0073e9SAndroid Build Coastguard Worker ) 2322*da0073e9SAndroid Build Coastguard Worker # error case: squeeze nested dimension 2323*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2324*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2325*da0073e9SAndroid Build Coastguard Worker "For nested tensors, squeezing dimension 0", 2326*da0073e9SAndroid Build Coastguard Worker lambda: nt.squeeze(0), 2327*da0073e9SAndroid Build Coastguard Worker ) 2328*da0073e9SAndroid Build Coastguard Worker # error case: dimension out of range 2329*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: nt.squeeze(3)) 2330*da0073e9SAndroid Build Coastguard Worker # error case: squeeze nested tensor of singleton tensors 2331*da0073e9SAndroid Build Coastguard Worker c = torch.ones(1) 2332*da0073e9SAndroid Build Coastguard Worker nt_singleton = torch.nested.nested_tensor([c, c], device=device, dtype=dtype) 2333*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2334*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2335*da0073e9SAndroid Build Coastguard Worker "For nested tensors, squeezing a nested tensor of singleton", 2336*da0073e9SAndroid Build Coastguard Worker lambda: nt_singleton.squeeze(1), 2337*da0073e9SAndroid Build Coastguard Worker ) 2338*da0073e9SAndroid Build Coastguard Worker 2339*da0073e9SAndroid Build Coastguard Worker # squeezing a dim which does not have size 1 should be a no-op 2340*da0073e9SAndroid Build Coastguard Worker nt2 = nt.squeeze(-1) 2341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt, nt2) 2342*da0073e9SAndroid Build Coastguard Worker 2343*da0073e9SAndroid Build Coastguard Worker # test cases that should work 2344*da0073e9SAndroid Build Coastguard Worker nt_sizes = nt._nested_tensor_size() 2345*da0073e9SAndroid Build Coastguard Worker nt_strides = nt._nested_tensor_strides() 2346*da0073e9SAndroid Build Coastguard Worker for i in range(-2, 4): 2347*da0073e9SAndroid Build Coastguard Worker if i == 0: 2348*da0073e9SAndroid Build Coastguard Worker # cannot unsqueeze batch dim 2349*da0073e9SAndroid Build Coastguard Worker continue 2350*da0073e9SAndroid Build Coastguard Worker nt_unsqueezed = nt.unsqueeze(i) 2351*da0073e9SAndroid Build Coastguard Worker # negative dim will correspond to unsqueeze() applied at dim = dim + nt.dim() + 1 2352*da0073e9SAndroid Build Coastguard Worker wrapped_i = i + nt.dim() + 1 if i < 0 else i 2353*da0073e9SAndroid Build Coastguard Worker # col_index into nt size tensor is requires subtraction of 1 to ignore batch dim 2354*da0073e9SAndroid Build Coastguard Worker size_idx = wrapped_i - 1 2355*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2356*da0073e9SAndroid Build Coastguard Worker nt_unsqueezed._nested_tensor_size()[:, size_idx], 2357*da0073e9SAndroid Build Coastguard Worker torch.ones(2, dtype=torch.long), 2358*da0073e9SAndroid Build Coastguard Worker ) 2359*da0073e9SAndroid Build Coastguard Worker unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx] 2360*da0073e9SAndroid Build Coastguard Worker if i == nt.ndim or i == -1: 2361*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long)) 2362*da0073e9SAndroid Build Coastguard Worker else: 2363*da0073e9SAndroid Build Coastguard Worker stride_col_after = nt_strides[:, size_idx] 2364*da0073e9SAndroid Build Coastguard Worker size_col_after = nt_sizes[:, size_idx] 2365*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unsqueezed_stride, stride_col_after * size_col_after) 2366*da0073e9SAndroid Build Coastguard Worker nt_squeezed = nt_unsqueezed.squeeze(i) 2367*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_squeezed, nt) 2368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_squeezed._nested_tensor_size(), nt_sizes) 2369*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_squeezed._nested_tensor_strides(), nt_strides) 2370*da0073e9SAndroid Build Coastguard Worker 2371*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 2372*da0073e9SAndroid Build Coastguard Worker def test_transpose_inference_mode_interaction(self, device, dtype): 2373*da0073e9SAndroid Build Coastguard Worker nt = random_nt(device, dtype, 4, (4, 4)) 2374*da0073e9SAndroid Build Coastguard Worker # Construct in default mode and transpose while in inference mode 2375*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 2376*da0073e9SAndroid Build Coastguard Worker ntT = nt.transpose(-1, -2) 2377*da0073e9SAndroid Build Coastguard Worker ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) 2378*da0073e9SAndroid Build Coastguard Worker pt = torch.nested.to_padded_tensor(nt, 0.0) 2379*da0073e9SAndroid Build Coastguard Worker ptT = pt.transpose(-1, -2) 2380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ptT, ptT_from_ntT) 2381*da0073e9SAndroid Build Coastguard Worker 2382*da0073e9SAndroid Build Coastguard Worker # Construct and transpose while in inference mode 2383*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 2384*da0073e9SAndroid Build Coastguard Worker nt = random_nt(device, dtype, 4, (4, 4)) 2385*da0073e9SAndroid Build Coastguard Worker ntT = nt.transpose(-1, -2) 2386*da0073e9SAndroid Build Coastguard Worker ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) 2387*da0073e9SAndroid Build Coastguard Worker pt = torch.nested.to_padded_tensor(nt, 0.0) 2388*da0073e9SAndroid Build Coastguard Worker ptT = pt.transpose(-1, -2) 2389*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ptT, ptT_from_ntT) 2390*da0073e9SAndroid Build Coastguard Worker 2391*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 2392*da0073e9SAndroid Build Coastguard Worker def test_view(self, device, dtype): 2393*da0073e9SAndroid Build Coastguard Worker nt = random_nt(device, dtype, 4, (4, 4)) 2394*da0073e9SAndroid Build Coastguard Worker # error case: empty shape 2395*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2396*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2397*da0073e9SAndroid Build Coastguard Worker r"shape '\[\]' is invalid for a nested tensor", 2398*da0073e9SAndroid Build Coastguard Worker lambda: nt.view(()), 2399*da0073e9SAndroid Build Coastguard Worker ) 2400*da0073e9SAndroid Build Coastguard Worker # error case: empty nested tensor 2401*da0073e9SAndroid Build Coastguard Worker nt_empty = torch.nested.nested_tensor([]) 2402*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2403*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2404*da0073e9SAndroid Build Coastguard Worker "empty nested tensor cannot be reshaped", 2405*da0073e9SAndroid Build Coastguard Worker lambda: nt_empty.view(-1), 2406*da0073e9SAndroid Build Coastguard Worker ) 2407*da0073e9SAndroid Build Coastguard Worker # error case: -1 for batch size 2408*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2409*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2410*da0073e9SAndroid Build Coastguard Worker r"view: For now nested view cannot change or infer the implicit batch dimension", 2411*da0073e9SAndroid Build Coastguard Worker lambda: nt.view(-1, 2, 3), 2412*da0073e9SAndroid Build Coastguard Worker ) 2413*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2414*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2415*da0073e9SAndroid Build Coastguard Worker r"shape '\[.*\]' is invalid for input of size [0-9]+", 2416*da0073e9SAndroid Build Coastguard Worker lambda: nt.view(4, 2, 3), 2417*da0073e9SAndroid Build Coastguard Worker ) 2418*da0073e9SAndroid Build Coastguard Worker # normal case 2419*da0073e9SAndroid Build Coastguard Worker x0 = torch.randn((2, 20), device=device, dtype=dtype) 2420*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn((3, 20), device=device, dtype=dtype) 2421*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([x0, x1]) 2422*da0073e9SAndroid Build Coastguard Worker pt = torch.nested.to_padded_tensor(nt, 0.0) 2423*da0073e9SAndroid Build Coastguard Worker # error case, trying to reshape batch dim to a legit shape 2424*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2425*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2426*da0073e9SAndroid Build Coastguard Worker r"For now nested view cannot change or infer the implicit batch dimension", 2427*da0073e9SAndroid Build Coastguard Worker lambda: nt.transpose(-1, -2).view(40, -1), 2428*da0073e9SAndroid Build Coastguard Worker ) 2429*da0073e9SAndroid Build Coastguard Worker # inherit only the ragged dimension 2430*da0073e9SAndroid Build Coastguard Worker # (2, 20) -> (2, 5, 4) 2431*da0073e9SAndroid Build Coastguard Worker # (3, 20) -> (3, 5, 4) 2432*da0073e9SAndroid Build Coastguard Worker nt1 = nt.view(2, -1, 5, 4) 2433*da0073e9SAndroid Build Coastguard Worker # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4) 2434*da0073e9SAndroid Build Coastguard Worker pt1 = pt.view(2, -1, 5, 4) 2435*da0073e9SAndroid Build Coastguard Worker self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1) 2436*da0073e9SAndroid Build Coastguard Worker 2437*da0073e9SAndroid Build Coastguard Worker # more than one -1 (even for "old" dims), should fail 2438*da0073e9SAndroid Build Coastguard Worker # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2) 2439*da0073e9SAndroid Build Coastguard Worker # but we ban "inherit old behavior" for >1 dimension 2440*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2441*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2442*da0073e9SAndroid Build Coastguard Worker r"only one dimension can be inferred", 2443*da0073e9SAndroid Build Coastguard Worker lambda: nt1.view(2, -1, -1, 2, 2), 2444*da0073e9SAndroid Build Coastguard Worker ) 2445*da0073e9SAndroid Build Coastguard Worker 2446*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 2447*da0073e9SAndroid Build Coastguard Worker def test_view_inference_mode_interaction(self, device, dtype): 2448*da0073e9SAndroid Build Coastguard Worker # Construct in default mode and view while in inference mode 2449*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 2450*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype 2451*da0073e9SAndroid Build Coastguard Worker ) 2452*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 2453*da0073e9SAndroid Build Coastguard Worker ntT = nt.view(2, -1, 4, 5) 2454*da0073e9SAndroid Build Coastguard Worker ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) 2455*da0073e9SAndroid Build Coastguard Worker pt = torch.nested.to_padded_tensor(nt, 0.0) 2456*da0073e9SAndroid Build Coastguard Worker ptT = pt.view(2, -1, 4, 5) 2457*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ptT, ptT_from_ntT) 2458*da0073e9SAndroid Build Coastguard Worker # Construct and view while in inference mode 2459*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 2460*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 2461*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype 2462*da0073e9SAndroid Build Coastguard Worker ) 2463*da0073e9SAndroid Build Coastguard Worker ntT = nt.view(2, -1, 4, 5) 2464*da0073e9SAndroid Build Coastguard Worker ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) 2465*da0073e9SAndroid Build Coastguard Worker pt = torch.nested.to_padded_tensor(nt, 0.0) 2466*da0073e9SAndroid Build Coastguard Worker ptT = pt.view(2, -1, 4, 5) 2467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ptT, ptT_from_ntT) 2468*da0073e9SAndroid Build Coastguard Worker 2469*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 2470*da0073e9SAndroid Build Coastguard Worker def test_reshape(self, device, dtype): 2471*da0073e9SAndroid Build Coastguard Worker nt = random_nt(device, dtype, 4, (4, 4)) 2472*da0073e9SAndroid Build Coastguard Worker # error case: empty shape 2473*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2474*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2475*da0073e9SAndroid Build Coastguard Worker r"shape '\[\]' is invalid for a nested tensor", 2476*da0073e9SAndroid Build Coastguard Worker lambda: nt.reshape(()), 2477*da0073e9SAndroid Build Coastguard Worker ) 2478*da0073e9SAndroid Build Coastguard Worker # error case: empty nested tensor 2479*da0073e9SAndroid Build Coastguard Worker nt_empty = torch.nested.nested_tensor([]) 2480*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2481*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2482*da0073e9SAndroid Build Coastguard Worker "empty nested tensor cannot be reshaped", 2483*da0073e9SAndroid Build Coastguard Worker lambda: nt_empty.reshape(-1), 2484*da0073e9SAndroid Build Coastguard Worker ) 2485*da0073e9SAndroid Build Coastguard Worker # error case: -1 for batch size 2486*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2487*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2488*da0073e9SAndroid Build Coastguard Worker r"reshape: For now nested reshape cannot change or infer the implicit batch dimension", 2489*da0073e9SAndroid Build Coastguard Worker lambda: nt.reshape(-1, 2, 3), 2490*da0073e9SAndroid Build Coastguard Worker ) 2491*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2492*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2493*da0073e9SAndroid Build Coastguard Worker r"shape '\[.*\]' is invalid for input of size [0-9]+", 2494*da0073e9SAndroid Build Coastguard Worker lambda: nt.reshape(4, 2, 3), 2495*da0073e9SAndroid Build Coastguard Worker ) 2496*da0073e9SAndroid Build Coastguard Worker # normal case 2497*da0073e9SAndroid Build Coastguard Worker x0 = torch.randn((2, 20), device=device, dtype=dtype) 2498*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn((3, 20), device=device, dtype=dtype) 2499*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([x0, x1]) # (2, (2, 3), 20) 2500*da0073e9SAndroid Build Coastguard Worker pt = torch.nested.to_padded_tensor(nt, 0.0) 2501*da0073e9SAndroid Build Coastguard Worker # error case, trying to reshape batch dim to a legit shape 2502*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2503*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2504*da0073e9SAndroid Build Coastguard Worker r"reshape: For now nested reshape cannot change or infer the implicit batch dimension", 2505*da0073e9SAndroid Build Coastguard Worker lambda: nt.transpose(-1, -2).reshape(40, -1), 2506*da0073e9SAndroid Build Coastguard Worker ) 2507*da0073e9SAndroid Build Coastguard Worker # inherit only the ragged dimension 2508*da0073e9SAndroid Build Coastguard Worker # (2, 20) -> (2, 5, 4) 2509*da0073e9SAndroid Build Coastguard Worker # (3, 20) -> (3, 5, 4) 2510*da0073e9SAndroid Build Coastguard Worker nt1 = nt.reshape(2, -1, 5, 4) 2511*da0073e9SAndroid Build Coastguard Worker # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4) 2512*da0073e9SAndroid Build Coastguard Worker pt1 = pt.reshape(2, -1, 5, 4) 2513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1) 2514*da0073e9SAndroid Build Coastguard Worker 2515*da0073e9SAndroid Build Coastguard Worker # more than one -1 (even for "old" dims), should fail 2516*da0073e9SAndroid Build Coastguard Worker # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2) 2517*da0073e9SAndroid Build Coastguard Worker # but we ban "inherit old behavior" for >1 dimension 2518*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2519*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2520*da0073e9SAndroid Build Coastguard Worker r"only one dimension can be inferred", 2521*da0073e9SAndroid Build Coastguard Worker lambda: nt1.reshape(2, -1, -1, 2, 2), 2522*da0073e9SAndroid Build Coastguard Worker ) 2523*da0073e9SAndroid Build Coastguard Worker 2524*da0073e9SAndroid Build Coastguard Worker def test_nested_masked_select(self, device): 2525*da0073e9SAndroid Build Coastguard Worker t = torch.randn([3, 3], device=device) 2526*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([False], device=device) 2527*da0073e9SAndroid Build Coastguard Worker 2528*da0073e9SAndroid Build Coastguard Worker njt = torch.nested.masked_select(t, mask) 2529*da0073e9SAndroid Build Coastguard Worker self.assertEqual(njt.values(), torch.tensor([], device=device)) 2530*da0073e9SAndroid Build Coastguard Worker self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 0], device=device)) 2531*da0073e9SAndroid Build Coastguard Worker 2532*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([[False], [False], [True]], device=device) 2533*da0073e9SAndroid Build Coastguard Worker njt = torch.nested.masked_select(t, mask) 2534*da0073e9SAndroid Build Coastguard Worker self.assertEqual(njt.values(), t[-1], atol=0.1, rtol=0.1) 2535*da0073e9SAndroid Build Coastguard Worker self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 3], device=device)) 2536*da0073e9SAndroid Build Coastguard Worker 2537*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor( 2538*da0073e9SAndroid Build Coastguard Worker [[False, False, True], [True, False, True], [False, False, True]], 2539*da0073e9SAndroid Build Coastguard Worker device=device, 2540*da0073e9SAndroid Build Coastguard Worker ) 2541*da0073e9SAndroid Build Coastguard Worker njt = torch.nested.masked_select(t, mask) 2542*da0073e9SAndroid Build Coastguard Worker self.assertEqual(njt.values(), t.masked_select(mask)) 2543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(njt.offsets(), torch.tensor([0, 1, 3, 4], device=device)) 2544*da0073e9SAndroid Build Coastguard Worker 2545*da0073e9SAndroid Build Coastguard Worker t = torch.randn([2, 3, 3, 1], device=device) 2546*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor( 2547*da0073e9SAndroid Build Coastguard Worker [ 2548*da0073e9SAndroid Build Coastguard Worker [ 2549*da0073e9SAndroid Build Coastguard Worker [[True], [False], [True]], 2550*da0073e9SAndroid Build Coastguard Worker [[True], [False], [True]], 2551*da0073e9SAndroid Build Coastguard Worker [[True], [False], [True]], 2552*da0073e9SAndroid Build Coastguard Worker ], 2553*da0073e9SAndroid Build Coastguard Worker [ 2554*da0073e9SAndroid Build Coastguard Worker [[False], [True], [True]], 2555*da0073e9SAndroid Build Coastguard Worker [[False], [True], [True]], 2556*da0073e9SAndroid Build Coastguard Worker [[True], [True], [True]], 2557*da0073e9SAndroid Build Coastguard Worker ], 2558*da0073e9SAndroid Build Coastguard Worker ], 2559*da0073e9SAndroid Build Coastguard Worker device=device, 2560*da0073e9SAndroid Build Coastguard Worker ) 2561*da0073e9SAndroid Build Coastguard Worker njt = torch.nested.masked_select(t, mask) 2562*da0073e9SAndroid Build Coastguard Worker self.assertEqual(njt.values(), t.masked_select(mask)) 2563*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2564*da0073e9SAndroid Build Coastguard Worker njt.offsets(), 2565*da0073e9SAndroid Build Coastguard Worker torch.tensor( 2566*da0073e9SAndroid Build Coastguard Worker [0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 11, 12, 13], 2567*da0073e9SAndroid Build Coastguard Worker device=device, 2568*da0073e9SAndroid Build Coastguard Worker ), 2569*da0073e9SAndroid Build Coastguard Worker ) 2570*da0073e9SAndroid Build Coastguard Worker 2571*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 2572*da0073e9SAndroid Build Coastguard Worker def test_narrow(self, device, dtype): 2573*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims([5, None, None, None], device=device, dtype=dtype) 2574*da0073e9SAndroid Build Coastguard Worker 2575*da0073e9SAndroid Build Coastguard Worker # narrow on dim=0 from start to end 2576*da0073e9SAndroid Build Coastguard Worker bounds = [(0, 5), (0, 3), (1, 2), (1, 5), (2, 4)] 2577*da0073e9SAndroid Build Coastguard Worker for start, end in bounds: 2578*da0073e9SAndroid Build Coastguard Worker length = end - start 2579*da0073e9SAndroid Build Coastguard Worker narrowed = nt.narrow(dim=0, start=start, length=length) 2580*da0073e9SAndroid Build Coastguard Worker # ensure output is a view 2581*da0073e9SAndroid Build Coastguard Worker self.assertTrue(narrowed._base is nt) 2582*da0073e9SAndroid Build Coastguard Worker for nc, c in zip(narrowed.unbind(), nt.unbind()[start:end]): 2583*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nc, c) 2584*da0073e9SAndroid Build Coastguard Worker 2585*da0073e9SAndroid Build Coastguard Worker # dim != 0 is not supported 2586*da0073e9SAndroid Build Coastguard Worker for dim in range(1, nt.dim()): 2587*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2588*da0073e9SAndroid Build Coastguard Worker RuntimeError, "only dim=0 supported for nested tensors" 2589*da0073e9SAndroid Build Coastguard Worker ): 2590*da0073e9SAndroid Build Coastguard Worker nt.narrow(dim=dim, start=0, length=1) 2591*da0073e9SAndroid Build Coastguard Worker 2592*da0073e9SAndroid Build Coastguard Worker # error case: non-contiguous NT 2593*da0073e9SAndroid Build Coastguard Worker _, nt_noncont = random_nt_noncontiguous_pair((2, 3, 4)) 2594*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2595*da0073e9SAndroid Build Coastguard Worker RuntimeError, "only contiguous nested tensors supported" 2596*da0073e9SAndroid Build Coastguard Worker ): 2597*da0073e9SAndroid Build Coastguard Worker nt_noncont.narrow(dim=0, start=0, length=1) 2598*da0073e9SAndroid Build Coastguard Worker 2599*da0073e9SAndroid Build Coastguard Worker @parametrize("input_dim", [3, 4]) 2600*da0073e9SAndroid Build Coastguard Worker def test_scaled_dot_product_attention(self, device, input_dim): 2601*da0073e9SAndroid Build Coastguard Worker def rand_tensor(*shape): 2602*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, device=device) 2603*da0073e9SAndroid Build Coastguard Worker 2604*da0073e9SAndroid Build Coastguard Worker E = 8 2605*da0073e9SAndroid Build Coastguard Worker if input_dim == 3: 2606*da0073e9SAndroid Build Coastguard Worker # Shape: (N, L, E); ragged L 2607*da0073e9SAndroid Build Coastguard Worker query = torch.nested.nested_tensor( 2608*da0073e9SAndroid Build Coastguard Worker [rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)] 2609*da0073e9SAndroid Build Coastguard Worker ) 2610*da0073e9SAndroid Build Coastguard Worker 2611*da0073e9SAndroid Build Coastguard Worker # Shape: (N, S, E); ragged S 2612*da0073e9SAndroid Build Coastguard Worker key = torch.nested.nested_tensor( 2613*da0073e9SAndroid Build Coastguard Worker [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)] 2614*da0073e9SAndroid Build Coastguard Worker ) 2615*da0073e9SAndroid Build Coastguard Worker value = torch.nested.nested_tensor( 2616*da0073e9SAndroid Build Coastguard Worker [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)] 2617*da0073e9SAndroid Build Coastguard Worker ) 2618*da0073e9SAndroid Build Coastguard Worker elif input_dim == 4: 2619*da0073e9SAndroid Build Coastguard Worker # In the 4D case the L and S is ragged 2620*da0073e9SAndroid Build Coastguard Worker # Shape: (N, N', L, E); ragged N' and L 2621*da0073e9SAndroid Build Coastguard Worker query = torch.nested.nested_tensor( 2622*da0073e9SAndroid Build Coastguard Worker [rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)] 2623*da0073e9SAndroid Build Coastguard Worker ) 2624*da0073e9SAndroid Build Coastguard Worker # Shape: (N, N', S, E); ragged N' and S 2625*da0073e9SAndroid Build Coastguard Worker key = torch.nested.nested_tensor( 2626*da0073e9SAndroid Build Coastguard Worker [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)] 2627*da0073e9SAndroid Build Coastguard Worker ) 2628*da0073e9SAndroid Build Coastguard Worker value = torch.nested.nested_tensor( 2629*da0073e9SAndroid Build Coastguard Worker [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)] 2630*da0073e9SAndroid Build Coastguard Worker ) 2631*da0073e9SAndroid Build Coastguard Worker else: 2632*da0073e9SAndroid Build Coastguard Worker self.fail(f"Invalid input_dim {input_dim} encountered in SDP test") 2633*da0073e9SAndroid Build Coastguard Worker 2634*da0073e9SAndroid Build Coastguard Worker def rand_mask(size): 2635*da0073e9SAndroid Build Coastguard Worker return torch.randint(0, 2, size=size, dtype=torch.bool, device=device) 2636*da0073e9SAndroid Build Coastguard Worker 2637*da0073e9SAndroid Build Coastguard Worker # Shape: (N, L, S); ragged L and S matching above 2638*da0073e9SAndroid Build Coastguard Worker attn_mask = torch.nested.nested_tensor( 2639*da0073e9SAndroid Build Coastguard Worker [rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))] 2640*da0073e9SAndroid Build Coastguard Worker ) 2641*da0073e9SAndroid Build Coastguard Worker 2642*da0073e9SAndroid Build Coastguard Worker dropout_p = 0.0 # no dropout for reproducibility 2643*da0073e9SAndroid Build Coastguard Worker 2644*da0073e9SAndroid Build Coastguard Worker # Success case: no attn_mask set and is_causal=False. 2645*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.scaled_dot_product_attention( 2646*da0073e9SAndroid Build Coastguard Worker query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p 2647*da0073e9SAndroid Build Coastguard Worker ) 2648*da0073e9SAndroid Build Coastguard Worker 2649*da0073e9SAndroid Build Coastguard Worker expected_outputs = [] 2650*da0073e9SAndroid Build Coastguard Worker for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()): 2651*da0073e9SAndroid Build Coastguard Worker output = torch.nn.functional.scaled_dot_product_attention( 2652*da0073e9SAndroid Build Coastguard Worker q.unsqueeze(0), 2653*da0073e9SAndroid Build Coastguard Worker k.unsqueeze(0), 2654*da0073e9SAndroid Build Coastguard Worker v.unsqueeze(0), 2655*da0073e9SAndroid Build Coastguard Worker attn_mask=None, 2656*da0073e9SAndroid Build Coastguard Worker dropout_p=dropout_p, 2657*da0073e9SAndroid Build Coastguard Worker ) 2658*da0073e9SAndroid Build Coastguard Worker expected_outputs.append(output.squeeze(0)) 2659*da0073e9SAndroid Build Coastguard Worker expected_output_nested = torch.nested.nested_tensor(expected_outputs) 2660*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected_output_nested) 2661*da0073e9SAndroid Build Coastguard Worker 2662*da0073e9SAndroid Build Coastguard Worker # Error case: explicit attn_mask set. 2663*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2664*da0073e9SAndroid Build Coastguard Worker RuntimeError, "not supported when an explicit attn_mask is set" 2665*da0073e9SAndroid Build Coastguard Worker ): 2666*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.scaled_dot_product_attention( 2667*da0073e9SAndroid Build Coastguard Worker query, key, value, attn_mask=attn_mask, dropout_p=dropout_p 2668*da0073e9SAndroid Build Coastguard Worker ) 2669*da0073e9SAndroid Build Coastguard Worker 2670*da0073e9SAndroid Build Coastguard Worker # Error case: is_causal=True. 2671*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"): 2672*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.scaled_dot_product_attention( 2673*da0073e9SAndroid Build Coastguard Worker query, key, value, dropout_p=dropout_p, is_causal=True 2674*da0073e9SAndroid Build Coastguard Worker ) 2675*da0073e9SAndroid Build Coastguard Worker 2676*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.float16, torch.double) 2677*da0073e9SAndroid Build Coastguard Worker def test_empty_like(self, device, dtype): 2678*da0073e9SAndroid Build Coastguard Worker ntensors = 4 2679*da0073e9SAndroid Build Coastguard Worker nt = random_nt(device, dtype, ntensors, (4, 4)) 2680*da0073e9SAndroid Build Coastguard Worker 2681*da0073e9SAndroid Build Coastguard Worker # Create empty on same device as original nested tensor 2682*da0073e9SAndroid Build Coastguard Worker nt_empty = torch.empty_like(nt) 2683*da0073e9SAndroid Build Coastguard Worker assert nt.is_same_size(nt_empty) 2684*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.dtype, nt_empty.dtype) 2685*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.device, nt_empty.device) 2686*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.layout, nt_empty.layout) 2687*da0073e9SAndroid Build Coastguard Worker 2688*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 2689*da0073e9SAndroid Build Coastguard Worker if device == "cpu": 2690*da0073e9SAndroid Build Coastguard Worker nt_cuda = torch.empty_like(nt, device="cuda") 2691*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.device("cuda").type, nt_cuda.device.type) 2692*da0073e9SAndroid Build Coastguard Worker else: 2693*da0073e9SAndroid Build Coastguard Worker nt_cpu = torch.empty_like(nt, device="cpu") 2694*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.device("cpu").type, nt_cpu.device.type) 2695*da0073e9SAndroid Build Coastguard Worker 2696*da0073e9SAndroid Build Coastguard Worker # Check changing dtype of empty_like nested tensor output 2697*da0073e9SAndroid Build Coastguard Worker dtype_set = {torch.float, torch.float16, torch.double} 2698*da0073e9SAndroid Build Coastguard Worker for other_dtype in dtype_set - {dtype}: 2699*da0073e9SAndroid Build Coastguard Worker nt_empty_other_dtype = torch.empty_like(nt, dtype=other_dtype) 2700*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.dtype, dtype) 2701*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_empty_other_dtype.dtype, other_dtype) 2702*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.device, nt_empty.device) 2703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.layout, nt_empty.layout) 2704*da0073e9SAndroid Build Coastguard Worker 2705*da0073e9SAndroid Build Coastguard Worker # Create tensor for autograd 2706*da0073e9SAndroid Build Coastguard Worker nt_empty_req_grad = torch.empty_like(nt, requires_grad=True) 2707*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_empty_req_grad.requires_grad, True) 2708*da0073e9SAndroid Build Coastguard Worker 2709*da0073e9SAndroid Build Coastguard Worker # Test noncontiguous tensor does not fail to copy 2710*da0073e9SAndroid Build Coastguard Worker nt_cont, nt_noncont = random_nt_noncontiguous_pair((2, 3, 6, 7)) 2711*da0073e9SAndroid Build Coastguard Worker nt_empty = torch.empty_like(nt_cont) 2712*da0073e9SAndroid Build Coastguard Worker assert nt_cont.is_same_size(nt_empty) 2713*da0073e9SAndroid Build Coastguard Worker nt_empty_non_contig = torch.empty_like(nt_noncont) 2714*da0073e9SAndroid Build Coastguard Worker assert nt_noncont.is_same_size(nt_empty_non_contig) 2715*da0073e9SAndroid Build Coastguard Worker 2716*da0073e9SAndroid Build Coastguard Worker # Test the contiguous memory format option 2717*da0073e9SAndroid Build Coastguard Worker nt_empty_contig = torch.empty_like( 2718*da0073e9SAndroid Build Coastguard Worker nt_cont, memory_format=torch.contiguous_format 2719*da0073e9SAndroid Build Coastguard Worker ) 2720*da0073e9SAndroid Build Coastguard Worker assert nt_cont.is_same_size(nt_empty_contig) 2721*da0073e9SAndroid Build Coastguard Worker assert nt_empty_contig.is_contiguous() 2722*da0073e9SAndroid Build Coastguard Worker 2723*da0073e9SAndroid Build Coastguard Worker nt_empty_non_contig = torch.empty_like( 2724*da0073e9SAndroid Build Coastguard Worker nt_noncont, memory_format=torch.contiguous_format 2725*da0073e9SAndroid Build Coastguard Worker ) 2726*da0073e9SAndroid Build Coastguard Worker assert nt_noncont.is_same_size(nt_empty_non_contig) 2727*da0073e9SAndroid Build Coastguard Worker assert nt_empty_non_contig.is_contiguous() 2728*da0073e9SAndroid Build Coastguard Worker 2729*da0073e9SAndroid Build Coastguard Worker # Test other memory formats fail 2730*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 2731*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2732*da0073e9SAndroid Build Coastguard Worker lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last), 2733*da0073e9SAndroid Build Coastguard Worker ) 2734*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 2735*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2736*da0073e9SAndroid Build Coastguard Worker lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last), 2737*da0073e9SAndroid Build Coastguard Worker ) 2738*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 2739*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2740*da0073e9SAndroid Build Coastguard Worker lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d), 2741*da0073e9SAndroid Build Coastguard Worker ) 2742*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 2743*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2744*da0073e9SAndroid Build Coastguard Worker lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d), 2745*da0073e9SAndroid Build Coastguard Worker ) 2746*da0073e9SAndroid Build Coastguard Worker 2747*da0073e9SAndroid Build Coastguard Worker 2748*da0073e9SAndroid Build Coastguard Worker@markDynamoStrictTest 2749*da0073e9SAndroid Build Coastguard Workerclass TestNestedTensorAutograd(NestedTensorTestCase): 2750*da0073e9SAndroid Build Coastguard Worker # Note [Gradcheck args check_batched_grad=False] the common_utils testing version of gradcheck 2751*da0073e9SAndroid Build Coastguard Worker # includes the default parameters used for testing ops with gradcheck. However nested tensor 2752*da0073e9SAndroid Build Coastguard Worker # does not support the stack op therefore we turn it off for these tests 2753*da0073e9SAndroid Build Coastguard Worker def _create_leaf_nested_tensor_from_list(self, tensor_device, requires_grad=False): 2754*da0073e9SAndroid Build Coastguard Worker return torch.nested.nested_tensor( 2755*da0073e9SAndroid Build Coastguard Worker [torch.randn(1, 2), torch.randn(7, 8)], 2756*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 2757*da0073e9SAndroid Build Coastguard Worker device=tensor_device, 2758*da0073e9SAndroid Build Coastguard Worker ) 2759*da0073e9SAndroid Build Coastguard Worker 2760*da0073e9SAndroid Build Coastguard Worker def _create_nested_tensor_from_list(self, tensor_device, requires_grad=False): 2761*da0073e9SAndroid Build Coastguard Worker return torch.nested.as_nested_tensor( 2762*da0073e9SAndroid Build Coastguard Worker [ 2763*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, requires_grad=requires_grad), 2764*da0073e9SAndroid Build Coastguard Worker torch.randn(7, 8, requires_grad=requires_grad), 2765*da0073e9SAndroid Build Coastguard Worker ], 2766*da0073e9SAndroid Build Coastguard Worker device=tensor_device, 2767*da0073e9SAndroid Build Coastguard Worker ) 2768*da0073e9SAndroid Build Coastguard Worker 2769*da0073e9SAndroid Build Coastguard Worker def _create_nested_tensor_from_mask(self, tensor_device, requires_grad=False): 2770*da0073e9SAndroid Build Coastguard Worker data = torch.randn(2, 3, 4, requires_grad=requires_grad, device=tensor_device) 2771*da0073e9SAndroid Build Coastguard Worker mask = torch.ones_like(data[:, :, 0]).bool() 2772*da0073e9SAndroid Build Coastguard Worker return torch._nested_tensor_from_mask(data, mask) 2773*da0073e9SAndroid Build Coastguard Worker 2774*da0073e9SAndroid Build Coastguard Worker def test_as_nested_tensor_propagates_gradients(self, device): 2775*da0073e9SAndroid Build Coastguard Worker a = torch.arange(3, dtype=torch.float, device=device) 2776*da0073e9SAndroid Build Coastguard Worker b = torch.arange(5, dtype=torch.float, device=device) 2777*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b]) 2778*da0073e9SAndroid Build Coastguard Worker # tensors with requires_grad=False are leaves 2779*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nt.is_leaf) 2780*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not nt.requires_grad) 2781*da0073e9SAndroid Build Coastguard Worker 2782*da0073e9SAndroid Build Coastguard Worker a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device) 2783*da0073e9SAndroid Build Coastguard Worker b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device) 2784*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.as_nested_tensor([a, b]) 2785*da0073e9SAndroid Build Coastguard Worker fake_grad = torch.nested.nested_tensor( 2786*da0073e9SAndroid Build Coastguard Worker [torch.ones_like(a), torch.zeros_like(b)], device=device 2787*da0073e9SAndroid Build Coastguard Worker ) 2788*da0073e9SAndroid Build Coastguard Worker nt2.backward(fake_grad) 2789*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, fake_grad[0]) 2790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, fake_grad[1]) 2791*da0073e9SAndroid Build Coastguard Worker 2792*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_generates_leaf(self, device): 2793*da0073e9SAndroid Build Coastguard Worker a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device) 2794*da0073e9SAndroid Build Coastguard Worker b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device) 2795*da0073e9SAndroid Build Coastguard Worker 2796*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([a, b], requires_grad=False) 2797*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nt.is_leaf) 2798*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not nt.requires_grad) 2799*da0073e9SAndroid Build Coastguard Worker 2800*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor([a, b], requires_grad=True) 2801*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nt2.is_leaf) 2802*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nt2.requires_grad) 2803*da0073e9SAndroid Build Coastguard Worker 2804*da0073e9SAndroid Build Coastguard Worker fake_grad = torch.nested.nested_tensor( 2805*da0073e9SAndroid Build Coastguard Worker [torch.ones_like(a), torch.zeros_like(b)], device=device 2806*da0073e9SAndroid Build Coastguard Worker ) 2807*da0073e9SAndroid Build Coastguard Worker nt2.backward(fake_grad) 2808*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt2.grad, fake_grad) 2809*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, None) 2810*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, None) 2811*da0073e9SAndroid Build Coastguard Worker 2812*da0073e9SAndroid Build Coastguard Worker def test_set_requires_grad_from_list(self, device): 2813*da0073e9SAndroid Build Coastguard Worker nt = self._create_nested_tensor_from_list(device) 2814*da0073e9SAndroid Build Coastguard Worker nt.requires_grad_() 2815*da0073e9SAndroid Build Coastguard Worker assert nt.requires_grad 2816*da0073e9SAndroid Build Coastguard Worker 2817*da0073e9SAndroid Build Coastguard Worker def test_set_requires_grad_from_mask(self, device): 2818*da0073e9SAndroid Build Coastguard Worker nt = self._create_nested_tensor_from_mask(device) 2819*da0073e9SAndroid Build Coastguard Worker nt.requires_grad_() 2820*da0073e9SAndroid Build Coastguard Worker assert nt.requires_grad 2821*da0073e9SAndroid Build Coastguard Worker 2822*da0073e9SAndroid Build Coastguard Worker def test_backward_for_add_op(self, device): 2823*da0073e9SAndroid Build Coastguard Worker nt_1 = self._create_nested_tensor_from_mask(device) 2824*da0073e9SAndroid Build Coastguard Worker nt_2 = self._create_nested_tensor_from_mask(device) 2825*da0073e9SAndroid Build Coastguard Worker 2826*da0073e9SAndroid Build Coastguard Worker nt_1.requires_grad_() 2827*da0073e9SAndroid Build Coastguard Worker c = nt_1 + nt_2 2828*da0073e9SAndroid Build Coastguard Worker 2829*da0073e9SAndroid Build Coastguard Worker assert nt_1.requires_grad 2830*da0073e9SAndroid Build Coastguard Worker assert c.requires_grad 2831*da0073e9SAndroid Build Coastguard Worker grad_output = self._create_nested_tensor_from_mask(device) 2832*da0073e9SAndroid Build Coastguard Worker c.backward(grad_output) 2833*da0073e9SAndroid Build Coastguard Worker 2834*da0073e9SAndroid Build Coastguard Worker # Grad check doesn't work with nested yet. 2835*da0073e9SAndroid Build Coastguard Worker # d/dnt_1 (nt + nt_1) = 1*grad_output 2836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_1.grad, grad_output) 2837*da0073e9SAndroid Build Coastguard Worker 2838*da0073e9SAndroid Build Coastguard Worker def test_backward_for_sub_op(self, device): 2839*da0073e9SAndroid Build Coastguard Worker nt_1 = self._create_nested_tensor_from_mask(device) 2840*da0073e9SAndroid Build Coastguard Worker nt_2 = self._create_nested_tensor_from_mask(device) 2841*da0073e9SAndroid Build Coastguard Worker 2842*da0073e9SAndroid Build Coastguard Worker nt_1.requires_grad_() 2843*da0073e9SAndroid Build Coastguard Worker nt_2.requires_grad_() 2844*da0073e9SAndroid Build Coastguard Worker c = nt_1 - nt_2 2845*da0073e9SAndroid Build Coastguard Worker 2846*da0073e9SAndroid Build Coastguard Worker assert nt_1.requires_grad 2847*da0073e9SAndroid Build Coastguard Worker assert nt_2.requires_grad 2848*da0073e9SAndroid Build Coastguard Worker assert c.requires_grad 2849*da0073e9SAndroid Build Coastguard Worker grad_output = self._create_nested_tensor_from_mask(device) 2850*da0073e9SAndroid Build Coastguard Worker c.backward(grad_output) 2851*da0073e9SAndroid Build Coastguard Worker 2852*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_1.grad, grad_output) 2853*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_2.grad, -1 * grad_output) 2854*da0073e9SAndroid Build Coastguard Worker 2855*da0073e9SAndroid Build Coastguard Worker def test_backward_sub_strided(self, device): 2856*da0073e9SAndroid Build Coastguard Worker a = torch.nested.nested_tensor( 2857*da0073e9SAndroid Build Coastguard Worker [torch.randn(9, 2, 4), torch.randn(12, 2, 4)], 2858*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 2859*da0073e9SAndroid Build Coastguard Worker device=device, 2860*da0073e9SAndroid Build Coastguard Worker ) 2861*da0073e9SAndroid Build Coastguard Worker b = torch.nested.nested_tensor( 2862*da0073e9SAndroid Build Coastguard Worker [torch.randn(9, 4, 2), torch.randn(12, 4, 2)], 2863*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 2864*da0073e9SAndroid Build Coastguard Worker device=device, 2865*da0073e9SAndroid Build Coastguard Worker ) 2866*da0073e9SAndroid Build Coastguard Worker c = a - b.transpose(-1, -2) 2867*da0073e9SAndroid Build Coastguard Worker grad_output = c.clone() 2868*da0073e9SAndroid Build Coastguard Worker c.backward(grad_output) 2869*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, grad_output) 2870*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, -1 * grad_output.transpose(-1, -2)) 2871*da0073e9SAndroid Build Coastguard Worker 2872*da0073e9SAndroid Build Coastguard Worker def test_backward_add_strided(self, device): 2873*da0073e9SAndroid Build Coastguard Worker a = torch.nested.nested_tensor( 2874*da0073e9SAndroid Build Coastguard Worker [torch.randn(9, 2, 4), torch.randn(12, 2, 4)], 2875*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 2876*da0073e9SAndroid Build Coastguard Worker device=device, 2877*da0073e9SAndroid Build Coastguard Worker ) 2878*da0073e9SAndroid Build Coastguard Worker b = torch.nested.nested_tensor( 2879*da0073e9SAndroid Build Coastguard Worker [torch.randn(9, 4, 2), torch.randn(12, 4, 2)], 2880*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 2881*da0073e9SAndroid Build Coastguard Worker device=device, 2882*da0073e9SAndroid Build Coastguard Worker ) 2883*da0073e9SAndroid Build Coastguard Worker c = a + b.transpose(-1, -2) 2884*da0073e9SAndroid Build Coastguard Worker grad_output = c.clone() 2885*da0073e9SAndroid Build Coastguard Worker c.backward(grad_output) 2886*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, grad_output) 2887*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, grad_output.transpose(-1, -2)) 2888*da0073e9SAndroid Build Coastguard Worker 2889*da0073e9SAndroid Build Coastguard Worker # Test Factory Functions 2890*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_to_padded_tensor(self, device): 2891*da0073e9SAndroid Build Coastguard Worker for padding_val in [0, 1]: 2892*da0073e9SAndroid Build Coastguard Worker nt = self._create_leaf_nested_tensor_from_list( 2893*da0073e9SAndroid Build Coastguard Worker tensor_device=device, requires_grad=True 2894*da0073e9SAndroid Build Coastguard Worker ) 2895*da0073e9SAndroid Build Coastguard Worker 2896*da0073e9SAndroid Build Coastguard Worker out = torch.nested.to_padded_tensor(nt, padding_val) 2897*da0073e9SAndroid Build Coastguard Worker grad_output = torch.ones(out.shape, device=device) 2898*da0073e9SAndroid Build Coastguard Worker out.backward(grad_output) 2899*da0073e9SAndroid Build Coastguard Worker 2900*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2901*da0073e9SAndroid Build Coastguard Worker nt.grad, 2902*da0073e9SAndroid Build Coastguard Worker torch.nested.nested_tensor( 2903*da0073e9SAndroid Build Coastguard Worker [torch.ones(1, 2), torch.ones(7, 8)], device=device 2904*da0073e9SAndroid Build Coastguard Worker ), 2905*da0073e9SAndroid Build Coastguard Worker ) 2906*da0073e9SAndroid Build Coastguard Worker 2907*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_from_mask_and_to_padded(self, device): 2908*da0073e9SAndroid Build Coastguard Worker N, L, D = 2, 4, 4 2909*da0073e9SAndroid Build Coastguard Worker mask = torch.ones(N, L, device=device) 2910*da0073e9SAndroid Build Coastguard Worker for i in range(1, N): 2911*da0073e9SAndroid Build Coastguard Worker end = torch.randint(1, L - 1, (1,), device=device) 2912*da0073e9SAndroid Build Coastguard Worker mask[i, end:] = 0 2913*da0073e9SAndroid Build Coastguard Worker 2914*da0073e9SAndroid Build Coastguard Worker mask[0, :] = 1 2915*da0073e9SAndroid Build Coastguard Worker mask = mask.bool() 2916*da0073e9SAndroid Build Coastguard Worker 2917*da0073e9SAndroid Build Coastguard Worker data = torch.randn( 2918*da0073e9SAndroid Build Coastguard Worker N, L, D, requires_grad=True, dtype=torch.float64, device=device 2919*da0073e9SAndroid Build Coastguard Worker ) 2920*da0073e9SAndroid Build Coastguard Worker 2921*da0073e9SAndroid Build Coastguard Worker def grad_test_func(inpt): 2922*da0073e9SAndroid Build Coastguard Worker nt = torch._nested_tensor_from_mask(inpt, mask) 2923*da0073e9SAndroid Build Coastguard Worker # This implicitly tests to_padded_tensor grads 2924*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(nt, 0) 2925*da0073e9SAndroid Build Coastguard Worker 2926*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 2927*da0073e9SAndroid Build Coastguard Worker 2928*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_from_padded(self, device): 2929*da0073e9SAndroid Build Coastguard Worker nested_size = torch.tensor([[1, 2], [2, 2]]) 2930*da0073e9SAndroid Build Coastguard Worker padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64, device=device) 2931*da0073e9SAndroid Build Coastguard Worker padded_tensor[0, 1, :] = 0 2932*da0073e9SAndroid Build Coastguard Worker padded_tensor.requires_grad_() 2933*da0073e9SAndroid Build Coastguard Worker 2934*da0073e9SAndroid Build Coastguard Worker def grad_test_func(tensor, nested_size): 2935*da0073e9SAndroid Build Coastguard Worker nt = torch._nested_from_padded( 2936*da0073e9SAndroid Build Coastguard Worker tensor, nested_size, fuse_transform_0213=False 2937*da0073e9SAndroid Build Coastguard Worker ) 2938*da0073e9SAndroid Build Coastguard Worker # This implicitly tests to_padded_tensor grads 2939*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(nt, 0) 2940*da0073e9SAndroid Build Coastguard Worker 2941*da0073e9SAndroid Build Coastguard Worker data = (padded_tensor, nested_size) 2942*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 2943*da0073e9SAndroid Build Coastguard Worker 2944*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_from_padded_fused(self, device): 2945*da0073e9SAndroid Build Coastguard Worker nested_size = torch.tensor([[1, 8], [2, 8]]) 2946*da0073e9SAndroid Build Coastguard Worker padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64, device=device) 2947*da0073e9SAndroid Build Coastguard Worker padded_tensor[0, 1, :] = 0 2948*da0073e9SAndroid Build Coastguard Worker padded_tensor.requires_grad_() 2949*da0073e9SAndroid Build Coastguard Worker 2950*da0073e9SAndroid Build Coastguard Worker def grad_test_func(tensor, nested_size): 2951*da0073e9SAndroid Build Coastguard Worker nt = torch._nested_from_padded( 2952*da0073e9SAndroid Build Coastguard Worker tensor, nested_size, fuse_transform_0213=True 2953*da0073e9SAndroid Build Coastguard Worker ) 2954*da0073e9SAndroid Build Coastguard Worker # This implicitly tests to_padded_tensor grads 2955*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(nt, 0) 2956*da0073e9SAndroid Build Coastguard Worker 2957*da0073e9SAndroid Build Coastguard Worker data = (padded_tensor, nested_size) 2958*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 2959*da0073e9SAndroid Build Coastguard Worker 2960*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_from_list(self, device): 2961*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) 2962*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) 2963*da0073e9SAndroid Build Coastguard Worker c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64, device=device) 2964*da0073e9SAndroid Build Coastguard Worker 2965*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 2966*da0073e9SAndroid Build Coastguard Worker c = torch.nested.as_nested_tensor([a, b, c]) 2967*da0073e9SAndroid Build Coastguard Worker # This implictily tests to_padded_tensor grads 2968*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(c, 0) 2969*da0073e9SAndroid Build Coastguard Worker 2970*da0073e9SAndroid Build Coastguard Worker data = (a, b, c) 2971*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 2972*da0073e9SAndroid Build Coastguard Worker 2973*da0073e9SAndroid Build Coastguard Worker @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) 2974*da0073e9SAndroid Build Coastguard Worker def test_dropout_backward(self, layout): 2975*da0073e9SAndroid Build Coastguard Worker if layout == torch.jagged: 2976*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 2977*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 5)), torch.randn((3, 5))], 2978*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 2979*da0073e9SAndroid Build Coastguard Worker layout=layout, 2980*da0073e9SAndroid Build Coastguard Worker ) 2981*da0073e9SAndroid Build Coastguard Worker else: 2982*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 2983*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 5)), torch.randn((3, 4))], 2984*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 2985*da0073e9SAndroid Build Coastguard Worker layout=layout, 2986*da0073e9SAndroid Build Coastguard Worker ) 2987*da0073e9SAndroid Build Coastguard Worker p = 0.2 2988*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.dropout(nt, p) 2989*da0073e9SAndroid Build Coastguard Worker y.backward(nt.clone().detach()) 2990*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.grad, y) 2991*da0073e9SAndroid Build Coastguard Worker 2992*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_bmm_gradcheck(self, device): 2993*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device) 2994*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device) 2995*da0073e9SAndroid Build Coastguard Worker c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device) 2996*da0073e9SAndroid Build Coastguard Worker d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device) 2997*da0073e9SAndroid Build Coastguard Worker 2998*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c, d): 2999*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.as_nested_tensor([a, b]) 3000*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.as_nested_tensor([c, d]) 3001*da0073e9SAndroid Build Coastguard Worker result = nt0.bmm(nt1) 3002*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(result, 0.0) 3003*da0073e9SAndroid Build Coastguard Worker 3004*da0073e9SAndroid Build Coastguard Worker data = (a, b, c, d) 3005*da0073e9SAndroid Build Coastguard Worker assert torch.autograd.gradcheck(grad_test_func, inputs=data) 3006*da0073e9SAndroid Build Coastguard Worker 3007*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_bmm_backward(self, device): 3008*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 3009*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 6)), torch.randn((3, 6))], 3010*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 3011*da0073e9SAndroid Build Coastguard Worker device=device, 3012*da0073e9SAndroid Build Coastguard Worker ) 3013*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 3014*da0073e9SAndroid Build Coastguard Worker [torch.randn((6, 4)), torch.randn((6, 5))], 3015*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 3016*da0073e9SAndroid Build Coastguard Worker device=device, 3017*da0073e9SAndroid Build Coastguard Worker ) 3018*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3019*da0073e9SAndroid Build Coastguard Worker pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) 3020*da0073e9SAndroid Build Coastguard Worker pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) 3021*da0073e9SAndroid Build Coastguard Worker 3022*da0073e9SAndroid Build Coastguard Worker ynt = nt0.bmm(nt1) 3023*da0073e9SAndroid Build Coastguard Worker ypt = pt0.bmm(pt1) 3024*da0073e9SAndroid Build Coastguard Worker ynt.backward(ynt.clone()) 3025*da0073e9SAndroid Build Coastguard Worker ypt.backward(ypt.clone()) 3026*da0073e9SAndroid Build Coastguard Worker 3027*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad) 3028*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad) 3029*da0073e9SAndroid Build Coastguard Worker 3030*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_matmul_gradcheck(self, device): 3031*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device) 3032*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device) 3033*da0073e9SAndroid Build Coastguard Worker c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device) 3034*da0073e9SAndroid Build Coastguard Worker d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device) 3035*da0073e9SAndroid Build Coastguard Worker 3036*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c, d): 3037*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.as_nested_tensor([a, b]) 3038*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.as_nested_tensor([c, d]) 3039*da0073e9SAndroid Build Coastguard Worker result = torch.matmul(nt0, nt1) 3040*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(result, 0.0) 3041*da0073e9SAndroid Build Coastguard Worker 3042*da0073e9SAndroid Build Coastguard Worker data = (a, b, c, d) 3043*da0073e9SAndroid Build Coastguard Worker assert torch.autograd.gradcheck(grad_test_func, inputs=data) 3044*da0073e9SAndroid Build Coastguard Worker 3045*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_matmul_backward(self, device): 3046*da0073e9SAndroid Build Coastguard Worker nt0 = torch.nested.nested_tensor( 3047*da0073e9SAndroid Build Coastguard Worker [torch.randn((7, 2, 6)), torch.randn((7, 3, 6))], 3048*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 3049*da0073e9SAndroid Build Coastguard Worker device=device, 3050*da0073e9SAndroid Build Coastguard Worker ) 3051*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor( 3052*da0073e9SAndroid Build Coastguard Worker [torch.randn((7, 6, 4)), torch.randn((7, 6, 5))], 3053*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 3054*da0073e9SAndroid Build Coastguard Worker device=device, 3055*da0073e9SAndroid Build Coastguard Worker ) 3056*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3057*da0073e9SAndroid Build Coastguard Worker pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) 3058*da0073e9SAndroid Build Coastguard Worker pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) 3059*da0073e9SAndroid Build Coastguard Worker 3060*da0073e9SAndroid Build Coastguard Worker ynt = torch.matmul(nt0, nt1) 3061*da0073e9SAndroid Build Coastguard Worker ypt = torch.matmul(pt0, pt1) 3062*da0073e9SAndroid Build Coastguard Worker ynt.backward(ynt.clone()) 3063*da0073e9SAndroid Build Coastguard Worker ypt.backward(ypt.clone()) 3064*da0073e9SAndroid Build Coastguard Worker 3065*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad) 3066*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad) 3067*da0073e9SAndroid Build Coastguard Worker 3068*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_transpose_gradcheck(self, device): 3069*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 5, requires_grad=True, device=device) 3070*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 4, requires_grad=True, device=device) 3071*da0073e9SAndroid Build Coastguard Worker 3072*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b): 3073*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b]) 3074*da0073e9SAndroid Build Coastguard Worker result = nt.transpose(-2, -1).transpose(-2, -1) 3075*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(result, 0.0) 3076*da0073e9SAndroid Build Coastguard Worker 3077*da0073e9SAndroid Build Coastguard Worker data = (a, b) 3078*da0073e9SAndroid Build Coastguard Worker assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) 3079*da0073e9SAndroid Build Coastguard Worker 3080*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_transpose_backward(self, device): 3081*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 3082*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 5)), torch.randn((3, 4))], 3083*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 3084*da0073e9SAndroid Build Coastguard Worker device=device, 3085*da0073e9SAndroid Build Coastguard Worker ) 3086*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3087*da0073e9SAndroid Build Coastguard Worker pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) 3088*da0073e9SAndroid Build Coastguard Worker 3089*da0073e9SAndroid Build Coastguard Worker ynt = nt.transpose(-2, -1) 3090*da0073e9SAndroid Build Coastguard Worker ypt = pt.transpose(-2, -1) 3091*da0073e9SAndroid Build Coastguard Worker ynt.backward(ynt.clone()) 3092*da0073e9SAndroid Build Coastguard Worker ypt.backward(ypt.clone()) 3093*da0073e9SAndroid Build Coastguard Worker 3094*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) 3095*da0073e9SAndroid Build Coastguard Worker 3096*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_reshape_gradcheck(self, device): 3097*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 6, requires_grad=True, device=device) 3098*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 6, requires_grad=True, device=device) 3099*da0073e9SAndroid Build Coastguard Worker 3100*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b): 3101*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b]) 3102*da0073e9SAndroid Build Coastguard Worker result = nt.reshape(2, -1, 2, 3) 3103*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(result, 0.0) 3104*da0073e9SAndroid Build Coastguard Worker 3105*da0073e9SAndroid Build Coastguard Worker data = (a, b) 3106*da0073e9SAndroid Build Coastguard Worker assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) 3107*da0073e9SAndroid Build Coastguard Worker 3108*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_reshape_backward(self): 3109*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 3110*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True 3111*da0073e9SAndroid Build Coastguard Worker ) 3112*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3113*da0073e9SAndroid Build Coastguard Worker pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) 3114*da0073e9SAndroid Build Coastguard Worker 3115*da0073e9SAndroid Build Coastguard Worker ynt = nt.reshape(2, -1, 2, 3) 3116*da0073e9SAndroid Build Coastguard Worker ypt = pt.reshape(2, -1, 2, 3) 3117*da0073e9SAndroid Build Coastguard Worker ynt.backward(ynt.clone()) 3118*da0073e9SAndroid Build Coastguard Worker ypt.backward(ypt.clone()) 3119*da0073e9SAndroid Build Coastguard Worker 3120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) 3121*da0073e9SAndroid Build Coastguard Worker 3122*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_squeeze_backward(self, device): 3123*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 3124*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], 3125*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 3126*da0073e9SAndroid Build Coastguard Worker device=device, 3127*da0073e9SAndroid Build Coastguard Worker ) 3128*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3129*da0073e9SAndroid Build Coastguard Worker pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) 3130*da0073e9SAndroid Build Coastguard Worker 3131*da0073e9SAndroid Build Coastguard Worker ynt = nt.squeeze(-1) 3132*da0073e9SAndroid Build Coastguard Worker ypt = pt.squeeze(-1) 3133*da0073e9SAndroid Build Coastguard Worker ynt.backward(ynt.clone()) 3134*da0073e9SAndroid Build Coastguard Worker ypt.backward(ypt.clone()) 3135*da0073e9SAndroid Build Coastguard Worker 3136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) 3137*da0073e9SAndroid Build Coastguard Worker 3138*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_squeeze_gradcheck(self, device): 3139*da0073e9SAndroid Build Coastguard Worker a = torch.randn( 3140*da0073e9SAndroid Build Coastguard Worker (2, 6, 1), dtype=torch.float64, requires_grad=True, device=device 3141*da0073e9SAndroid Build Coastguard Worker ) 3142*da0073e9SAndroid Build Coastguard Worker b = torch.randn( 3143*da0073e9SAndroid Build Coastguard Worker (3, 6, 1), dtype=torch.float64, requires_grad=True, device=device 3144*da0073e9SAndroid Build Coastguard Worker ) 3145*da0073e9SAndroid Build Coastguard Worker 3146*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b): 3147*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b]) 3148*da0073e9SAndroid Build Coastguard Worker result = nt.squeeze(-1) 3149*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(result, 0.0) 3150*da0073e9SAndroid Build Coastguard Worker 3151*da0073e9SAndroid Build Coastguard Worker assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) 3152*da0073e9SAndroid Build Coastguard Worker 3153*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_unsqueeze_backward(self, device): 3154*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 3155*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 6)), torch.randn((3, 6))], 3156*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 3157*da0073e9SAndroid Build Coastguard Worker device=device, 3158*da0073e9SAndroid Build Coastguard Worker ) 3159*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3160*da0073e9SAndroid Build Coastguard Worker pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) 3161*da0073e9SAndroid Build Coastguard Worker 3162*da0073e9SAndroid Build Coastguard Worker ynt = nt.unsqueeze(2) 3163*da0073e9SAndroid Build Coastguard Worker ypt = pt.unsqueeze(2) 3164*da0073e9SAndroid Build Coastguard Worker ynt.backward(ynt.clone()) 3165*da0073e9SAndroid Build Coastguard Worker ypt.backward(ypt.clone()) 3166*da0073e9SAndroid Build Coastguard Worker 3167*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) 3168*da0073e9SAndroid Build Coastguard Worker 3169*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_unsqueeze_gradcheck(self, device): 3170*da0073e9SAndroid Build Coastguard Worker a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True, device=device) 3171*da0073e9SAndroid Build Coastguard Worker b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True, device=device) 3172*da0073e9SAndroid Build Coastguard Worker 3173*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b): 3174*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b]) 3175*da0073e9SAndroid Build Coastguard Worker result = nt.unsqueeze(-1) 3176*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(result, 0.0) 3177*da0073e9SAndroid Build Coastguard Worker 3178*da0073e9SAndroid Build Coastguard Worker assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) 3179*da0073e9SAndroid Build Coastguard Worker 3180*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_linear(self, device): 3181*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) 3182*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) 3183*da0073e9SAndroid Build Coastguard Worker c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) 3184*da0073e9SAndroid Build Coastguard Worker 3185*da0073e9SAndroid Build Coastguard Worker weight = torch.randn( 3186*da0073e9SAndroid Build Coastguard Worker 2, 2, requires_grad=True, dtype=torch.float64, device=device 3187*da0073e9SAndroid Build Coastguard Worker ) 3188*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device) 3189*da0073e9SAndroid Build Coastguard Worker 3190*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c, weight, bias=None): 3191*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3192*da0073e9SAndroid Build Coastguard Worker # This implicitly tests to_padded_tensor grads 3193*da0073e9SAndroid Build Coastguard Worker d = torch.functional.F.linear(nt, weight, bias) 3194*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(d, 0) 3195*da0073e9SAndroid Build Coastguard Worker 3196*da0073e9SAndroid Build Coastguard Worker data = (a, b, c, weight, bias) 3197*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3198*da0073e9SAndroid Build Coastguard Worker 3199*da0073e9SAndroid Build Coastguard Worker # Test linear with no bias added 3200*da0073e9SAndroid Build Coastguard Worker data = (a, b, c, weight) 3201*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3202*da0073e9SAndroid Build Coastguard Worker 3203*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_linear_plus_transpose(self, device): 3204*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) 3205*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) 3206*da0073e9SAndroid Build Coastguard Worker c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) 3207*da0073e9SAndroid Build Coastguard Worker 3208*da0073e9SAndroid Build Coastguard Worker weight = torch.randn( 3209*da0073e9SAndroid Build Coastguard Worker 2, 2, requires_grad=True, dtype=torch.float64, device=device 3210*da0073e9SAndroid Build Coastguard Worker ) 3211*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device) 3212*da0073e9SAndroid Build Coastguard Worker 3213*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c, weight, bias=None): 3214*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3215*da0073e9SAndroid Build Coastguard Worker # This implicitly tests to_padded_tensor grads 3216*da0073e9SAndroid Build Coastguard Worker d = torch.functional.F.linear(nt, weight, bias) 3217*da0073e9SAndroid Build Coastguard Worker d = d.transpose(-1, -2).contiguous() 3218*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(d, 0) 3219*da0073e9SAndroid Build Coastguard Worker 3220*da0073e9SAndroid Build Coastguard Worker data = (a, b, c, weight, bias) 3221*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3222*da0073e9SAndroid Build Coastguard Worker 3223*da0073e9SAndroid Build Coastguard Worker # Test linear with no bias added 3224*da0073e9SAndroid Build Coastguard Worker data = (a, b, c, weight) 3225*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3226*da0073e9SAndroid Build Coastguard Worker 3227*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_softmax(self, device): 3228*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) 3229*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) 3230*da0073e9SAndroid Build Coastguard Worker c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) 3231*da0073e9SAndroid Build Coastguard Worker 3232*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c, dim): 3233*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3234*da0073e9SAndroid Build Coastguard Worker # This implicitly tests to_padded_tensor grads 3235*da0073e9SAndroid Build Coastguard Worker d = torch.functional.F.softmax(nt, dim=dim) 3236*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(d, 0) 3237*da0073e9SAndroid Build Coastguard Worker 3238*da0073e9SAndroid Build Coastguard Worker # softmax over last dim 3239*da0073e9SAndroid Build Coastguard Worker data = (a, b, c, -1) 3240*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3241*da0073e9SAndroid Build Coastguard Worker 3242*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_linear_backward(self, device): 3243*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, requires_grad=False, device=device) 3244*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, requires_grad=False, device=device) 3245*da0073e9SAndroid Build Coastguard Worker c = torch.randn(3, 2, requires_grad=False, device=device) 3246*da0073e9SAndroid Build Coastguard Worker 3247*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(2, 2, requires_grad=True, device=device) 3248*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(2, requires_grad=True, device=device) 3249*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c], device=device) 3250*da0073e9SAndroid Build Coastguard Worker 3251*da0073e9SAndroid Build Coastguard Worker out = torch.functional.F.linear(nt, weight, bias) 3252*da0073e9SAndroid Build Coastguard Worker 3253*da0073e9SAndroid Build Coastguard Worker out.backward(out.clone()) 3254*da0073e9SAndroid Build Coastguard Worker 3255*da0073e9SAndroid Build Coastguard Worker assert weight.grad is not None 3256*da0073e9SAndroid Build Coastguard Worker assert bias.grad is not None 3257*da0073e9SAndroid Build Coastguard Worker 3258*da0073e9SAndroid Build Coastguard Worker assert a.grad is None 3259*da0073e9SAndroid Build Coastguard Worker assert b.grad is None 3260*da0073e9SAndroid Build Coastguard Worker assert c.grad is None 3261*da0073e9SAndroid Build Coastguard Worker 3262*da0073e9SAndroid Build Coastguard Worker def test_values_grad_with_broadcast(self, device): 3263*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3264*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3265*da0073e9SAndroid Build Coastguard Worker c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3266*da0073e9SAndroid Build Coastguard Worker 3267*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3268*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3269*da0073e9SAndroid Build Coastguard Worker buffer = nt.values() 3270*da0073e9SAndroid Build Coastguard Worker return buffer.sum() 3271*da0073e9SAndroid Build Coastguard Worker 3272*da0073e9SAndroid Build Coastguard Worker data = (a, b, c) 3273*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3274*da0073e9SAndroid Build Coastguard Worker 3275*da0073e9SAndroid Build Coastguard Worker def test_to_buffer_series_ops_grad_with_broadcast(self, device): 3276*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) 3277*da0073e9SAndroid Build Coastguard Worker b = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) 3278*da0073e9SAndroid Build Coastguard Worker c = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) 3279*da0073e9SAndroid Build Coastguard Worker 3280*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3281*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3282*da0073e9SAndroid Build Coastguard Worker buffer = nt.values() 3283*da0073e9SAndroid Build Coastguard Worker buffer = buffer * 2 3284*da0073e9SAndroid Build Coastguard Worker return buffer.exp() 3285*da0073e9SAndroid Build Coastguard Worker 3286*da0073e9SAndroid Build Coastguard Worker data = (a, b, c) 3287*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3288*da0073e9SAndroid Build Coastguard Worker 3289*da0073e9SAndroid Build Coastguard Worker def test_unbind_flow_through(self, device): 3290*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3291*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3292*da0073e9SAndroid Build Coastguard Worker c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3293*da0073e9SAndroid Build Coastguard Worker 3294*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3295*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3296*da0073e9SAndroid Build Coastguard Worker ntT = nt.transpose(-1, -2) 3297*da0073e9SAndroid Build Coastguard Worker unbound = ntT.unbind() 3298*da0073e9SAndroid Build Coastguard Worker d = unbound[0] 3299*da0073e9SAndroid Build Coastguard Worker d = torch.pow(d, 2) 3300*da0073e9SAndroid Build Coastguard Worker return d 3301*da0073e9SAndroid Build Coastguard Worker 3302*da0073e9SAndroid Build Coastguard Worker data = (a, b, c) 3303*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3304*da0073e9SAndroid Build Coastguard Worker 3305*da0073e9SAndroid Build Coastguard Worker def test_split_with_sizes_flow_through(self, device): 3306*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 5, requires_grad=True, dtype=torch.float64, device=device) 3307*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 5, requires_grad=True, dtype=torch.float64, device=device) 3308*da0073e9SAndroid Build Coastguard Worker c = torch.randn(4, 5, requires_grad=True, dtype=torch.float64, device=device) 3309*da0073e9SAndroid Build Coastguard Worker 3310*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3311*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3312*da0073e9SAndroid Build Coastguard Worker splits = nt.split_with_sizes([2, 3], dim=-1) 3313*da0073e9SAndroid Build Coastguard Worker unbound = splits[1].unbind() 3314*da0073e9SAndroid Build Coastguard Worker d = unbound[0] 3315*da0073e9SAndroid Build Coastguard Worker d = torch.pow(d, 2) 3316*da0073e9SAndroid Build Coastguard Worker return d 3317*da0073e9SAndroid Build Coastguard Worker 3318*da0073e9SAndroid Build Coastguard Worker data = (a, b, c) 3319*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3320*da0073e9SAndroid Build Coastguard Worker 3321*da0073e9SAndroid Build Coastguard Worker def test_indexing_backward(self, device): 3322*da0073e9SAndroid Build Coastguard Worker x0 = torch.randn((2, 5)) 3323*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn((3, 4)) 3324*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([x0, x1], device=device, requires_grad=True) 3325*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt[0], x0) 3326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt[-1], x1) 3327*da0073e9SAndroid Build Coastguard Worker grad_x0 = torch.randn((2, 5), device=device) 3328*da0073e9SAndroid Build Coastguard Worker nt[0].backward(grad_x0) 3329*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.nested.nested_tensor( 3330*da0073e9SAndroid Build Coastguard Worker [grad_x0, torch.zeros((3, 4), device=device)] 3331*da0073e9SAndroid Build Coastguard Worker ) 3332*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.grad, expected_grad) 3333*da0073e9SAndroid Build Coastguard Worker 3334*da0073e9SAndroid Build Coastguard Worker def test_masked_fill_backward(self, device): 3335*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3336*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3337*da0073e9SAndroid Build Coastguard Worker c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3338*da0073e9SAndroid Build Coastguard Worker 3339*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3340*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3341*da0073e9SAndroid Build Coastguard Worker mask = nt.detach().clone().to(bool) 3342*da0073e9SAndroid Build Coastguard Worker out = nt.masked_fill(mask, 0) 3343*da0073e9SAndroid Build Coastguard Worker out = torch.nested.to_padded_tensor(out, 0) 3344*da0073e9SAndroid Build Coastguard Worker return out 3345*da0073e9SAndroid Build Coastguard Worker 3346*da0073e9SAndroid Build Coastguard Worker data = (a, b, c) 3347*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3348*da0073e9SAndroid Build Coastguard Worker 3349*da0073e9SAndroid Build Coastguard Worker def test_gelu_backward(self, device): 3350*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3351*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3352*da0073e9SAndroid Build Coastguard Worker c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3353*da0073e9SAndroid Build Coastguard Worker 3354*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3355*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3356*da0073e9SAndroid Build Coastguard Worker nt_gelu = torch.nn.functional.gelu(nt) 3357*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(nt_gelu, 0) 3358*da0073e9SAndroid Build Coastguard Worker 3359*da0073e9SAndroid Build Coastguard Worker data = (a, b, c) 3360*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3361*da0073e9SAndroid Build Coastguard Worker 3362*da0073e9SAndroid Build Coastguard Worker def test_relu_backward(self, device): 3363*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3364*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3365*da0073e9SAndroid Build Coastguard Worker c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3366*da0073e9SAndroid Build Coastguard Worker 3367*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3368*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3369*da0073e9SAndroid Build Coastguard Worker nt_relu = torch.nn.functional.relu(nt) 3370*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(nt_relu, 0) 3371*da0073e9SAndroid Build Coastguard Worker 3372*da0073e9SAndroid Build Coastguard Worker data = (a, b, c) 3373*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3374*da0073e9SAndroid Build Coastguard Worker 3375*da0073e9SAndroid Build Coastguard Worker def test_selu_backward(self, device): 3376*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3377*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3378*da0073e9SAndroid Build Coastguard Worker c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3379*da0073e9SAndroid Build Coastguard Worker 3380*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3381*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3382*da0073e9SAndroid Build Coastguard Worker nt_relu = torch.nn.functional.silu(nt) 3383*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(nt_relu, 0) 3384*da0073e9SAndroid Build Coastguard Worker 3385*da0073e9SAndroid Build Coastguard Worker data = (a, b, c) 3386*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3387*da0073e9SAndroid Build Coastguard Worker 3388*da0073e9SAndroid Build Coastguard Worker def test_abs_backward(self, device): 3389*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3390*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3391*da0073e9SAndroid Build Coastguard Worker c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) 3392*da0073e9SAndroid Build Coastguard Worker 3393*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3394*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3395*da0073e9SAndroid Build Coastguard Worker nt_abs = torch.abs(nt) 3396*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(nt_abs, 0) 3397*da0073e9SAndroid Build Coastguard Worker 3398*da0073e9SAndroid Build Coastguard Worker data = (a, b, c) 3399*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3400*da0073e9SAndroid Build Coastguard Worker 3401*da0073e9SAndroid Build Coastguard Worker # Previously would error when input NT doesn't require grad 3402*da0073e9SAndroid Build Coastguard Worker # NotImplementedError: Cannot access storage of UndefinedTensorImpl 3403*da0073e9SAndroid Build Coastguard Worker def test_layer_norm_backward_edge_case(self, device): 3404*da0073e9SAndroid Build Coastguard Worker size = 4 3405*da0073e9SAndroid Build Coastguard Worker a = torch.randn( 3406*da0073e9SAndroid Build Coastguard Worker 1, 2, size, requires_grad=False, dtype=torch.float64, device=device 3407*da0073e9SAndroid Build Coastguard Worker ) 3408*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([a]) 3409*da0073e9SAndroid Build Coastguard Worker nt_layer_norm = torch.nn.LayerNorm( 3410*da0073e9SAndroid Build Coastguard Worker nt.size(-1), device=device, dtype=torch.float64 3411*da0073e9SAndroid Build Coastguard Worker ) 3412*da0073e9SAndroid Build Coastguard Worker out = nt_layer_norm(nt) 3413*da0073e9SAndroid Build Coastguard Worker out.backward(out.clone()) 3414*da0073e9SAndroid Build Coastguard Worker 3415*da0073e9SAndroid Build Coastguard Worker def test_accumulate_grad_different_strides(self, device): 3416*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1, 4, 2, requires_grad=True, dtype=torch.float64, device=device) 3417*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1, 8, 2, requires_grad=True, dtype=torch.float64, device=device) 3418*da0073e9SAndroid Build Coastguard Worker 3419*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b): 3420*da0073e9SAndroid Build Coastguard Worker nt_1 = torch.nested.as_nested_tensor([a, b]) 3421*da0073e9SAndroid Build Coastguard Worker nt_2 = nt_1.clone() 3422*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.scaled_dot_product_attention(nt_1, nt_2, nt_2) 3423*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(out, 0) 3424*da0073e9SAndroid Build Coastguard Worker 3425*da0073e9SAndroid Build Coastguard Worker data = (a, b) 3426*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3427*da0073e9SAndroid Build Coastguard Worker 3428*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/95562 3429*da0073e9SAndroid Build Coastguard Worker @skipIfSlowGradcheckEnv 3430*da0073e9SAndroid Build Coastguard Worker @parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2]) 3431*da0073e9SAndroid Build Coastguard Worker def test_layer_norm_backward(self, device, size): 3432*da0073e9SAndroid Build Coastguard Worker a = torch.randn( 3433*da0073e9SAndroid Build Coastguard Worker 1, 2, size, requires_grad=True, dtype=torch.float64, device=device 3434*da0073e9SAndroid Build Coastguard Worker ) 3435*da0073e9SAndroid Build Coastguard Worker b = torch.randn( 3436*da0073e9SAndroid Build Coastguard Worker 2, 2, size, requires_grad=True, dtype=torch.float64, device=device 3437*da0073e9SAndroid Build Coastguard Worker ) 3438*da0073e9SAndroid Build Coastguard Worker c = torch.randn( 3439*da0073e9SAndroid Build Coastguard Worker 3, 2, size, requires_grad=True, dtype=torch.float64, device=device 3440*da0073e9SAndroid Build Coastguard Worker ) 3441*da0073e9SAndroid Build Coastguard Worker 3442*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3443*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3444*da0073e9SAndroid Build Coastguard Worker layer_norm = torch.nn.LayerNorm( 3445*da0073e9SAndroid Build Coastguard Worker nt.size(-1), device=device, dtype=torch.float64 3446*da0073e9SAndroid Build Coastguard Worker ) 3447*da0073e9SAndroid Build Coastguard Worker nt_layer_norm = layer_norm(nt) 3448*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(nt_layer_norm, 0) 3449*da0073e9SAndroid Build Coastguard Worker 3450*da0073e9SAndroid Build Coastguard Worker data = (a, b, c) 3451*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3452*da0073e9SAndroid Build Coastguard Worker 3453*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/95562 3454*da0073e9SAndroid Build Coastguard Worker @skipIfSlowGradcheckEnv 3455*da0073e9SAndroid Build Coastguard Worker # Could either mark slow or reduce size 3456*da0073e9SAndroid Build Coastguard Worker @parametrize("size", [128, 32, 4, 2]) 3457*da0073e9SAndroid Build Coastguard Worker def test_layer_norm_backward_5d(self, device, size): 3458*da0073e9SAndroid Build Coastguard Worker a = torch.randn( 3459*da0073e9SAndroid Build Coastguard Worker 4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device 3460*da0073e9SAndroid Build Coastguard Worker ) 3461*da0073e9SAndroid Build Coastguard Worker b = torch.randn( 3462*da0073e9SAndroid Build Coastguard Worker 7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device 3463*da0073e9SAndroid Build Coastguard Worker ) 3464*da0073e9SAndroid Build Coastguard Worker c = torch.randn( 3465*da0073e9SAndroid Build Coastguard Worker 10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device 3466*da0073e9SAndroid Build Coastguard Worker ) 3467*da0073e9SAndroid Build Coastguard Worker 3468*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3469*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c]) 3470*da0073e9SAndroid Build Coastguard Worker layer_norm = torch.nn.LayerNorm( 3471*da0073e9SAndroid Build Coastguard Worker (size, size, nt.size(-1)), device=device, dtype=torch.float64 3472*da0073e9SAndroid Build Coastguard Worker ) 3473*da0073e9SAndroid Build Coastguard Worker nt_layer_norm = layer_norm(nt) 3474*da0073e9SAndroid Build Coastguard Worker return torch.nested.to_padded_tensor(nt_layer_norm, 0) 3475*da0073e9SAndroid Build Coastguard Worker 3476*da0073e9SAndroid Build Coastguard Worker data = (a, b, c) 3477*da0073e9SAndroid Build Coastguard Worker assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) 3478*da0073e9SAndroid Build Coastguard Worker 3479*da0073e9SAndroid Build Coastguard Worker 3480*da0073e9SAndroid Build Coastguard Worker# Found in torch/testing/_comparison.py 3481*da0073e9SAndroid Build Coastguard Workerdefault_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5} 3482*da0073e9SAndroid Build Coastguard Workerdefault_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6} 3483*da0073e9SAndroid Build Coastguard Worker 3484*da0073e9SAndroid Build Coastguard Worker 3485*da0073e9SAndroid Build Coastguard Workerdef get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: 3486*da0073e9SAndroid Build Coastguard Worker deviation = true_value - computed_value 3487*da0073e9SAndroid Build Coastguard Worker deviation = torch.abs(deviation / true_value) 3488*da0073e9SAndroid Build Coastguard Worker # Fill in the nans with the default rtol 3489*da0073e9SAndroid Build Coastguard Worker torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype]) 3490*da0073e9SAndroid Build Coastguard Worker return deviation.max().item() 3491*da0073e9SAndroid Build Coastguard Worker 3492*da0073e9SAndroid Build Coastguard Worker 3493*da0073e9SAndroid Build Coastguard Workerdef get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: 3494*da0073e9SAndroid Build Coastguard Worker deviation = true_value - computed_value 3495*da0073e9SAndroid Build Coastguard Worker atol = torch.abs(deviation).max().item() 3496*da0073e9SAndroid Build Coastguard Worker return atol 3497*da0073e9SAndroid Build Coastguard Worker 3498*da0073e9SAndroid Build Coastguard Worker 3499*da0073e9SAndroid Build Coastguard Workerdef get_tolerances( 3500*da0073e9SAndroid Build Coastguard Worker true_value: torch.Tensor, 3501*da0073e9SAndroid Build Coastguard Worker computed_value: torch.Tensor, 3502*da0073e9SAndroid Build Coastguard Worker fudge_factor: Optional[float] = None, 3503*da0073e9SAndroid Build Coastguard Worker) -> Tuple[float, float]: 3504*da0073e9SAndroid Build Coastguard Worker """Returns the absolute and relative tolerances for comparing two tensors.""" 3505*da0073e9SAndroid Build Coastguard Worker fudge_factor = fudge_factor if fudge_factor is not None else 1.0 3506*da0073e9SAndroid Build Coastguard Worker atol = get_atol(true_value, computed_value) 3507*da0073e9SAndroid Build Coastguard Worker rtol = get_rtol(true_value, computed_value) 3508*da0073e9SAndroid Build Coastguard Worker 3509*da0073e9SAndroid Build Coastguard Worker atol = fudge_factor * max(atol, default_atol[computed_value.dtype]) 3510*da0073e9SAndroid Build Coastguard Worker rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype]) 3511*da0073e9SAndroid Build Coastguard Worker # torch.isclose() has weird behavior around see: 3512*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/102400 3513*da0073e9SAndroid Build Coastguard Worker if rtol > 1e30: 3514*da0073e9SAndroid Build Coastguard Worker rtol = default_rtol[computed_value.dtype] 3515*da0073e9SAndroid Build Coastguard Worker return atol, rtol 3516*da0073e9SAndroid Build Coastguard Worker 3517*da0073e9SAndroid Build Coastguard Worker 3518*da0073e9SAndroid Build Coastguard Worker# We can probably parametrizing existing tests instead of having a separate 3519*da0073e9SAndroid Build Coastguard Worker# test class as we begin to support more ops. Also maybe rewrite with OpInfos. 3520*da0073e9SAndroid Build Coastguard Worker@markDynamoStrictTest 3521*da0073e9SAndroid Build Coastguard Workerclass TestNestedTensorSubclass(NestedTensorTestCase): 3522*da0073e9SAndroid Build Coastguard Worker # TODO: consolidate with the below 3523*da0073e9SAndroid Build Coastguard Worker def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True): 3524*da0073e9SAndroid Build Coastguard Worker Ds = nested_size[1:] 3525*da0073e9SAndroid Build Coastguard Worker out = [] 3526*da0073e9SAndroid Build Coastguard Worker for s in nested_size[0]: 3527*da0073e9SAndroid Build Coastguard Worker out.append( 3528*da0073e9SAndroid Build Coastguard Worker torch.randn( 3529*da0073e9SAndroid Build Coastguard Worker s, 3530*da0073e9SAndroid Build Coastguard Worker *Ds, 3531*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 3532*da0073e9SAndroid Build Coastguard Worker device=device, 3533*da0073e9SAndroid Build Coastguard Worker dtype=torch.float64, 3534*da0073e9SAndroid Build Coastguard Worker ) 3535*da0073e9SAndroid Build Coastguard Worker ) 3536*da0073e9SAndroid Build Coastguard Worker return out 3537*da0073e9SAndroid Build Coastguard Worker 3538*da0073e9SAndroid Build Coastguard Worker def _get_example_tensor_lists( 3539*da0073e9SAndroid Build Coastguard Worker self, 3540*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=True, 3541*da0073e9SAndroid Build Coastguard Worker include_requires_grad=True, 3542*da0073e9SAndroid Build Coastguard Worker include_inner_dim_size_1=False, 3543*da0073e9SAndroid Build Coastguard Worker include_2d_tensor=False, 3544*da0073e9SAndroid Build Coastguard Worker ): 3545*da0073e9SAndroid Build Coastguard Worker def _make_tensor( 3546*da0073e9SAndroid Build Coastguard Worker *shape, include_requires_grad=include_requires_grad, requires_grad=True 3547*da0073e9SAndroid Build Coastguard Worker ): 3548*da0073e9SAndroid Build Coastguard Worker return torch.randn( 3549*da0073e9SAndroid Build Coastguard Worker *shape, 3550*da0073e9SAndroid Build Coastguard Worker requires_grad=(requires_grad if include_requires_grad else False), 3551*da0073e9SAndroid Build Coastguard Worker ) 3552*da0073e9SAndroid Build Coastguard Worker 3553*da0073e9SAndroid Build Coastguard Worker # Purposefully introduce mixed requires_grad settings for the components 3554*da0073e9SAndroid Build Coastguard Worker # when include_requires_grad=True. 3555*da0073e9SAndroid Build Coastguard Worker example_lists = [ 3556*da0073e9SAndroid Build Coastguard Worker # (B, *, D) with B=4 3557*da0073e9SAndroid Build Coastguard Worker [ 3558*da0073e9SAndroid Build Coastguard Worker _make_tensor(2, 5), 3559*da0073e9SAndroid Build Coastguard Worker _make_tensor(3, 5, requires_grad=False), 3560*da0073e9SAndroid Build Coastguard Worker _make_tensor(4, 5, requires_grad=False), 3561*da0073e9SAndroid Build Coastguard Worker _make_tensor(6, 5), 3562*da0073e9SAndroid Build Coastguard Worker ], 3563*da0073e9SAndroid Build Coastguard Worker # (B, *, D_0, D_1) with B=5 3564*da0073e9SAndroid Build Coastguard Worker [ 3565*da0073e9SAndroid Build Coastguard Worker _make_tensor(2, 5, 6), 3566*da0073e9SAndroid Build Coastguard Worker _make_tensor(3, 5, 6), 3567*da0073e9SAndroid Build Coastguard Worker _make_tensor(4, 5, 6, requires_grad=False), 3568*da0073e9SAndroid Build Coastguard Worker _make_tensor(5, 5, 6), 3569*da0073e9SAndroid Build Coastguard Worker _make_tensor(6, 5, 6), 3570*da0073e9SAndroid Build Coastguard Worker ], 3571*da0073e9SAndroid Build Coastguard Worker # (B, *, D_0, D_1, D_2) with B=6 3572*da0073e9SAndroid Build Coastguard Worker [ 3573*da0073e9SAndroid Build Coastguard Worker _make_tensor(2, 5, 6, 7), 3574*da0073e9SAndroid Build Coastguard Worker _make_tensor(3, 5, 6, 7), 3575*da0073e9SAndroid Build Coastguard Worker _make_tensor(4, 5, 6, 7, requires_grad=False), 3576*da0073e9SAndroid Build Coastguard Worker _make_tensor(5, 5, 6, 7), 3577*da0073e9SAndroid Build Coastguard Worker _make_tensor(6, 5, 6, 7), 3578*da0073e9SAndroid Build Coastguard Worker _make_tensor(7, 5, 6, 7), 3579*da0073e9SAndroid Build Coastguard Worker ], 3580*da0073e9SAndroid Build Coastguard Worker ] 3581*da0073e9SAndroid Build Coastguard Worker 3582*da0073e9SAndroid Build Coastguard Worker if include_list_of_lists: 3583*da0073e9SAndroid Build Coastguard Worker example_lists.append( 3584*da0073e9SAndroid Build Coastguard Worker # (B, *, D) with B=3 in list form 3585*da0073e9SAndroid Build Coastguard Worker [ 3586*da0073e9SAndroid Build Coastguard Worker _make_tensor(2, 5, requires_grad=False).tolist(), 3587*da0073e9SAndroid Build Coastguard Worker _make_tensor(3, 5).tolist(), 3588*da0073e9SAndroid Build Coastguard Worker _make_tensor(4, 5).tolist(), 3589*da0073e9SAndroid Build Coastguard Worker ] 3590*da0073e9SAndroid Build Coastguard Worker ) 3591*da0073e9SAndroid Build Coastguard Worker 3592*da0073e9SAndroid Build Coastguard Worker if include_inner_dim_size_1: 3593*da0073e9SAndroid Build Coastguard Worker example_lists.append( 3594*da0073e9SAndroid Build Coastguard Worker [ 3595*da0073e9SAndroid Build Coastguard Worker _make_tensor(2, 1), 3596*da0073e9SAndroid Build Coastguard Worker _make_tensor(3, 1, requires_grad=False), 3597*da0073e9SAndroid Build Coastguard Worker _make_tensor(4, 1, requires_grad=False), 3598*da0073e9SAndroid Build Coastguard Worker _make_tensor(6, 1), 3599*da0073e9SAndroid Build Coastguard Worker ] # (B, *, 1) 3600*da0073e9SAndroid Build Coastguard Worker ) 3601*da0073e9SAndroid Build Coastguard Worker example_lists.append( 3602*da0073e9SAndroid Build Coastguard Worker [ 3603*da0073e9SAndroid Build Coastguard Worker _make_tensor(2, 5, 1), 3604*da0073e9SAndroid Build Coastguard Worker _make_tensor(3, 5, 1, requires_grad=False), 3605*da0073e9SAndroid Build Coastguard Worker _make_tensor(4, 5, 1, requires_grad=False), 3606*da0073e9SAndroid Build Coastguard Worker _make_tensor(6, 5, 1), 3607*da0073e9SAndroid Build Coastguard Worker ] # (B, *, 5, 1) 3608*da0073e9SAndroid Build Coastguard Worker ) 3609*da0073e9SAndroid Build Coastguard Worker 3610*da0073e9SAndroid Build Coastguard Worker if include_2d_tensor: 3611*da0073e9SAndroid Build Coastguard Worker example_lists.append( 3612*da0073e9SAndroid Build Coastguard Worker [ 3613*da0073e9SAndroid Build Coastguard Worker _make_tensor(2), 3614*da0073e9SAndroid Build Coastguard Worker _make_tensor(3, requires_grad=False), 3615*da0073e9SAndroid Build Coastguard Worker _make_tensor(4, requires_grad=False), 3616*da0073e9SAndroid Build Coastguard Worker _make_tensor(6), 3617*da0073e9SAndroid Build Coastguard Worker ] # (B, *) 3618*da0073e9SAndroid Build Coastguard Worker ) 3619*da0073e9SAndroid Build Coastguard Worker 3620*da0073e9SAndroid Build Coastguard Worker return example_lists 3621*da0073e9SAndroid Build Coastguard Worker 3622*da0073e9SAndroid Build Coastguard Worker def test_tensor_attributes(self, device): 3623*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) 3624*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) 3625*da0073e9SAndroid Build Coastguard Worker c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) 3626*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3627*da0073e9SAndroid Build Coastguard Worker _offsets = nt.offsets() 3628*da0073e9SAndroid Build Coastguard Worker 3629*da0073e9SAndroid Build Coastguard Worker for op in ( 3630*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.is_non_overlapping_and_dense.default, 3631*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.sym_size.default, 3632*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.dim.default, 3633*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.numel.default, 3634*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.sym_numel.default, 3635*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.sym_stride.default, 3636*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.sym_storage_offset.default, 3637*da0073e9SAndroid Build Coastguard Worker ): 3638*da0073e9SAndroid Build Coastguard Worker op(nt) 3639*da0073e9SAndroid Build Coastguard Worker 3640*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3641*da0073e9SAndroid Build Coastguard Worker RuntimeError, "directly calling torch.ops.aten.size" 3642*da0073e9SAndroid Build Coastguard Worker ): 3643*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.size.default(nt) 3644*da0073e9SAndroid Build Coastguard Worker 3645*da0073e9SAndroid Build Coastguard Worker nested_int = torch.nested._internal.nested_tensor.get_tensor_symint( 3646*da0073e9SAndroid Build Coastguard Worker _offsets, coeff=1 3647*da0073e9SAndroid Build Coastguard Worker ) 3648*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.size(), (3, nested_int, 3)) 3649*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.shape, (3, nested_int, 3)) 3650*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.dim(), 3) 3651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.numel(), 27) 3652*da0073e9SAndroid Build Coastguard Worker 3653*da0073e9SAndroid Build Coastguard Worker @parametrize("nt_dim", [3, 4, 5]) 3654*da0073e9SAndroid Build Coastguard Worker def test_linear(self, device, nt_dim): 3655*da0073e9SAndroid Build Coastguard Worker if nt_dim == 3: 3656*da0073e9SAndroid Build Coastguard Worker fixed_shape = (3,) 3657*da0073e9SAndroid Build Coastguard Worker elif nt_dim == 4: 3658*da0073e9SAndroid Build Coastguard Worker fixed_shape = (4, 3) 3659*da0073e9SAndroid Build Coastguard Worker elif nt_dim == 5: 3660*da0073e9SAndroid Build Coastguard Worker fixed_shape = (5, 4, 3) 3661*da0073e9SAndroid Build Coastguard Worker 3662*da0073e9SAndroid Build Coastguard Worker a = torch.randn( 3663*da0073e9SAndroid Build Coastguard Worker 2, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device 3664*da0073e9SAndroid Build Coastguard Worker ) 3665*da0073e9SAndroid Build Coastguard Worker b = torch.randn( 3666*da0073e9SAndroid Build Coastguard Worker 3, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device 3667*da0073e9SAndroid Build Coastguard Worker ) 3668*da0073e9SAndroid Build Coastguard Worker c = torch.randn( 3669*da0073e9SAndroid Build Coastguard Worker 4, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device 3670*da0073e9SAndroid Build Coastguard Worker ) 3671*da0073e9SAndroid Build Coastguard Worker weight = torch.randn( 3672*da0073e9SAndroid Build Coastguard Worker 4, 3, requires_grad=True, dtype=torch.float64, device=device 3673*da0073e9SAndroid Build Coastguard Worker ) 3674*da0073e9SAndroid Build Coastguard Worker 3675*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c, weight): 3676*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3677*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.linear(nt, weight) 3678*da0073e9SAndroid Build Coastguard Worker return out.values() 3679*da0073e9SAndroid Build Coastguard Worker 3680*da0073e9SAndroid Build Coastguard Worker gradcheck(grad_test_func, inputs=(a, b, c, weight), check_batched_grad=False) 3681*da0073e9SAndroid Build Coastguard Worker 3682*da0073e9SAndroid Build Coastguard Worker def test_unary_pointwise(self, device): 3683*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) 3684*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) 3685*da0073e9SAndroid Build Coastguard Worker c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) 3686*da0073e9SAndroid Build Coastguard Worker 3687*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3688*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3689*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.silu(nt.sin().cos()) 3690*da0073e9SAndroid Build Coastguard Worker return out.values() 3691*da0073e9SAndroid Build Coastguard Worker 3692*da0073e9SAndroid Build Coastguard Worker gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) 3693*da0073e9SAndroid Build Coastguard Worker 3694*da0073e9SAndroid Build Coastguard Worker def test_unary_pointwise_transposed_inputs(self, device): 3695*da0073e9SAndroid Build Coastguard Worker a, b, c = ( 3696*da0073e9SAndroid Build Coastguard Worker torch.randn( 3697*da0073e9SAndroid Build Coastguard Worker i + 2, 5, requires_grad=True, dtype=torch.float64, device=device 3698*da0073e9SAndroid Build Coastguard Worker ) 3699*da0073e9SAndroid Build Coastguard Worker for i in range(3) 3700*da0073e9SAndroid Build Coastguard Worker ) 3701*da0073e9SAndroid Build Coastguard Worker 3702*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 3703*da0073e9SAndroid Build Coastguard Worker [a.detach(), b.detach(), c.detach()], layout=torch.jagged 3704*da0073e9SAndroid Build Coastguard Worker ) 3705*da0073e9SAndroid Build Coastguard Worker nt_t = nt.transpose(1, 2) 3706*da0073e9SAndroid Build Coastguard Worker self.assertFalse(nt_t.is_contiguous()) 3707*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.silu(nt_t.sin().cos()) 3708*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3709*da0073e9SAndroid Build Coastguard Worker out.is_contiguous(), 3710*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous(), 3711*da0073e9SAndroid Build Coastguard Worker ) 3712*da0073e9SAndroid Build Coastguard Worker 3713*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_t.shape, out.shape) 3714*da0073e9SAndroid Build Coastguard Worker 3715*da0073e9SAndroid Build Coastguard Worker a, b, c = ( 3716*da0073e9SAndroid Build Coastguard Worker torch.randn( 3717*da0073e9SAndroid Build Coastguard Worker i + 2, 5, requires_grad=True, dtype=torch.float64, device=device 3718*da0073e9SAndroid Build Coastguard Worker ) 3719*da0073e9SAndroid Build Coastguard Worker for i in range(3) 3720*da0073e9SAndroid Build Coastguard Worker ) 3721*da0073e9SAndroid Build Coastguard Worker 3722*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3723*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3724*da0073e9SAndroid Build Coastguard Worker nt_t = nt.transpose(1, 2) 3725*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.silu(nt_t.sin().cos()) 3726*da0073e9SAndroid Build Coastguard Worker return out.values() 3727*da0073e9SAndroid Build Coastguard Worker 3728*da0073e9SAndroid Build Coastguard Worker gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) 3729*da0073e9SAndroid Build Coastguard Worker 3730*da0073e9SAndroid Build Coastguard Worker def test_binary_pointwise(self, device): 3731*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) 3732*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) 3733*da0073e9SAndroid Build Coastguard Worker c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) 3734*da0073e9SAndroid Build Coastguard Worker 3735*da0073e9SAndroid Build Coastguard Worker # Incorrect usage: shape check will fail if the offsets tensor are not 3736*da0073e9SAndroid Build Coastguard Worker # the same exact tensor object 3737*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3738*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3739*da0073e9SAndroid Build Coastguard Worker 3740*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 3741*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3742*da0073e9SAndroid Build Coastguard Worker "cannot call binary pointwise function .* with inputs of shapes", 3743*da0073e9SAndroid Build Coastguard Worker lambda: nt1 * nt2, 3744*da0073e9SAndroid Build Coastguard Worker ) 3745*da0073e9SAndroid Build Coastguard Worker 3746*da0073e9SAndroid Build Coastguard Worker # Correct usage: chain the calls using the same offsets tensor object 3747*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3748*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3749*da0073e9SAndroid Build Coastguard Worker # TODO: Switch to public API that takes in (values, offsets) once it exists 3750*da0073e9SAndroid Build Coastguard Worker nt2, offsets = jagged_from_list([a, b, c], nt1.offsets()) 3751*da0073e9SAndroid Build Coastguard Worker out = nt1 * nt2 3752*da0073e9SAndroid Build Coastguard Worker return out.values() 3753*da0073e9SAndroid Build Coastguard Worker 3754*da0073e9SAndroid Build Coastguard Worker gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) 3755*da0073e9SAndroid Build Coastguard Worker 3756*da0073e9SAndroid Build Coastguard Worker def test_binary_pointwise_transposed(self, device): 3757*da0073e9SAndroid Build Coastguard Worker a, b, c = ( 3758*da0073e9SAndroid Build Coastguard Worker torch.randn(i + 2, 5, dtype=torch.float64, device=device) for i in range(3) 3759*da0073e9SAndroid Build Coastguard Worker ) 3760*da0073e9SAndroid Build Coastguard Worker 3761*da0073e9SAndroid Build Coastguard Worker nt1, offsets = jagged_from_list([a, b, c], None) 3762*da0073e9SAndroid Build Coastguard Worker nt2, offsets = jagged_from_list([a, b, c], offsets) 3763*da0073e9SAndroid Build Coastguard Worker 3764*da0073e9SAndroid Build Coastguard Worker nt1_t = nt1.transpose(1, 2) 3765*da0073e9SAndroid Build Coastguard Worker nt2_t = nt2.transpose(1, 2) 3766*da0073e9SAndroid Build Coastguard Worker 3767*da0073e9SAndroid Build Coastguard Worker # out = nt1_t * nt2_t 3768*da0073e9SAndroid Build Coastguard Worker # self.assertFalse(nt1_t.is_contiguous()) 3769*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(out.is_contiguous(), (b.transpose(-1, -2) * b.transpose(-1, -2)).is_contiguous()) 3770*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(out.shape, nt1_t.shape) 3771*da0073e9SAndroid Build Coastguard Worker 3772*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 3773*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3774*da0073e9SAndroid Build Coastguard Worker "cannot call binary pointwise function mul.Tensor with inputs of shapes", 3775*da0073e9SAndroid Build Coastguard Worker lambda: nt1 * nt2_t, 3776*da0073e9SAndroid Build Coastguard Worker ) 3777*da0073e9SAndroid Build Coastguard Worker 3778*da0073e9SAndroid Build Coastguard Worker a, b, c = ( 3779*da0073e9SAndroid Build Coastguard Worker torch.randn( 3780*da0073e9SAndroid Build Coastguard Worker i + 2, 5, requires_grad=True, dtype=torch.float64, device=device 3781*da0073e9SAndroid Build Coastguard Worker ) 3782*da0073e9SAndroid Build Coastguard Worker for i in range(3) 3783*da0073e9SAndroid Build Coastguard Worker ) 3784*da0073e9SAndroid Build Coastguard Worker 3785*da0073e9SAndroid Build Coastguard Worker # Correct usage: chain the calls using the same offsets tensor object 3786*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b, c): 3787*da0073e9SAndroid Build Coastguard Worker nt1, offsets = jagged_from_list([a, b, c], None) 3788*da0073e9SAndroid Build Coastguard Worker nt2, offsets = jagged_from_list([a, b, c], offsets) 3789*da0073e9SAndroid Build Coastguard Worker nt1_t = nt1.transpose(1, 2) 3790*da0073e9SAndroid Build Coastguard Worker nt2_t = nt2.transpose(1, 2) 3791*da0073e9SAndroid Build Coastguard Worker out = nt1_t * nt2_t 3792*da0073e9SAndroid Build Coastguard Worker return out.values() 3793*da0073e9SAndroid Build Coastguard Worker 3794*da0073e9SAndroid Build Coastguard Worker gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) 3795*da0073e9SAndroid Build Coastguard Worker 3796*da0073e9SAndroid Build Coastguard Worker def test_split(self, device): 3797*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) 3798*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) 3799*da0073e9SAndroid Build Coastguard Worker c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) 3800*da0073e9SAndroid Build Coastguard Worker 3801*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3802*da0073e9SAndroid Build Coastguard Worker out = torch.split(nt, 2, -1) 3803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(out), 2) 3804*da0073e9SAndroid Build Coastguard Worker self.assertEqualIgnoringNestedInts( 3805*da0073e9SAndroid Build Coastguard Worker out[0], 3806*da0073e9SAndroid Build Coastguard Worker torch.nested.as_nested_tensor( 3807*da0073e9SAndroid Build Coastguard Worker [a[:, 0:2], b[:, 0:2], c[:, 0:2]], layout=torch.jagged 3808*da0073e9SAndroid Build Coastguard Worker ), 3809*da0073e9SAndroid Build Coastguard Worker ) 3810*da0073e9SAndroid Build Coastguard Worker self.assertEqualIgnoringNestedInts( 3811*da0073e9SAndroid Build Coastguard Worker out[1], 3812*da0073e9SAndroid Build Coastguard Worker torch.nested.as_nested_tensor( 3813*da0073e9SAndroid Build Coastguard Worker [a[:, 2:], b[:, 2:], c[:, 2:]], layout=torch.jagged 3814*da0073e9SAndroid Build Coastguard Worker ), 3815*da0073e9SAndroid Build Coastguard Worker ) 3816*da0073e9SAndroid Build Coastguard Worker 3817*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3818*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3819*da0073e9SAndroid Build Coastguard Worker r"split\(\): not supported for NestedTensor on dim=1", 3820*da0073e9SAndroid Build Coastguard Worker ): 3821*da0073e9SAndroid Build Coastguard Worker torch.split(nt, 2, 1) 3822*da0073e9SAndroid Build Coastguard Worker 3823*da0073e9SAndroid Build Coastguard Worker def test_split_with_sizes(self, device): 3824*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) 3825*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) 3826*da0073e9SAndroid Build Coastguard Worker c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) 3827*da0073e9SAndroid Build Coastguard Worker 3828*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 3829*da0073e9SAndroid Build Coastguard Worker out = torch.split(nt, [1, 2], -1) 3830*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(out), 2) 3831*da0073e9SAndroid Build Coastguard Worker self.assertEqualIgnoringNestedInts( 3832*da0073e9SAndroid Build Coastguard Worker out[0], 3833*da0073e9SAndroid Build Coastguard Worker torch.nested.as_nested_tensor( 3834*da0073e9SAndroid Build Coastguard Worker [a[:, 0:1], b[:, 0:1], c[:, 0:1]], layout=torch.jagged 3835*da0073e9SAndroid Build Coastguard Worker ), 3836*da0073e9SAndroid Build Coastguard Worker ) 3837*da0073e9SAndroid Build Coastguard Worker self.assertEqualIgnoringNestedInts( 3838*da0073e9SAndroid Build Coastguard Worker out[1], 3839*da0073e9SAndroid Build Coastguard Worker torch.nested.as_nested_tensor( 3840*da0073e9SAndroid Build Coastguard Worker [a[:, 1:], b[:, 1:], c[:, 1:]], layout=torch.jagged 3841*da0073e9SAndroid Build Coastguard Worker ), 3842*da0073e9SAndroid Build Coastguard Worker ) 3843*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3844*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3845*da0073e9SAndroid Build Coastguard Worker r"split_with_sizes\(\): not supported for NestedTensor on dim=1", 3846*da0073e9SAndroid Build Coastguard Worker ): 3847*da0073e9SAndroid Build Coastguard Worker torch.split(nt, [1, 2], 1) 3848*da0073e9SAndroid Build Coastguard Worker 3849*da0073e9SAndroid Build Coastguard Worker def test_softmax(self, device): 3850*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 3851*da0073e9SAndroid Build Coastguard Worker [3, None, 5], 3852*da0073e9SAndroid Build Coastguard Worker device=device, 3853*da0073e9SAndroid Build Coastguard Worker dtype=torch.float32, 3854*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 3855*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 3856*da0073e9SAndroid Build Coastguard Worker ) 3857*da0073e9SAndroid Build Coastguard Worker 3858*da0073e9SAndroid Build Coastguard Worker # operate on dim=2 3859*da0073e9SAndroid Build Coastguard Worker output = nt.softmax(dim=2) 3860*da0073e9SAndroid Build Coastguard Worker 3861*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable 3862*da0073e9SAndroid Build Coastguard Worker def _compare_to_ref(nt, output, dim): 3863*da0073e9SAndroid Build Coastguard Worker for in_component, out_component in zip(nt.unbind(), output.unbind()): 3864*da0073e9SAndroid Build Coastguard Worker self.assertEqual(in_component.softmax(dim=dim), out_component) 3865*da0073e9SAndroid Build Coastguard Worker 3866*da0073e9SAndroid Build Coastguard Worker # dim=2 -> dim=1 after unbind 3867*da0073e9SAndroid Build Coastguard Worker _compare_to_ref(nt, output, dim=1) 3868*da0073e9SAndroid Build Coastguard Worker 3869*da0073e9SAndroid Build Coastguard Worker # operate on dim=-1 3870*da0073e9SAndroid Build Coastguard Worker output2 = nt.softmax(dim=-1) 3871*da0073e9SAndroid Build Coastguard Worker torch._dynamo.disable(self.assertEqual)(output, output2) 3872*da0073e9SAndroid Build Coastguard Worker _compare_to_ref(nt, output2, dim=-1) 3873*da0073e9SAndroid Build Coastguard Worker 3874*da0073e9SAndroid Build Coastguard Worker def grad_test_func(a, b): 3875*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged) 3876*da0073e9SAndroid Build Coastguard Worker out = nt.softmax(dim=-1) 3877*da0073e9SAndroid Build Coastguard Worker return out.values() 3878*da0073e9SAndroid Build Coastguard Worker 3879*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4, 5, requires_grad=True, dtype=torch.float64, device=device) 3880*da0073e9SAndroid Build Coastguard Worker b = torch.rand(8, 5, requires_grad=True, dtype=torch.float64, device=device) 3881*da0073e9SAndroid Build Coastguard Worker gradcheck(grad_test_func, inputs=(a, b), check_batched_grad=False) 3882*da0073e9SAndroid Build Coastguard Worker 3883*da0073e9SAndroid Build Coastguard Worker def test_views_inherit_ragged_dim(self, device): 3884*da0073e9SAndroid Build Coastguard Worker # view 3885*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 3886*da0073e9SAndroid Build Coastguard Worker [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged 3887*da0073e9SAndroid Build Coastguard Worker ) 3888*da0073e9SAndroid Build Coastguard Worker # inherit ragged dim via -1 3889*da0073e9SAndroid Build Coastguard Worker view = nt.view(4, -1, 80) 3890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.shape[1], view.shape[1]) 3891*da0073e9SAndroid Build Coastguard Worker # inherit batch and ragged dims via -1 3892*da0073e9SAndroid Build Coastguard Worker view2 = nt.view(-1, -1, 80) 3893*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.shape[:2], view2.shape[:2]) 3894*da0073e9SAndroid Build Coastguard Worker 3895*da0073e9SAndroid Build Coastguard Worker # expand 3896*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 3897*da0073e9SAndroid Build Coastguard Worker [3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged 3898*da0073e9SAndroid Build Coastguard Worker ) 3899*da0073e9SAndroid Build Coastguard Worker # inherit batch and ragged dims via -1 3900*da0073e9SAndroid Build Coastguard Worker view = nt.expand(-1, -1, 5) 3901*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.shape[:2], view.shape[:2]) 3902*da0073e9SAndroid Build Coastguard Worker 3903*da0073e9SAndroid Build Coastguard Worker def test_view_ragged_idx_not_one(self, device): 3904*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 3905*da0073e9SAndroid Build Coastguard Worker [2, None, 20], device=device, dtype=torch.float32, layout=torch.jagged 3906*da0073e9SAndroid Build Coastguard Worker ) 3907*da0073e9SAndroid Build Coastguard Worker 3908*da0073e9SAndroid Build Coastguard Worker view_transposed = nt.transpose(1, 2).view(2, 20, nt.size(1)) 3909*da0073e9SAndroid Build Coastguard Worker self.assertEqual((2, 20, nt.size(1)), (view_transposed.size())) 3910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(view_transposed._base, nt._base) 3911*da0073e9SAndroid Build Coastguard Worker 3912*da0073e9SAndroid Build Coastguard Worker def test_unsafe_view(self, device): 3913*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 3914*da0073e9SAndroid Build Coastguard Worker [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged 3915*da0073e9SAndroid Build Coastguard Worker ) 3916*da0073e9SAndroid Build Coastguard Worker # basic view 3917*da0073e9SAndroid Build Coastguard Worker view1 = torch.ops.aten._unsafe_view(nt, (4, -1, 80)) 3918*da0073e9SAndroid Build Coastguard Worker self.assertEqual((4, nt.size(1), 80), tuple(view1.size())) 3919*da0073e9SAndroid Build Coastguard Worker # _unsafe_view differs from view in that the view information is not tracked 3920*da0073e9SAndroid Build Coastguard Worker self.assertTrue(view1._base is None) 3921*da0073e9SAndroid Build Coastguard Worker 3922*da0073e9SAndroid Build Coastguard Worker # test an unsafe_view when ragged_idx != 1, currently only supports identity view 3923*da0073e9SAndroid Build Coastguard Worker nt_t = nt.transpose(1, 2) 3924*da0073e9SAndroid Build Coastguard Worker view2 = torch.ops.aten._unsafe_view(nt_t, (4, 8, nt.size(1), 10)) 3925*da0073e9SAndroid Build Coastguard Worker self.assertEqual((4, 8, nt.size(1), 10), tuple(view2.size())) 3926*da0073e9SAndroid Build Coastguard Worker self.assertTrue(view2._base is None) 3927*da0073e9SAndroid Build Coastguard Worker 3928*da0073e9SAndroid Build Coastguard Worker @xfailIfTorchDynamo 3929*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 3930*da0073e9SAndroid Build Coastguard Worker def test_reshape_decomp(self, device, requires_grad): 3931*da0073e9SAndroid Build Coastguard Worker # contiguous NT should result in view. 3932*da0073e9SAndroid Build Coastguard Worker nt = ( 3933*da0073e9SAndroid Build Coastguard Worker random_nt_from_dims( 3934*da0073e9SAndroid Build Coastguard Worker [3, None, 10], 3935*da0073e9SAndroid Build Coastguard Worker device=device, 3936*da0073e9SAndroid Build Coastguard Worker dtype=torch.float32, 3937*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 3938*da0073e9SAndroid Build Coastguard Worker ) 3939*da0073e9SAndroid Build Coastguard Worker .detach() 3940*da0073e9SAndroid Build Coastguard Worker .requires_grad_(requires_grad) 3941*da0073e9SAndroid Build Coastguard Worker ) 3942*da0073e9SAndroid Build Coastguard Worker view = nt.reshape(-1, -1, 5, 2) 3943*da0073e9SAndroid Build Coastguard Worker self.assertEqual(view.shape[:2], nt.shape[:2]) 3944*da0073e9SAndroid Build Coastguard Worker self.assertTrue(view._is_view() and view._base is nt) 3945*da0073e9SAndroid Build Coastguard Worker # make sure gradients flow back 3946*da0073e9SAndroid Build Coastguard Worker if requires_grad: 3947*da0073e9SAndroid Build Coastguard Worker view.backward(torch.ones_like(view)) 3948*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.grad, torch.ones_like(nt)) 3949*da0073e9SAndroid Build Coastguard Worker 3950*da0073e9SAndroid Build Coastguard Worker # non-contiguous NT should result in contiguous copy 3951*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 3952*da0073e9SAndroid Build Coastguard Worker [3, None, 5, 2], 3953*da0073e9SAndroid Build Coastguard Worker device=device, 3954*da0073e9SAndroid Build Coastguard Worker dtype=torch.float32, 3955*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 3956*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 3957*da0073e9SAndroid Build Coastguard Worker ) 3958*da0073e9SAndroid Build Coastguard Worker nt_noncontig = nt.transpose(-1, -2) 3959*da0073e9SAndroid Build Coastguard Worker self.assertFalse(nt_noncontig.is_contiguous()) 3960*da0073e9SAndroid Build Coastguard Worker copy = nt_noncontig.reshape(-1, -1, 10) 3961*da0073e9SAndroid Build Coastguard Worker self.assertTrue(copy.is_contiguous()) 3962*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copy.shape[:2], nt.shape[:2]) 3963*da0073e9SAndroid Build Coastguard Worker # make sure gradients flow back 3964*da0073e9SAndroid Build Coastguard Worker if requires_grad: 3965*da0073e9SAndroid Build Coastguard Worker copy.backward(torch.ones_like(copy)) 3966*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.grad, torch.ones_like(nt)) 3967*da0073e9SAndroid Build Coastguard Worker 3968*da0073e9SAndroid Build Coastguard Worker def test_flatten_decomp(self, device): 3969*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 3970*da0073e9SAndroid Build Coastguard Worker [3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged 3971*da0073e9SAndroid Build Coastguard Worker ) 3972*da0073e9SAndroid Build Coastguard Worker flattened = nt.flatten(-2, -1) 3973*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flattened.shape, nt.view(3, -1, 10).shape) 3974*da0073e9SAndroid Build Coastguard Worker 3975*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 3976*da0073e9SAndroid Build Coastguard Worker [3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged 3977*da0073e9SAndroid Build Coastguard Worker ) 3978*da0073e9SAndroid Build Coastguard Worker flattened = nt.flatten(-3, -2) 3979*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flattened.shape, nt.view(3, -1, 10, 6).shape) 3980*da0073e9SAndroid Build Coastguard Worker 3981*da0073e9SAndroid Build Coastguard Worker def test_chunk(self, device): 3982*da0073e9SAndroid Build Coastguard Worker # none NJT case 3983*da0073e9SAndroid Build Coastguard Worker t = torch.randn(10, 4, 5, requires_grad=True) 3984*da0073e9SAndroid Build Coastguard Worker t_list = t.chunk(3, dim=0) 3985*da0073e9SAndroid Build Coastguard Worker loss = t_list[0].sum() + t_list[2].sum() 3986*da0073e9SAndroid Build Coastguard Worker loss.backward() 3987*da0073e9SAndroid Build Coastguard Worker 3988*da0073e9SAndroid Build Coastguard Worker # normal case 3989*da0073e9SAndroid Build Coastguard Worker D = 30 3990*da0073e9SAndroid Build Coastguard Worker B = 8 3991*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 3992*da0073e9SAndroid Build Coastguard Worker [B, None, D], 3993*da0073e9SAndroid Build Coastguard Worker device=device, 3994*da0073e9SAndroid Build Coastguard Worker dtype=torch.float32, 3995*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 3996*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 3997*da0073e9SAndroid Build Coastguard Worker ) 3998*da0073e9SAndroid Build Coastguard Worker NUM_CHUNKS = 3 3999*da0073e9SAndroid Build Coastguard Worker chunks = nt.chunk(NUM_CHUNKS, dim=-1) 4000*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(chunks), NUM_CHUNKS) 4001*da0073e9SAndroid Build Coastguard Worker for i in range(NUM_CHUNKS): 4002*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chunks[i].shape[-1], D // NUM_CHUNKS) 4003*da0073e9SAndroid Build Coastguard Worker 4004*da0073e9SAndroid Build Coastguard Worker # test chunk_backward 4005*da0073e9SAndroid Build Coastguard Worker values = torch.randn( 4006*da0073e9SAndroid Build Coastguard Worker 5, 11, dtype=torch.float64, device=device, requires_grad=True 4007*da0073e9SAndroid Build Coastguard Worker ) 4008*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 3, 5], device=device) 4009*da0073e9SAndroid Build Coastguard Worker 4010*da0073e9SAndroid Build Coastguard Worker def grad_test_func(values, offsets): 4011*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, offsets) 4012*da0073e9SAndroid Build Coastguard Worker chunks = nt.chunk(3, dim=-1) 4013*da0073e9SAndroid Build Coastguard Worker return chunks[0].values().sum() 4014*da0073e9SAndroid Build Coastguard Worker 4015*da0073e9SAndroid Build Coastguard Worker assert gradcheck( 4016*da0073e9SAndroid Build Coastguard Worker grad_test_func, 4017*da0073e9SAndroid Build Coastguard Worker inputs=(values, offsets), 4018*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 4019*da0073e9SAndroid Build Coastguard Worker ) 4020*da0073e9SAndroid Build Coastguard Worker 4021*da0073e9SAndroid Build Coastguard Worker # chunk on batch dim 4022*da0073e9SAndroid Build Coastguard Worker chunks = nt.chunk(NUM_CHUNKS, dim=0) 4023*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(chunks), NUM_CHUNKS) 4024*da0073e9SAndroid Build Coastguard Worker chunk_size = math.ceil(B / NUM_CHUNKS) 4025*da0073e9SAndroid Build Coastguard Worker for i in range(NUM_CHUNKS): 4026*da0073e9SAndroid Build Coastguard Worker if i < NUM_CHUNKS - 1: 4027*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chunks[i].shape[0], chunk_size) 4028*da0073e9SAndroid Build Coastguard Worker else: 4029*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chunks[i].shape[0], B - chunk_size * (NUM_CHUNKS - 1)) 4030*da0073e9SAndroid Build Coastguard Worker offsets_expected = ( 4031*da0073e9SAndroid Build Coastguard Worker nt._offsets[i * chunk_size + 1 : (i + 1) * chunk_size + 1] 4032*da0073e9SAndroid Build Coastguard Worker - nt._offsets[i * chunk_size] 4033*da0073e9SAndroid Build Coastguard Worker ) 4034*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chunks[i]._offsets[1:], offsets_expected) 4035*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt._values, torch.cat([x._values for x in chunks], dim=0)) 4036*da0073e9SAndroid Build Coastguard Worker 4037*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4038*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4039*da0073e9SAndroid Build Coastguard Worker "dim != 0 INTERNAL ASSERT FAILED .* Nested Tensor doesn't support chunk backward on dim=0 yet.", 4040*da0073e9SAndroid Build Coastguard Worker ): 4041*da0073e9SAndroid Build Coastguard Worker # doesn't support backward for chunk (dim=0) yet 4042*da0073e9SAndroid Build Coastguard Worker loss = ( 4043*da0073e9SAndroid Build Coastguard Worker chunks[0].values().sum() 4044*da0073e9SAndroid Build Coastguard Worker + chunks[1].values().sum() 4045*da0073e9SAndroid Build Coastguard Worker + chunks[2].values().sum() 4046*da0073e9SAndroid Build Coastguard Worker ) 4047*da0073e9SAndroid Build Coastguard Worker loss.backward() 4048*da0073e9SAndroid Build Coastguard Worker 4049*da0073e9SAndroid Build Coastguard Worker # chunk on ragged dim not supported 4050*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4051*da0073e9SAndroid Build Coastguard Worker RuntimeError, "chunk.* not supported for NestedTensor on dim=1" 4052*da0073e9SAndroid Build Coastguard Worker ): 4053*da0073e9SAndroid Build Coastguard Worker nt.chunk(2, dim=1) 4054*da0073e9SAndroid Build Coastguard Worker 4055*da0073e9SAndroid Build Coastguard Worker def test_squeeze(self, device): 4056*da0073e9SAndroid Build Coastguard Worker B = 4 4057*da0073e9SAndroid Build Coastguard Worker D = 6 4058*da0073e9SAndroid Build Coastguard Worker # squeeze middle dim 4059*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 4060*da0073e9SAndroid Build Coastguard Worker [B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged 4061*da0073e9SAndroid Build Coastguard Worker ) 4062*da0073e9SAndroid Build Coastguard Worker j0 = nt.shape[1] 4063*da0073e9SAndroid Build Coastguard Worker 4064*da0073e9SAndroid Build Coastguard Worker for dim_arg in [-2, 2]: 4065*da0073e9SAndroid Build Coastguard Worker out = nt.squeeze(dim_arg) 4066*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.shape, (B, j0, D)) 4067*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.unsqueeze(-2), nt) 4068*da0073e9SAndroid Build Coastguard Worker 4069*da0073e9SAndroid Build Coastguard Worker # squeeze last dim 4070*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 4071*da0073e9SAndroid Build Coastguard Worker [B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged 4072*da0073e9SAndroid Build Coastguard Worker ) 4073*da0073e9SAndroid Build Coastguard Worker j1 = nt.shape[1] 4074*da0073e9SAndroid Build Coastguard Worker 4075*da0073e9SAndroid Build Coastguard Worker for dim_arg in [-1, 2]: 4076*da0073e9SAndroid Build Coastguard Worker out = nt.squeeze(dim_arg) 4077*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.shape, (B, j1)) 4078*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.unsqueeze(-1), nt) 4079*da0073e9SAndroid Build Coastguard Worker 4080*da0073e9SAndroid Build Coastguard Worker # squeeze on batch dim not supported 4081*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4082*da0073e9SAndroid Build Coastguard Worker RuntimeError, "squeeze.* not supported for NestedTensor on dim=0" 4083*da0073e9SAndroid Build Coastguard Worker ): 4084*da0073e9SAndroid Build Coastguard Worker nt.squeeze(0) 4085*da0073e9SAndroid Build Coastguard Worker 4086*da0073e9SAndroid Build Coastguard Worker # squeeze on ragged dim not supported 4087*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4088*da0073e9SAndroid Build Coastguard Worker RuntimeError, "squeeze.* not supported for NestedTensor on dim=1" 4089*da0073e9SAndroid Build Coastguard Worker ): 4090*da0073e9SAndroid Build Coastguard Worker nt.squeeze(1) 4091*da0073e9SAndroid Build Coastguard Worker 4092*da0073e9SAndroid Build Coastguard Worker def test_binary_pointwise_broadcasting(self, device): 4093*da0073e9SAndroid Build Coastguard Worker # (B, j0, 3, 4) 4094*da0073e9SAndroid Build Coastguard Worker ts = self._get_list_for_jagged_tensor( 4095*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), 3, 4), device, requires_grad=True 4096*da0073e9SAndroid Build Coastguard Worker ) 4097*da0073e9SAndroid Build Coastguard Worker # (B, j0, ?, ?) + (?) -> (B, j0, ?, ?) 4098*da0073e9SAndroid Build Coastguard Worker # (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?) 4099*da0073e9SAndroid Build Coastguard Worker # (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?) 4100*da0073e9SAndroid Build Coastguard Worker # Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) 4101*da0073e9SAndroid Build Coastguard Worker t_sizes = ( 4102*da0073e9SAndroid Build Coastguard Worker (4,), 4103*da0073e9SAndroid Build Coastguard Worker (1, 4), 4104*da0073e9SAndroid Build Coastguard Worker (3, 1), 4105*da0073e9SAndroid Build Coastguard Worker (1, 3, 1), 4106*da0073e9SAndroid Build Coastguard Worker (1, 1, 1, 4), 4107*da0073e9SAndroid Build Coastguard Worker # (1, 1, 1, 1, 4), (unsupported today) 4108*da0073e9SAndroid Build Coastguard Worker ) 4109*da0073e9SAndroid Build Coastguard Worker 4110*da0073e9SAndroid Build Coastguard Worker def grad_test_func(t, *ts): 4111*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor(list(ts), layout=torch.jagged) 4112*da0073e9SAndroid Build Coastguard Worker out = nt + t 4113*da0073e9SAndroid Build Coastguard Worker return out.values() 4114*da0073e9SAndroid Build Coastguard Worker 4115*da0073e9SAndroid Build Coastguard Worker for t_size in t_sizes: 4116*da0073e9SAndroid Build Coastguard Worker t = torch.rand( 4117*da0073e9SAndroid Build Coastguard Worker t_size, requires_grad=True, device=device, dtype=torch.float64 4118*da0073e9SAndroid Build Coastguard Worker ) 4119*da0073e9SAndroid Build Coastguard Worker gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False) 4120*da0073e9SAndroid Build Coastguard Worker 4121*da0073e9SAndroid Build Coastguard Worker def test_threshold_backward(self, device): 4122*da0073e9SAndroid Build Coastguard Worker ts1 = self._get_list_for_jagged_tensor( 4123*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), 16), device=device, requires_grad=False 4124*da0073e9SAndroid Build Coastguard Worker ) 4125*da0073e9SAndroid Build Coastguard Worker ts2 = self._get_list_for_jagged_tensor( 4126*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), 16), device=device, requires_grad=False 4127*da0073e9SAndroid Build Coastguard Worker ) 4128*da0073e9SAndroid Build Coastguard Worker 4129*da0073e9SAndroid Build Coastguard Worker nt1, offsets = jagged_from_list(ts1, None) 4130*da0073e9SAndroid Build Coastguard Worker nt2, offsets = jagged_from_list(ts2, offsets) 4131*da0073e9SAndroid Build Coastguard Worker buf1 = nt1.values().detach().clone() 4132*da0073e9SAndroid Build Coastguard Worker buf2 = nt2.values().detach().clone() 4133*da0073e9SAndroid Build Coastguard Worker 4134*da0073e9SAndroid Build Coastguard Worker res_nt = torch.ops.aten.threshold_backward(nt1, nt2, 0.0) 4135*da0073e9SAndroid Build Coastguard Worker res_dense = torch.ops.aten.threshold_backward(buf1, buf2, 0.0) 4136*da0073e9SAndroid Build Coastguard Worker 4137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_dense, res_nt.values()) 4138*da0073e9SAndroid Build Coastguard Worker 4139*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4140*da0073e9SAndroid Build Coastguard Worker @parametrize( 4141*da0073e9SAndroid Build Coastguard Worker "func", 4142*da0073e9SAndroid Build Coastguard Worker [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], 4143*da0073e9SAndroid Build Coastguard Worker name_fn=get_op_name, 4144*da0073e9SAndroid Build Coastguard Worker ) 4145*da0073e9SAndroid Build Coastguard Worker @parametrize("keepdim", [False, True]) 4146*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4147*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4148*da0073e9SAndroid Build Coastguard Worker def test_jagged_op_different_output_shape_dim( 4149*da0073e9SAndroid Build Coastguard Worker self, device, dtype, keepdim, requires_grad, components_require_grad, func 4150*da0073e9SAndroid Build Coastguard Worker ): 4151*da0073e9SAndroid Build Coastguard Worker """ 4152*da0073e9SAndroid Build Coastguard Worker Operator passes when reducing on valid reduction dimensions. 4153*da0073e9SAndroid Build Coastguard Worker This test is for operators which return an output tensor with a shape different from the input tensor. 4154*da0073e9SAndroid Build Coastguard Worker """ 4155*da0073e9SAndroid Build Coastguard Worker if get_op_name(func) == "mean" and not keepdim: 4156*da0073e9SAndroid Build Coastguard Worker return 4157*da0073e9SAndroid Build Coastguard Worker 4158*da0073e9SAndroid Build Coastguard Worker op_name = get_op_name(func) 4159*da0073e9SAndroid Build Coastguard Worker 4160*da0073e9SAndroid Build Coastguard Worker ts = self._get_list_for_jagged_tensor( 4161*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), 3, 4), device=device, requires_grad=True 4162*da0073e9SAndroid Build Coastguard Worker ) # (B, j0, 3, 4) 4163*da0073e9SAndroid Build Coastguard Worker 4164*da0073e9SAndroid Build Coastguard Worker # verify correctness of shapes (assuming that ragged_idx == 1) 4165*da0073e9SAndroid Build Coastguard Worker if op_name == "sum": 4166*da0073e9SAndroid Build Coastguard Worker reduce_dims = ( 4167*da0073e9SAndroid Build Coastguard Worker ((0, 1), (3, 4), (1, 1, 3, 4), (0,)), # batch, ragged 4168*da0073e9SAndroid Build Coastguard Worker ((2, 3), (3, None), (3, None, 1, 1), (1, 2)), # non-batch, non-batch 4169*da0073e9SAndroid Build Coastguard Worker ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)), # batch, ragged, non-batch 4170*da0073e9SAndroid Build Coastguard Worker ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)), # batch, ragged, non-batch 4171*da0073e9SAndroid Build Coastguard Worker ( 4172*da0073e9SAndroid Build Coastguard Worker (0, 1, 2, 3), 4173*da0073e9SAndroid Build Coastguard Worker (), 4174*da0073e9SAndroid Build Coastguard Worker (1, 1, 1, 1), 4175*da0073e9SAndroid Build Coastguard Worker (0, 1, 2), 4176*da0073e9SAndroid Build Coastguard Worker ), # batch, ragged, non-batch, non-batch 4177*da0073e9SAndroid Build Coastguard Worker ((2,), (3, None, 4), (3, None, 1, 4), (1,)), # non-batch 4178*da0073e9SAndroid Build Coastguard Worker ) # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None 4179*da0073e9SAndroid Build Coastguard Worker elif op_name == "mean": 4180*da0073e9SAndroid Build Coastguard Worker reduce_dims = ( 4181*da0073e9SAndroid Build Coastguard Worker ((2,), (3, None, 4), (3, None, 1, 4), (1,)), 4182*da0073e9SAndroid Build Coastguard Worker ((3,), (3, None, 3), (3, None, 3, 1), (2,)), 4183*da0073e9SAndroid Build Coastguard Worker ) 4184*da0073e9SAndroid Build Coastguard Worker 4185*da0073e9SAndroid Build Coastguard Worker for rd, ref_shape_no_keepdim, ref_shape_keepdim, _ in reduce_dims: 4186*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged) 4187*da0073e9SAndroid Build Coastguard Worker out = func(nt, dim=rd, keepdim=keepdim) 4188*da0073e9SAndroid Build Coastguard Worker ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim 4189*da0073e9SAndroid Build Coastguard Worker if not torch.compiler.is_compiling: # if not using torch dynamo 4190*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(out.shape), len(ref_shape)) 4191*da0073e9SAndroid Build Coastguard Worker for o, r in zip(out.shape, ref_shape): 4192*da0073e9SAndroid Build Coastguard Worker if r is not None: 4193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o, r) 4194*da0073e9SAndroid Build Coastguard Worker else: 4195*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(o, torch.SymInt)) 4196*da0073e9SAndroid Build Coastguard Worker 4197*da0073e9SAndroid Build Coastguard Worker # verify correctness of values 4198*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4199*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, 4200*da0073e9SAndroid Build Coastguard Worker include_requires_grad=components_require_grad, 4201*da0073e9SAndroid Build Coastguard Worker include_inner_dim_size_1=True, 4202*da0073e9SAndroid Build Coastguard Worker ) 4203*da0073e9SAndroid Build Coastguard Worker for tensor_list, reduce_dim_tuple in itertools.product( 4204*da0073e9SAndroid Build Coastguard Worker tensor_lists, reduce_dims 4205*da0073e9SAndroid Build Coastguard Worker ): 4206*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4207*da0073e9SAndroid Build Coastguard Worker tensor_list, 4208*da0073e9SAndroid Build Coastguard Worker device=device, 4209*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4210*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4211*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4212*da0073e9SAndroid Build Coastguard Worker ) 4213*da0073e9SAndroid Build Coastguard Worker 4214*da0073e9SAndroid Build Coastguard Worker reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple 4215*da0073e9SAndroid Build Coastguard Worker 4216*da0073e9SAndroid Build Coastguard Worker if nt.dim() > reduce_dim[-1]: 4217*da0073e9SAndroid Build Coastguard Worker out_actual = func(nt, dim=reduce_dim, keepdim=keepdim) 4218*da0073e9SAndroid Build Coastguard Worker if nt._ragged_idx in reduce_dim: # raggedness reduced away 4219*da0073e9SAndroid Build Coastguard Worker out_expected = func( 4220*da0073e9SAndroid Build Coastguard Worker nt.values(), dim=reduce_dim_expected, keepdim=keepdim 4221*da0073e9SAndroid Build Coastguard Worker ) 4222*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(out_actual, out_expected)) 4223*da0073e9SAndroid Build Coastguard Worker else: # raggedness preserved 4224*da0073e9SAndroid Build Coastguard Worker out_expected = func(nt.values(), dim=reduce_dim_expected) 4225*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 4226*da0073e9SAndroid Build Coastguard Worker torch.allclose( 4227*da0073e9SAndroid Build Coastguard Worker out_actual.values().view(-1), out_expected.view(-1) 4228*da0073e9SAndroid Build Coastguard Worker ) 4229*da0073e9SAndroid Build Coastguard Worker ) 4230*da0073e9SAndroid Build Coastguard Worker 4231*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4232*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4233*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4234*da0073e9SAndroid Build Coastguard Worker def test_softmax_dim( 4235*da0073e9SAndroid Build Coastguard Worker self, 4236*da0073e9SAndroid Build Coastguard Worker device, 4237*da0073e9SAndroid Build Coastguard Worker dtype, 4238*da0073e9SAndroid Build Coastguard Worker requires_grad, 4239*da0073e9SAndroid Build Coastguard Worker components_require_grad, 4240*da0073e9SAndroid Build Coastguard Worker ): 4241*da0073e9SAndroid Build Coastguard Worker """ 4242*da0073e9SAndroid Build Coastguard Worker Softmax passes when reducing on valid reduction dimensions. 4243*da0073e9SAndroid Build Coastguard Worker """ 4244*da0073e9SAndroid Build Coastguard Worker ts = self._get_list_for_jagged_tensor( 4245*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), 3, 4), device=device, requires_grad=True 4246*da0073e9SAndroid Build Coastguard Worker ) # (B, j0, 3, 4) 4247*da0073e9SAndroid Build Coastguard Worker 4248*da0073e9SAndroid Build Coastguard Worker output_shape = (3, None, 3, 4) 4249*da0073e9SAndroid Build Coastguard Worker 4250*da0073e9SAndroid Build Coastguard Worker # verify correctness of shapes (assuming that ragged_idx == 1) 4251*da0073e9SAndroid Build Coastguard Worker reduce_dims = ( 4252*da0073e9SAndroid Build Coastguard Worker (2, 1), 4253*da0073e9SAndroid Build Coastguard Worker (3, 2), 4254*da0073e9SAndroid Build Coastguard Worker ) # (reduction dimension, effective reduction dimension for baseline) 4255*da0073e9SAndroid Build Coastguard Worker 4256*da0073e9SAndroid Build Coastguard Worker for reduce_dim, _ in reduce_dims: 4257*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged) 4258*da0073e9SAndroid Build Coastguard Worker out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim) 4259*da0073e9SAndroid Build Coastguard Worker torch._dynamo.disable(self.assertEqual)( 4260*da0073e9SAndroid Build Coastguard Worker len(out_actual.shape), len(output_shape) 4261*da0073e9SAndroid Build Coastguard Worker ) # disable if running on dynamo 4262*da0073e9SAndroid Build Coastguard Worker for dim_actual, dim_expected in zip(out_actual.shape, output_shape): 4263*da0073e9SAndroid Build Coastguard Worker if dim_expected is not None: 4264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dim_actual, dim_expected) 4265*da0073e9SAndroid Build Coastguard Worker else: 4266*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(dim_actual, torch.SymInt)) 4267*da0073e9SAndroid Build Coastguard Worker 4268*da0073e9SAndroid Build Coastguard Worker # verify correctness of values 4269*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4270*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, 4271*da0073e9SAndroid Build Coastguard Worker include_requires_grad=components_require_grad, 4272*da0073e9SAndroid Build Coastguard Worker include_inner_dim_size_1=True, 4273*da0073e9SAndroid Build Coastguard Worker ) 4274*da0073e9SAndroid Build Coastguard Worker for tensor_list, reduce_dim_tuple in itertools.product( 4275*da0073e9SAndroid Build Coastguard Worker tensor_lists, reduce_dims 4276*da0073e9SAndroid Build Coastguard Worker ): 4277*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4278*da0073e9SAndroid Build Coastguard Worker tensor_list, 4279*da0073e9SAndroid Build Coastguard Worker device=device, 4280*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4281*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4282*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4283*da0073e9SAndroid Build Coastguard Worker ) 4284*da0073e9SAndroid Build Coastguard Worker 4285*da0073e9SAndroid Build Coastguard Worker reduce_dim, reduce_dim_expected = reduce_dim_tuple 4286*da0073e9SAndroid Build Coastguard Worker 4287*da0073e9SAndroid Build Coastguard Worker if nt.dim() > reduce_dim: 4288*da0073e9SAndroid Build Coastguard Worker out_actual = torch.nn.functional.softmax( 4289*da0073e9SAndroid Build Coastguard Worker nt, dim=reduce_dim 4290*da0073e9SAndroid Build Coastguard Worker ) # nested tensor 4291*da0073e9SAndroid Build Coastguard Worker out_expected = torch.nn.functional.softmax( 4292*da0073e9SAndroid Build Coastguard Worker nt.values(), dim=reduce_dim_expected 4293*da0073e9SAndroid Build Coastguard Worker ) # dense tensor of dimensions 1 less than out_actual 4294*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 4295*da0073e9SAndroid Build Coastguard Worker torch.allclose(out_actual.values().view(-1), out_expected.view(-1)) 4296*da0073e9SAndroid Build Coastguard Worker ) 4297*da0073e9SAndroid Build Coastguard Worker 4298*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4299*da0073e9SAndroid Build Coastguard Worker @parametrize( 4300*da0073e9SAndroid Build Coastguard Worker "func", 4301*da0073e9SAndroid Build Coastguard Worker [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], 4302*da0073e9SAndroid Build Coastguard Worker name_fn=get_op_name, 4303*da0073e9SAndroid Build Coastguard Worker ) 4304*da0073e9SAndroid Build Coastguard Worker @parametrize("keepdim", [False, True]) 4305*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4306*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4307*da0073e9SAndroid Build Coastguard Worker def test_op_dim_reduce_ragged_idx_1_different_output_shape( 4308*da0073e9SAndroid Build Coastguard Worker self, device, dtype, keepdim, requires_grad, components_require_grad, func 4309*da0073e9SAndroid Build Coastguard Worker ): 4310*da0073e9SAndroid Build Coastguard Worker """ 4311*da0073e9SAndroid Build Coastguard Worker Operator on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1. 4312*da0073e9SAndroid Build Coastguard Worker This test is for operators which return an output tensor with a shape different from the input tensor. 4313*da0073e9SAndroid Build Coastguard Worker """ 4314*da0073e9SAndroid Build Coastguard Worker if get_op_name(func) == "mean" and not keepdim: 4315*da0073e9SAndroid Build Coastguard Worker return 4316*da0073e9SAndroid Build Coastguard Worker 4317*da0073e9SAndroid Build Coastguard Worker op_name = get_op_name(func) 4318*da0073e9SAndroid Build Coastguard Worker 4319*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4320*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, 4321*da0073e9SAndroid Build Coastguard Worker include_requires_grad=components_require_grad, 4322*da0073e9SAndroid Build Coastguard Worker include_inner_dim_size_1=True, # (B, *, 1) 4323*da0073e9SAndroid Build Coastguard Worker ) 4324*da0073e9SAndroid Build Coastguard Worker reduce_dim = (1,) # ragged 4325*da0073e9SAndroid Build Coastguard Worker 4326*da0073e9SAndroid Build Coastguard Worker for tensor_list in tensor_lists: 4327*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4328*da0073e9SAndroid Build Coastguard Worker tensor_list, 4329*da0073e9SAndroid Build Coastguard Worker device=device, 4330*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4331*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4332*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4333*da0073e9SAndroid Build Coastguard Worker ) 4334*da0073e9SAndroid Build Coastguard Worker 4335*da0073e9SAndroid Build Coastguard Worker out_actual = func(nt, dim=reduce_dim, keepdim=keepdim) 4336*da0073e9SAndroid Build Coastguard Worker out_expected = torch.cat( 4337*da0073e9SAndroid Build Coastguard Worker [func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0) for t in nt.unbind()] 4338*da0073e9SAndroid Build Coastguard Worker ) 4339*da0073e9SAndroid Build Coastguard Worker 4340*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 4341*da0073e9SAndroid Build Coastguard Worker out_actual.is_nested, 4342*da0073e9SAndroid Build Coastguard Worker f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor", 4343*da0073e9SAndroid Build Coastguard Worker ) # output is a dense tensor 4344*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(out_actual, out_expected)) 4345*da0073e9SAndroid Build Coastguard Worker 4346*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4347*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4348*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4349*da0073e9SAndroid Build Coastguard Worker def test_softmax_dim_reduce_ragged_idx_1( 4350*da0073e9SAndroid Build Coastguard Worker self, device, dtype, requires_grad, components_require_grad 4351*da0073e9SAndroid Build Coastguard Worker ): 4352*da0073e9SAndroid Build Coastguard Worker """ 4353*da0073e9SAndroid Build Coastguard Worker Softmax on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1. 4354*da0073e9SAndroid Build Coastguard Worker """ 4355*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4356*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, 4357*da0073e9SAndroid Build Coastguard Worker include_requires_grad=components_require_grad, 4358*da0073e9SAndroid Build Coastguard Worker include_inner_dim_size_1=True, # (B, *, 1) 4359*da0073e9SAndroid Build Coastguard Worker include_2d_tensor=True, # (B, *) 4360*da0073e9SAndroid Build Coastguard Worker ) 4361*da0073e9SAndroid Build Coastguard Worker reduce_dim = 1 # ragged 4362*da0073e9SAndroid Build Coastguard Worker 4363*da0073e9SAndroid Build Coastguard Worker for tensor_list in tensor_lists: 4364*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4365*da0073e9SAndroid Build Coastguard Worker tensor_list, 4366*da0073e9SAndroid Build Coastguard Worker device=device, 4367*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4368*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4369*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4370*da0073e9SAndroid Build Coastguard Worker ) 4371*da0073e9SAndroid Build Coastguard Worker 4372*da0073e9SAndroid Build Coastguard Worker out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim) 4373*da0073e9SAndroid Build Coastguard Worker out_expected = torch.cat( 4374*da0073e9SAndroid Build Coastguard Worker [ 4375*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.softmax(t, dim=reduce_dim - 1) 4376*da0073e9SAndroid Build Coastguard Worker for t in nt.unbind() 4377*da0073e9SAndroid Build Coastguard Worker ] 4378*da0073e9SAndroid Build Coastguard Worker ) 4379*da0073e9SAndroid Build Coastguard Worker 4380*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 4381*da0073e9SAndroid Build Coastguard Worker out_actual.is_nested, 4382*da0073e9SAndroid Build Coastguard Worker "softmax(): the result of reducing a nested tensor along the ragged dimension is a nested tensor", 4383*da0073e9SAndroid Build Coastguard Worker ) # output is a nested tensor 4384*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(out_actual.values(), out_expected)) 4385*da0073e9SAndroid Build Coastguard Worker 4386*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4387*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4388*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4389*da0073e9SAndroid Build Coastguard Worker def test_softmax_reduce_batch_dim( 4390*da0073e9SAndroid Build Coastguard Worker self, device, dtype, requires_grad, components_require_grad 4391*da0073e9SAndroid Build Coastguard Worker ): 4392*da0073e9SAndroid Build Coastguard Worker """ 4393*da0073e9SAndroid Build Coastguard Worker Softmax on NestedTensor fails when trying to reduce across batch dimension. 4394*da0073e9SAndroid Build Coastguard Worker """ 4395*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4396*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, 4397*da0073e9SAndroid Build Coastguard Worker include_requires_grad=components_require_grad, 4398*da0073e9SAndroid Build Coastguard Worker include_inner_dim_size_1=True, # (B, *, 1) 4399*da0073e9SAndroid Build Coastguard Worker ) 4400*da0073e9SAndroid Build Coastguard Worker reduce_dim = 0 # batch 4401*da0073e9SAndroid Build Coastguard Worker 4402*da0073e9SAndroid Build Coastguard Worker for tensor_list in tensor_lists: 4403*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4404*da0073e9SAndroid Build Coastguard Worker tensor_list, 4405*da0073e9SAndroid Build Coastguard Worker device=device, 4406*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4407*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4408*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4409*da0073e9SAndroid Build Coastguard Worker ) 4410*da0073e9SAndroid Build Coastguard Worker 4411*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4412*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4413*da0073e9SAndroid Build Coastguard Worker "not supported when reducing across the batch dimension for NestedTensor", 4414*da0073e9SAndroid Build Coastguard Worker ): 4415*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.softmax(nt, dim=reduce_dim) 4416*da0073e9SAndroid Build Coastguard Worker 4417*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4418*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4419*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4420*da0073e9SAndroid Build Coastguard Worker def test_layer_norm_reduce_ragged_idx_1( 4421*da0073e9SAndroid Build Coastguard Worker self, device, dtype, requires_grad, components_require_grad 4422*da0073e9SAndroid Build Coastguard Worker ): 4423*da0073e9SAndroid Build Coastguard Worker """ 4424*da0073e9SAndroid Build Coastguard Worker Layer normalization on NestedTensor passes when trying to normalize across ragged dimension, where ragged_idx == 1. 4425*da0073e9SAndroid Build Coastguard Worker """ 4426*da0073e9SAndroid Build Coastguard Worker 4427*da0073e9SAndroid Build Coastguard Worker # requires_grad = False does not currently work with dynamo tests and throws this error: 4428*da0073e9SAndroid Build Coastguard Worker # AssertionError: SymInts must use SymNodeVariable. 4429*da0073e9SAndroid Build Coastguard Worker # If the underlying value is static, we will create a ConstantVariable and specialize. 4430*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.is_compiling() and not requires_grad: 4431*da0073e9SAndroid Build Coastguard Worker return 4432*da0073e9SAndroid Build Coastguard Worker 4433*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4434*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, 4435*da0073e9SAndroid Build Coastguard Worker include_requires_grad=components_require_grad, 4436*da0073e9SAndroid Build Coastguard Worker include_inner_dim_size_1=True, # (B, *, 1) 4437*da0073e9SAndroid Build Coastguard Worker ) 4438*da0073e9SAndroid Build Coastguard Worker 4439*da0073e9SAndroid Build Coastguard Worker for tensor_list in tensor_lists: 4440*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4441*da0073e9SAndroid Build Coastguard Worker tensor_list, 4442*da0073e9SAndroid Build Coastguard Worker device=device, 4443*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4444*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4445*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4446*da0073e9SAndroid Build Coastguard Worker ) 4447*da0073e9SAndroid Build Coastguard Worker 4448*da0073e9SAndroid Build Coastguard Worker if ( 4449*da0073e9SAndroid Build Coastguard Worker nt.dim() >= 3 4450*da0073e9SAndroid Build Coastguard Worker ): # layer norm only works for tensors with 3 or more dimensions 4451*da0073e9SAndroid Build Coastguard Worker normalized_shape = nt.shape[nt._ragged_idx :] 4452*da0073e9SAndroid Build Coastguard Worker 4453*da0073e9SAndroid Build Coastguard Worker out_actual = torch.nn.functional.layer_norm( 4454*da0073e9SAndroid Build Coastguard Worker nt, normalized_shape=normalized_shape 4455*da0073e9SAndroid Build Coastguard Worker ) 4456*da0073e9SAndroid Build Coastguard Worker out_expected = torch.cat( 4457*da0073e9SAndroid Build Coastguard Worker [ 4458*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.layer_norm(t, normalized_shape=t.shape) 4459*da0073e9SAndroid Build Coastguard Worker for t in nt.unbind() 4460*da0073e9SAndroid Build Coastguard Worker ] 4461*da0073e9SAndroid Build Coastguard Worker ) # e.g. in 3D tensor (B, *, M), performs layer normalization on B 2D tensors (*, M) 4462*da0073e9SAndroid Build Coastguard Worker 4463*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 4464*da0073e9SAndroid Build Coastguard Worker out_actual.is_nested, 4465*da0073e9SAndroid Build Coastguard Worker "layer_norm(): the result of reducing a nested tensor along the ragged dimension is a nested tensor", 4466*da0073e9SAndroid Build Coastguard Worker ) # output is a nested tensor 4467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_actual._values.shape, out_expected.shape) 4468*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(out_actual.values(), out_expected)) 4469*da0073e9SAndroid Build Coastguard Worker 4470*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4471*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4472*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4473*da0073e9SAndroid Build Coastguard Worker def test_layer_norm_2d_input( 4474*da0073e9SAndroid Build Coastguard Worker self, 4475*da0073e9SAndroid Build Coastguard Worker device, 4476*da0073e9SAndroid Build Coastguard Worker dtype, 4477*da0073e9SAndroid Build Coastguard Worker requires_grad, 4478*da0073e9SAndroid Build Coastguard Worker components_require_grad, 4479*da0073e9SAndroid Build Coastguard Worker ): 4480*da0073e9SAndroid Build Coastguard Worker """ 4481*da0073e9SAndroid Build Coastguard Worker Layer normalization on NestedTensor fails when trying to operate on a 2-dimensional tensor 4482*da0073e9SAndroid Build Coastguard Worker """ 4483*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4484*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, 4485*da0073e9SAndroid Build Coastguard Worker include_requires_grad=components_require_grad, 4486*da0073e9SAndroid Build Coastguard Worker include_inner_dim_size_1=True, # (B, *, 1) 4487*da0073e9SAndroid Build Coastguard Worker include_2d_tensor=True, # (B, *) 4488*da0073e9SAndroid Build Coastguard Worker ) 4489*da0073e9SAndroid Build Coastguard Worker 4490*da0073e9SAndroid Build Coastguard Worker for tensor_list in tensor_lists: 4491*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4492*da0073e9SAndroid Build Coastguard Worker tensor_list, 4493*da0073e9SAndroid Build Coastguard Worker device=device, 4494*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4495*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4496*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4497*da0073e9SAndroid Build Coastguard Worker ) 4498*da0073e9SAndroid Build Coastguard Worker 4499*da0073e9SAndroid Build Coastguard Worker if nt.dim() <= 2: 4500*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4501*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4502*da0073e9SAndroid Build Coastguard Worker "not supported for NestedTensor objects with 2 or fewer dimensions", 4503*da0073e9SAndroid Build Coastguard Worker ): 4504*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.layer_norm( 4505*da0073e9SAndroid Build Coastguard Worker nt, normalized_shape=(nt.shape[nt._ragged_idx],) 4506*da0073e9SAndroid Build Coastguard Worker ) 4507*da0073e9SAndroid Build Coastguard Worker 4508*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4509*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4510*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4511*da0073e9SAndroid Build Coastguard Worker def test_layer_norm_operate_on_batch_dim( 4512*da0073e9SAndroid Build Coastguard Worker self, 4513*da0073e9SAndroid Build Coastguard Worker device, 4514*da0073e9SAndroid Build Coastguard Worker dtype, 4515*da0073e9SAndroid Build Coastguard Worker requires_grad, 4516*da0073e9SAndroid Build Coastguard Worker components_require_grad, 4517*da0073e9SAndroid Build Coastguard Worker ): 4518*da0073e9SAndroid Build Coastguard Worker """ 4519*da0073e9SAndroid Build Coastguard Worker Layer normalization on NestedTensor fails when trying to operate on the batch dimension 4520*da0073e9SAndroid Build Coastguard Worker """ 4521*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4522*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, 4523*da0073e9SAndroid Build Coastguard Worker include_requires_grad=components_require_grad, 4524*da0073e9SAndroid Build Coastguard Worker include_inner_dim_size_1=True, # (B, *, 1) 4525*da0073e9SAndroid Build Coastguard Worker include_2d_tensor=True, # (B, *) 4526*da0073e9SAndroid Build Coastguard Worker ) 4527*da0073e9SAndroid Build Coastguard Worker 4528*da0073e9SAndroid Build Coastguard Worker for tensor_list in tensor_lists: 4529*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4530*da0073e9SAndroid Build Coastguard Worker tensor_list, 4531*da0073e9SAndroid Build Coastguard Worker device=device, 4532*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4533*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4534*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4535*da0073e9SAndroid Build Coastguard Worker ) 4536*da0073e9SAndroid Build Coastguard Worker 4537*da0073e9SAndroid Build Coastguard Worker if nt.dim() > 2: # cannot perform layer normalization on 2D tensors 4538*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4539*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4540*da0073e9SAndroid Build Coastguard Worker "not supported when normalizing over the batch dimension for NestedTensor", 4541*da0073e9SAndroid Build Coastguard Worker ): 4542*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.layer_norm(nt, normalized_shape=nt.shape) 4543*da0073e9SAndroid Build Coastguard Worker 4544*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4545*da0073e9SAndroid Build Coastguard Worker @parametrize( 4546*da0073e9SAndroid Build Coastguard Worker "func", 4547*da0073e9SAndroid Build Coastguard Worker [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], 4548*da0073e9SAndroid Build Coastguard Worker name_fn=get_op_name, 4549*da0073e9SAndroid Build Coastguard Worker ) 4550*da0073e9SAndroid Build Coastguard Worker @parametrize( 4551*da0073e9SAndroid Build Coastguard Worker "transpose_offset", [1, 2] 4552*da0073e9SAndroid Build Coastguard Worker ) # [transpose consecutive dimensions, transpose nonconsecutive dimensions] 4553*da0073e9SAndroid Build Coastguard Worker @parametrize("keepdim", [False, True]) 4554*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4555*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4556*da0073e9SAndroid Build Coastguard Worker def test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape( 4557*da0073e9SAndroid Build Coastguard Worker self, 4558*da0073e9SAndroid Build Coastguard Worker device, 4559*da0073e9SAndroid Build Coastguard Worker dtype, 4560*da0073e9SAndroid Build Coastguard Worker keepdim, 4561*da0073e9SAndroid Build Coastguard Worker requires_grad, 4562*da0073e9SAndroid Build Coastguard Worker components_require_grad, 4563*da0073e9SAndroid Build Coastguard Worker func, 4564*da0073e9SAndroid Build Coastguard Worker transpose_offset, 4565*da0073e9SAndroid Build Coastguard Worker ): 4566*da0073e9SAndroid Build Coastguard Worker """ 4567*da0073e9SAndroid Build Coastguard Worker Operator on NestedTensor passes when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1 4568*da0073e9SAndroid Build Coastguard Worker This test is for operators which return an output tensor with a shape different from the input tensor. 4569*da0073e9SAndroid Build Coastguard Worker """ 4570*da0073e9SAndroid Build Coastguard Worker if get_op_name(func) == "mean" and not keepdim: 4571*da0073e9SAndroid Build Coastguard Worker return 4572*da0073e9SAndroid Build Coastguard Worker 4573*da0073e9SAndroid Build Coastguard Worker op_name = get_op_name(func) 4574*da0073e9SAndroid Build Coastguard Worker 4575*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4576*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, 4577*da0073e9SAndroid Build Coastguard Worker include_requires_grad=components_require_grad, 4578*da0073e9SAndroid Build Coastguard Worker include_inner_dim_size_1=True, # (B, *, 1) 4579*da0073e9SAndroid Build Coastguard Worker include_2d_tensor=True, # (B, *) 4580*da0073e9SAndroid Build Coastguard Worker ) 4581*da0073e9SAndroid Build Coastguard Worker 4582*da0073e9SAndroid Build Coastguard Worker for tensor_list in tensor_lists: 4583*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4584*da0073e9SAndroid Build Coastguard Worker tensor_list, 4585*da0073e9SAndroid Build Coastguard Worker device=device, 4586*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4587*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4588*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4589*da0073e9SAndroid Build Coastguard Worker ) 4590*da0073e9SAndroid Build Coastguard Worker 4591*da0073e9SAndroid Build Coastguard Worker if nt.dim() > nt._ragged_idx + transpose_offset: 4592*da0073e9SAndroid Build Coastguard Worker nt_transposed = nt.transpose( 4593*da0073e9SAndroid Build Coastguard Worker nt._ragged_idx, nt._ragged_idx + transpose_offset 4594*da0073e9SAndroid Build Coastguard Worker ) 4595*da0073e9SAndroid Build Coastguard Worker reduce_dim = (nt_transposed._ragged_idx,) # ragged 4596*da0073e9SAndroid Build Coastguard Worker 4597*da0073e9SAndroid Build Coastguard Worker out_actual = func(nt_transposed, dim=reduce_dim, keepdim=keepdim) 4598*da0073e9SAndroid Build Coastguard Worker out_expected = torch.cat( 4599*da0073e9SAndroid Build Coastguard Worker [ 4600*da0073e9SAndroid Build Coastguard Worker func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0) 4601*da0073e9SAndroid Build Coastguard Worker for t in nt_transposed.unbind() 4602*da0073e9SAndroid Build Coastguard Worker ] 4603*da0073e9SAndroid Build Coastguard Worker ) 4604*da0073e9SAndroid Build Coastguard Worker 4605*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 4606*da0073e9SAndroid Build Coastguard Worker out_actual.is_nested, 4607*da0073e9SAndroid Build Coastguard Worker f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor", 4608*da0073e9SAndroid Build Coastguard Worker ) # output is a dense tensor 4609*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(out_actual, out_expected, rtol=1e-4)) 4610*da0073e9SAndroid Build Coastguard Worker 4611*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4612*da0073e9SAndroid Build Coastguard Worker @parametrize( 4613*da0073e9SAndroid Build Coastguard Worker "transpose_offset", [1, 2] 4614*da0073e9SAndroid Build Coastguard Worker ) # [transpose consecutive dimensions, transpose nonconsecutive dimensions] 4615*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4616*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4617*da0073e9SAndroid Build Coastguard Worker def test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape( 4618*da0073e9SAndroid Build Coastguard Worker self, 4619*da0073e9SAndroid Build Coastguard Worker device, 4620*da0073e9SAndroid Build Coastguard Worker dtype, 4621*da0073e9SAndroid Build Coastguard Worker requires_grad, 4622*da0073e9SAndroid Build Coastguard Worker components_require_grad, 4623*da0073e9SAndroid Build Coastguard Worker transpose_offset, 4624*da0073e9SAndroid Build Coastguard Worker ): 4625*da0073e9SAndroid Build Coastguard Worker """ 4626*da0073e9SAndroid Build Coastguard Worker Softmax on NestedTensor fails when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1 4627*da0073e9SAndroid Build Coastguard Worker This test is for operators which return an output tensor with the same shape as the input tensor. 4628*da0073e9SAndroid Build Coastguard Worker """ 4629*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4630*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, 4631*da0073e9SAndroid Build Coastguard Worker include_requires_grad=components_require_grad, 4632*da0073e9SAndroid Build Coastguard Worker include_inner_dim_size_1=True, # (B, *, 1) 4633*da0073e9SAndroid Build Coastguard Worker ) 4634*da0073e9SAndroid Build Coastguard Worker 4635*da0073e9SAndroid Build Coastguard Worker for tensor_list in tensor_lists: 4636*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4637*da0073e9SAndroid Build Coastguard Worker tensor_list, 4638*da0073e9SAndroid Build Coastguard Worker device=device, 4639*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4640*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4641*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4642*da0073e9SAndroid Build Coastguard Worker ) 4643*da0073e9SAndroid Build Coastguard Worker 4644*da0073e9SAndroid Build Coastguard Worker if nt.dim() > nt._ragged_idx + transpose_offset: 4645*da0073e9SAndroid Build Coastguard Worker nt_transposed = nt.transpose( 4646*da0073e9SAndroid Build Coastguard Worker nt._ragged_idx, nt._ragged_idx + transpose_offset 4647*da0073e9SAndroid Build Coastguard Worker ) 4648*da0073e9SAndroid Build Coastguard Worker reduce_dim = nt_transposed._ragged_idx # ragged 4649*da0073e9SAndroid Build Coastguard Worker 4650*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4651*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4652*da0073e9SAndroid Build Coastguard Worker "not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor", 4653*da0073e9SAndroid Build Coastguard Worker ): 4654*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.softmax(nt_transposed, dim=reduce_dim) 4655*da0073e9SAndroid Build Coastguard Worker 4656*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4657*da0073e9SAndroid Build Coastguard Worker @parametrize( 4658*da0073e9SAndroid Build Coastguard Worker "func", 4659*da0073e9SAndroid Build Coastguard Worker [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], 4660*da0073e9SAndroid Build Coastguard Worker name_fn=get_op_name, 4661*da0073e9SAndroid Build Coastguard Worker ) 4662*da0073e9SAndroid Build Coastguard Worker @parametrize("keepdim", [False, True]) 4663*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4664*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4665*da0073e9SAndroid Build Coastguard Worker def test_op_dim_transpose_non_ragged_dim_different_output_shape( 4666*da0073e9SAndroid Build Coastguard Worker self, device, dtype, keepdim, requires_grad, components_require_grad, func 4667*da0073e9SAndroid Build Coastguard Worker ): 4668*da0073e9SAndroid Build Coastguard Worker """ 4669*da0073e9SAndroid Build Coastguard Worker Operator passes when reducing transposed nested tensors on valid reduction dimensions. 4670*da0073e9SAndroid Build Coastguard Worker This test is for operators which return an output tensor with a shape different from the input tensor. 4671*da0073e9SAndroid Build Coastguard Worker """ 4672*da0073e9SAndroid Build Coastguard Worker if get_op_name(func) == "mean" and not keepdim: 4673*da0073e9SAndroid Build Coastguard Worker return 4674*da0073e9SAndroid Build Coastguard Worker 4675*da0073e9SAndroid Build Coastguard Worker # verify correctness of shapes (assuming that ragged_idx == 1) 4676*da0073e9SAndroid Build Coastguard Worker if get_op_name(func) == "sum": 4677*da0073e9SAndroid Build Coastguard Worker reduce_dims = ( 4678*da0073e9SAndroid Build Coastguard Worker ((0, 1), (3, 4), (1, 1, 3, 4), (0,)), # batch, ragged 4679*da0073e9SAndroid Build Coastguard Worker ((2, 3), (3, None), (3, None, 1, 1), (1, 2)), # non-batch, non-batch 4680*da0073e9SAndroid Build Coastguard Worker ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)), # batch, ragged, non-batch 4681*da0073e9SAndroid Build Coastguard Worker ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)), # batch, ragged, non-batch 4682*da0073e9SAndroid Build Coastguard Worker ( 4683*da0073e9SAndroid Build Coastguard Worker (0, 1, 2, 3), 4684*da0073e9SAndroid Build Coastguard Worker (), 4685*da0073e9SAndroid Build Coastguard Worker (1, 1, 1, 1), 4686*da0073e9SAndroid Build Coastguard Worker (0, 1, 2), 4687*da0073e9SAndroid Build Coastguard Worker ), # batch, ragged, non-batch, non-batch 4688*da0073e9SAndroid Build Coastguard Worker ((2,), (3, None, 4), (3, None, 1, 4), (1,)), # non-batch 4689*da0073e9SAndroid Build Coastguard Worker ) # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None 4690*da0073e9SAndroid Build Coastguard Worker elif get_op_name(func) == "mean": 4691*da0073e9SAndroid Build Coastguard Worker reduce_dims = ( 4692*da0073e9SAndroid Build Coastguard Worker ((2,), (3, None, 4), (3, None, 1, 4), (1,)), 4693*da0073e9SAndroid Build Coastguard Worker ((3,), (3, None, 3), (3, None, 3, 1), (2,)), 4694*da0073e9SAndroid Build Coastguard Worker ) 4695*da0073e9SAndroid Build Coastguard Worker 4696*da0073e9SAndroid Build Coastguard Worker # verify correctness of values 4697*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4698*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, 4699*da0073e9SAndroid Build Coastguard Worker include_requires_grad=components_require_grad, 4700*da0073e9SAndroid Build Coastguard Worker ) 4701*da0073e9SAndroid Build Coastguard Worker for tensor_list, reduce_dim_tuple in itertools.product( 4702*da0073e9SAndroid Build Coastguard Worker tensor_lists, reduce_dims 4703*da0073e9SAndroid Build Coastguard Worker ): 4704*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4705*da0073e9SAndroid Build Coastguard Worker tensor_list, 4706*da0073e9SAndroid Build Coastguard Worker device=device, 4707*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4708*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4709*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4710*da0073e9SAndroid Build Coastguard Worker ).transpose(-1, -2) 4711*da0073e9SAndroid Build Coastguard Worker 4712*da0073e9SAndroid Build Coastguard Worker reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple 4713*da0073e9SAndroid Build Coastguard Worker 4714*da0073e9SAndroid Build Coastguard Worker if nt.dim() > max( 4715*da0073e9SAndroid Build Coastguard Worker reduce_dim[-1], nt._ragged_idx + 2 4716*da0073e9SAndroid Build Coastguard Worker ): # ensure that transposed dimensions are non-batch, non-ragged dimensions 4717*da0073e9SAndroid Build Coastguard Worker out_actual = func(nt, dim=reduce_dim, keepdim=keepdim) 4718*da0073e9SAndroid Build Coastguard Worker if nt._ragged_idx in reduce_dim: # raggedness reduced away 4719*da0073e9SAndroid Build Coastguard Worker out_expected = func( 4720*da0073e9SAndroid Build Coastguard Worker nt.values(), dim=reduce_dim_expected, keepdim=keepdim 4721*da0073e9SAndroid Build Coastguard Worker ) 4722*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(out_actual, out_expected)) 4723*da0073e9SAndroid Build Coastguard Worker else: # raggedness preserved 4724*da0073e9SAndroid Build Coastguard Worker out_expected = func(nt.values(), dim=reduce_dim_expected) 4725*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 4726*da0073e9SAndroid Build Coastguard Worker torch.allclose( 4727*da0073e9SAndroid Build Coastguard Worker out_actual.values().view(-1), out_expected.view(-1) 4728*da0073e9SAndroid Build Coastguard Worker ) 4729*da0073e9SAndroid Build Coastguard Worker ) 4730*da0073e9SAndroid Build Coastguard Worker 4731*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4732*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4733*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4734*da0073e9SAndroid Build Coastguard Worker def test_softmax_dim_transpose_non_ragged_dim( 4735*da0073e9SAndroid Build Coastguard Worker self, 4736*da0073e9SAndroid Build Coastguard Worker device, 4737*da0073e9SAndroid Build Coastguard Worker dtype, 4738*da0073e9SAndroid Build Coastguard Worker requires_grad, 4739*da0073e9SAndroid Build Coastguard Worker components_require_grad, 4740*da0073e9SAndroid Build Coastguard Worker ): 4741*da0073e9SAndroid Build Coastguard Worker """ 4742*da0073e9SAndroid Build Coastguard Worker Softmax passes when reducing transposed nested tensors on valid reduction dimensions. 4743*da0073e9SAndroid Build Coastguard Worker This test is for operators which return an output tensor with the same shape as the input tensor. 4744*da0073e9SAndroid Build Coastguard Worker """ 4745*da0073e9SAndroid Build Coastguard Worker # verify correctness of shapes (assuming that ragged_idx == 1) 4746*da0073e9SAndroid Build Coastguard Worker reduce_dims = ( 4747*da0073e9SAndroid Build Coastguard Worker (2, 1), 4748*da0073e9SAndroid Build Coastguard Worker (3, 2), 4749*da0073e9SAndroid Build Coastguard Worker ) # (reduction dimension, effective reduction dimension for baseline) 4750*da0073e9SAndroid Build Coastguard Worker 4751*da0073e9SAndroid Build Coastguard Worker # verify correctness of values 4752*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4753*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, 4754*da0073e9SAndroid Build Coastguard Worker include_requires_grad=components_require_grad, 4755*da0073e9SAndroid Build Coastguard Worker include_inner_dim_size_1=True, # (B, *, 1) 4756*da0073e9SAndroid Build Coastguard Worker ) 4757*da0073e9SAndroid Build Coastguard Worker for tensor_list, reduce_dim_tuple in itertools.product( 4758*da0073e9SAndroid Build Coastguard Worker tensor_lists, reduce_dims 4759*da0073e9SAndroid Build Coastguard Worker ): 4760*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4761*da0073e9SAndroid Build Coastguard Worker tensor_list, 4762*da0073e9SAndroid Build Coastguard Worker device=device, 4763*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4764*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4765*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4766*da0073e9SAndroid Build Coastguard Worker ).transpose(-1, -2) 4767*da0073e9SAndroid Build Coastguard Worker 4768*da0073e9SAndroid Build Coastguard Worker reduce_dim, reduce_dim_expected = reduce_dim_tuple 4769*da0073e9SAndroid Build Coastguard Worker 4770*da0073e9SAndroid Build Coastguard Worker if nt.dim() > max(reduce_dim, nt._ragged_idx + 2): 4771*da0073e9SAndroid Build Coastguard Worker out_actual = torch.nn.functional.softmax( 4772*da0073e9SAndroid Build Coastguard Worker nt, dim=reduce_dim 4773*da0073e9SAndroid Build Coastguard Worker ) # nested tensor 4774*da0073e9SAndroid Build Coastguard Worker out_expected = torch.nn.functional.softmax( 4775*da0073e9SAndroid Build Coastguard Worker nt.values(), dim=reduce_dim_expected 4776*da0073e9SAndroid Build Coastguard Worker ) # dense tensor of dimensions 1 less than out_actual 4777*da0073e9SAndroid Build Coastguard Worker 4778*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 4779*da0073e9SAndroid Build Coastguard Worker torch.allclose(out_actual.values().view(-1), out_expected.view(-1)) 4780*da0073e9SAndroid Build Coastguard Worker ) 4781*da0073e9SAndroid Build Coastguard Worker 4782*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4783*da0073e9SAndroid Build Coastguard Worker @parametrize("keepdim", [False, True]) 4784*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4785*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4786*da0073e9SAndroid Build Coastguard Worker def test_sum_dim_reduce_ragged_and_non_batch( 4787*da0073e9SAndroid Build Coastguard Worker self, 4788*da0073e9SAndroid Build Coastguard Worker device, 4789*da0073e9SAndroid Build Coastguard Worker dtype, 4790*da0073e9SAndroid Build Coastguard Worker keepdim, 4791*da0073e9SAndroid Build Coastguard Worker requires_grad, 4792*da0073e9SAndroid Build Coastguard Worker components_require_grad, 4793*da0073e9SAndroid Build Coastguard Worker ): 4794*da0073e9SAndroid Build Coastguard Worker """ 4795*da0073e9SAndroid Build Coastguard Worker Sum on NestedTensor fails when trying to reduce across ragged and non-batch dimensions 4796*da0073e9SAndroid Build Coastguard Worker """ 4797*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4798*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, include_requires_grad=components_require_grad 4799*da0073e9SAndroid Build Coastguard Worker ) 4800*da0073e9SAndroid Build Coastguard Worker reduce_dims = ( 4801*da0073e9SAndroid Build Coastguard Worker (1, 2), # ragged, non-batch 4802*da0073e9SAndroid Build Coastguard Worker (1, 3), # ragged, non-batch 4803*da0073e9SAndroid Build Coastguard Worker ) 4804*da0073e9SAndroid Build Coastguard Worker 4805*da0073e9SAndroid Build Coastguard Worker for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims): 4806*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4807*da0073e9SAndroid Build Coastguard Worker tensor_list, 4808*da0073e9SAndroid Build Coastguard Worker device=device, 4809*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4810*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4811*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4812*da0073e9SAndroid Build Coastguard Worker ) 4813*da0073e9SAndroid Build Coastguard Worker 4814*da0073e9SAndroid Build Coastguard Worker if nt.dim() > reduce_dim[-1]: 4815*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4816*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4817*da0073e9SAndroid Build Coastguard Worker "not supported along a ragged and non-batch dimension for NestedTensor", 4818*da0073e9SAndroid Build Coastguard Worker ): 4819*da0073e9SAndroid Build Coastguard Worker out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim) 4820*da0073e9SAndroid Build Coastguard Worker 4821*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4822*da0073e9SAndroid Build Coastguard Worker @parametrize("keepdim", [False, True]) 4823*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4824*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4825*da0073e9SAndroid Build Coastguard Worker def test_sum_dim_reduce_batch_and_non_batch( 4826*da0073e9SAndroid Build Coastguard Worker self, 4827*da0073e9SAndroid Build Coastguard Worker device, 4828*da0073e9SAndroid Build Coastguard Worker dtype, 4829*da0073e9SAndroid Build Coastguard Worker keepdim, 4830*da0073e9SAndroid Build Coastguard Worker requires_grad, 4831*da0073e9SAndroid Build Coastguard Worker components_require_grad, 4832*da0073e9SAndroid Build Coastguard Worker ): 4833*da0073e9SAndroid Build Coastguard Worker """ 4834*da0073e9SAndroid Build Coastguard Worker Sum on NestedTensor fails when trying to reduce across batch and non-batch dimensions 4835*da0073e9SAndroid Build Coastguard Worker """ 4836*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4837*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, include_requires_grad=components_require_grad 4838*da0073e9SAndroid Build Coastguard Worker ) 4839*da0073e9SAndroid Build Coastguard Worker reduce_dims = ( 4840*da0073e9SAndroid Build Coastguard Worker (0, 2), # batch, non-batch 4841*da0073e9SAndroid Build Coastguard Worker (0, 3), # batch, non-batch 4842*da0073e9SAndroid Build Coastguard Worker ) 4843*da0073e9SAndroid Build Coastguard Worker 4844*da0073e9SAndroid Build Coastguard Worker for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims): 4845*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4846*da0073e9SAndroid Build Coastguard Worker tensor_list, 4847*da0073e9SAndroid Build Coastguard Worker device=device, 4848*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4849*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4850*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4851*da0073e9SAndroid Build Coastguard Worker ) 4852*da0073e9SAndroid Build Coastguard Worker 4853*da0073e9SAndroid Build Coastguard Worker if nt.dim() > reduce_dim[-1]: 4854*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4855*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4856*da0073e9SAndroid Build Coastguard Worker "not supported along the batch dimension but not the ragged dimension for NestedTensor", 4857*da0073e9SAndroid Build Coastguard Worker ): 4858*da0073e9SAndroid Build Coastguard Worker out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim) 4859*da0073e9SAndroid Build Coastguard Worker 4860*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4861*da0073e9SAndroid Build Coastguard Worker @parametrize( 4862*da0073e9SAndroid Build Coastguard Worker "func", 4863*da0073e9SAndroid Build Coastguard Worker [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], 4864*da0073e9SAndroid Build Coastguard Worker name_fn=get_op_name, 4865*da0073e9SAndroid Build Coastguard Worker ) 4866*da0073e9SAndroid Build Coastguard Worker @parametrize("keepdim", [False, True]) 4867*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4868*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4869*da0073e9SAndroid Build Coastguard Worker def test_op_dim_reduce_batch_only_different_output_shape( 4870*da0073e9SAndroid Build Coastguard Worker self, device, dtype, keepdim, requires_grad, components_require_grad, func 4871*da0073e9SAndroid Build Coastguard Worker ): 4872*da0073e9SAndroid Build Coastguard Worker """ 4873*da0073e9SAndroid Build Coastguard Worker Operator on NestedTensor fails when trying to reduce across batch dimension 4874*da0073e9SAndroid Build Coastguard Worker """ 4875*da0073e9SAndroid Build Coastguard Worker if get_op_name(func) == "mean" and not keepdim: 4876*da0073e9SAndroid Build Coastguard Worker return 4877*da0073e9SAndroid Build Coastguard Worker 4878*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 4879*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, include_requires_grad=components_require_grad 4880*da0073e9SAndroid Build Coastguard Worker ) 4881*da0073e9SAndroid Build Coastguard Worker reduce_dim = (0,) # batch 4882*da0073e9SAndroid Build Coastguard Worker 4883*da0073e9SAndroid Build Coastguard Worker for tensor_list in tensor_lists: 4884*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 4885*da0073e9SAndroid Build Coastguard Worker tensor_list, 4886*da0073e9SAndroid Build Coastguard Worker device=device, 4887*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4888*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 4889*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4890*da0073e9SAndroid Build Coastguard Worker ) 4891*da0073e9SAndroid Build Coastguard Worker 4892*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4893*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4894*da0073e9SAndroid Build Coastguard Worker "not supported along the batch dimension but not the ragged dimension for NestedTensor", 4895*da0073e9SAndroid Build Coastguard Worker ): 4896*da0073e9SAndroid Build Coastguard Worker out = func(nt, dim=reduce_dim, keepdim=keepdim) 4897*da0073e9SAndroid Build Coastguard Worker 4898*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4899*da0073e9SAndroid Build Coastguard Worker @parametrize( 4900*da0073e9SAndroid Build Coastguard Worker "func", 4901*da0073e9SAndroid Build Coastguard Worker [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], 4902*da0073e9SAndroid Build Coastguard Worker name_fn=get_op_name, 4903*da0073e9SAndroid Build Coastguard Worker ) 4904*da0073e9SAndroid Build Coastguard Worker @parametrize("keepdim", [False, True]) 4905*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4906*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4907*da0073e9SAndroid Build Coastguard Worker def test_op_dim_with_lengths_different_output_shape( 4908*da0073e9SAndroid Build Coastguard Worker self, 4909*da0073e9SAndroid Build Coastguard Worker device, 4910*da0073e9SAndroid Build Coastguard Worker dtype, 4911*da0073e9SAndroid Build Coastguard Worker keepdim, 4912*da0073e9SAndroid Build Coastguard Worker requires_grad, 4913*da0073e9SAndroid Build Coastguard Worker components_require_grad, 4914*da0073e9SAndroid Build Coastguard Worker func, 4915*da0073e9SAndroid Build Coastguard Worker ): 4916*da0073e9SAndroid Build Coastguard Worker """ 4917*da0073e9SAndroid Build Coastguard Worker Operator on NestedTensor fails when trying to reduce a nested tensor with lengths, 4918*da0073e9SAndroid Build Coastguard Worker i.e. a nested tensor with holes, if reducing on the ragged dimension. 4919*da0073e9SAndroid Build Coastguard Worker This test is for operators which return an output tensor with different shape than the input tensor. 4920*da0073e9SAndroid Build Coastguard Worker """ 4921*da0073e9SAndroid Build Coastguard Worker if get_op_name(func) == "mean" and not keepdim: 4922*da0073e9SAndroid Build Coastguard Worker return 4923*da0073e9SAndroid Build Coastguard Worker 4924*da0073e9SAndroid Build Coastguard Worker reduce_dims = ((1,), (2,), (2, 3)) 4925*da0073e9SAndroid Build Coastguard Worker 4926*da0073e9SAndroid Build Coastguard Worker lengths = torch.randint(5, 10, (20,), device=device) 4927*da0073e9SAndroid Build Coastguard Worker offsets = torch.zeros((21,), device=device, dtype=torch.int) 4928*da0073e9SAndroid Build Coastguard Worker torch.cumsum(lengths, dim=0, out=offsets[1:]) 4929*da0073e9SAndroid Build Coastguard Worker 4930*da0073e9SAndroid Build Coastguard Worker values = torch.randn( 4931*da0073e9SAndroid Build Coastguard Worker (offsets[-1].item(), 20), 4932*da0073e9SAndroid Build Coastguard Worker device=device, 4933*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4934*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4935*da0073e9SAndroid Build Coastguard Worker ) 4936*da0073e9SAndroid Build Coastguard Worker 4937*da0073e9SAndroid Build Coastguard Worker nt_with_holes = torch.nested.nested_tensor_from_jagged( 4938*da0073e9SAndroid Build Coastguard Worker values, 4939*da0073e9SAndroid Build Coastguard Worker offsets, 4940*da0073e9SAndroid Build Coastguard Worker lengths=offsets.diff() - 2, # arbitrary subtraction to create holes 4941*da0073e9SAndroid Build Coastguard Worker ) 4942*da0073e9SAndroid Build Coastguard Worker 4943*da0073e9SAndroid Build Coastguard Worker for reduce_dim in reduce_dims: 4944*da0073e9SAndroid Build Coastguard Worker if nt_with_holes.dim() > reduce_dim[-1]: 4945*da0073e9SAndroid Build Coastguard Worker if nt_with_holes._ragged_idx in reduce_dim: 4946*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4947*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4948*da0073e9SAndroid Build Coastguard Worker "not supported where lengths is not None " 4949*da0073e9SAndroid Build Coastguard Worker + "if reducing across the ragged dimension for NestedTensor", 4950*da0073e9SAndroid Build Coastguard Worker ): 4951*da0073e9SAndroid Build Coastguard Worker out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim) 4952*da0073e9SAndroid Build Coastguard Worker else: 4953*da0073e9SAndroid Build Coastguard Worker out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim) 4954*da0073e9SAndroid Build Coastguard Worker 4955*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 4956*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 4957*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 4958*da0073e9SAndroid Build Coastguard Worker def test_softmax_dim_with_lengths( 4959*da0073e9SAndroid Build Coastguard Worker self, 4960*da0073e9SAndroid Build Coastguard Worker device, 4961*da0073e9SAndroid Build Coastguard Worker dtype, 4962*da0073e9SAndroid Build Coastguard Worker requires_grad, 4963*da0073e9SAndroid Build Coastguard Worker components_require_grad, 4964*da0073e9SAndroid Build Coastguard Worker ): 4965*da0073e9SAndroid Build Coastguard Worker """ 4966*da0073e9SAndroid Build Coastguard Worker Softmax on NestedTensor fails when trying to reduce a nested tensor with lengths, 4967*da0073e9SAndroid Build Coastguard Worker i.e. a nested tensor with holes, if reducing on the ragged dimension. 4968*da0073e9SAndroid Build Coastguard Worker """ 4969*da0073e9SAndroid Build Coastguard Worker reduce_dims = (1, 2, 3) 4970*da0073e9SAndroid Build Coastguard Worker 4971*da0073e9SAndroid Build Coastguard Worker lengths = torch.randint(5, 10, (20,), device=device) 4972*da0073e9SAndroid Build Coastguard Worker offsets = torch.zeros((21,), device=device, dtype=torch.int) 4973*da0073e9SAndroid Build Coastguard Worker torch.cumsum(lengths, dim=0, out=offsets[1:]) 4974*da0073e9SAndroid Build Coastguard Worker 4975*da0073e9SAndroid Build Coastguard Worker values = torch.randn( 4976*da0073e9SAndroid Build Coastguard Worker (offsets[-1].item(), 20), 4977*da0073e9SAndroid Build Coastguard Worker device=device, 4978*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 4979*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 4980*da0073e9SAndroid Build Coastguard Worker ) 4981*da0073e9SAndroid Build Coastguard Worker 4982*da0073e9SAndroid Build Coastguard Worker nt_with_holes = torch.nested.nested_tensor_from_jagged( 4983*da0073e9SAndroid Build Coastguard Worker values, 4984*da0073e9SAndroid Build Coastguard Worker offsets, 4985*da0073e9SAndroid Build Coastguard Worker lengths=offsets.diff() - 2, # arbitrary subtraction to create holes 4986*da0073e9SAndroid Build Coastguard Worker ) 4987*da0073e9SAndroid Build Coastguard Worker 4988*da0073e9SAndroid Build Coastguard Worker for reduce_dim in reduce_dims: 4989*da0073e9SAndroid Build Coastguard Worker if nt_with_holes.dim() > reduce_dim: 4990*da0073e9SAndroid Build Coastguard Worker if nt_with_holes._ragged_idx == reduce_dim: 4991*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4992*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4993*da0073e9SAndroid Build Coastguard Worker "not supported where lengths is not None " 4994*da0073e9SAndroid Build Coastguard Worker + "if reducing across the ragged dimension for NestedTensor", 4995*da0073e9SAndroid Build Coastguard Worker ): 4996*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim) 4997*da0073e9SAndroid Build Coastguard Worker else: 4998*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim) 4999*da0073e9SAndroid Build Coastguard Worker 5000*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo( 5001*da0073e9SAndroid Build Coastguard Worker "ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx] does not currently work " 5002*da0073e9SAndroid Build Coastguard Worker + "with dynamo tests and throws this error: `AssertionError: SymInts must use SymNodeVariable. " 5003*da0073e9SAndroid Build Coastguard Worker + "If the underlying value is static, we will create a ConstantVariable and specialize.`" 5004*da0073e9SAndroid Build Coastguard Worker ) 5005*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 5006*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 5007*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 5008*da0073e9SAndroid Build Coastguard Worker def test_layer_norm_with_lengths( 5009*da0073e9SAndroid Build Coastguard Worker self, 5010*da0073e9SAndroid Build Coastguard Worker device, 5011*da0073e9SAndroid Build Coastguard Worker dtype, 5012*da0073e9SAndroid Build Coastguard Worker requires_grad, 5013*da0073e9SAndroid Build Coastguard Worker components_require_grad, 5014*da0073e9SAndroid Build Coastguard Worker ): 5015*da0073e9SAndroid Build Coastguard Worker """ 5016*da0073e9SAndroid Build Coastguard Worker Layer normalization on NestedTensor fails when trying to operate on a nested tensor with lengths, 5017*da0073e9SAndroid Build Coastguard Worker i.e. a nested tensor with holes, if operating on the ragged dimension. 5018*da0073e9SAndroid Build Coastguard Worker """ 5019*da0073e9SAndroid Build Coastguard Worker 5020*da0073e9SAndroid Build Coastguard Worker # create components for nested tensor 5021*da0073e9SAndroid Build Coastguard Worker lengths = torch.randint(5, 10, (20,), device=device) 5022*da0073e9SAndroid Build Coastguard Worker offsets = torch.zeros((21,), device=device, dtype=torch.int) 5023*da0073e9SAndroid Build Coastguard Worker torch.cumsum(lengths, dim=0, out=offsets[1:]) 5024*da0073e9SAndroid Build Coastguard Worker values = torch.randn( 5025*da0073e9SAndroid Build Coastguard Worker (offsets[-1].item(), 10, 30), 5026*da0073e9SAndroid Build Coastguard Worker device=device, 5027*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 5028*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 5029*da0073e9SAndroid Build Coastguard Worker ) 5030*da0073e9SAndroid Build Coastguard Worker 5031*da0073e9SAndroid Build Coastguard Worker nt_with_holes = torch.nested.nested_tensor_from_jagged( 5032*da0073e9SAndroid Build Coastguard Worker values, 5033*da0073e9SAndroid Build Coastguard Worker offsets, 5034*da0073e9SAndroid Build Coastguard Worker lengths=offsets.diff() - 2, # arbitrary subtraction to create holes 5035*da0073e9SAndroid Build Coastguard Worker ) 5036*da0073e9SAndroid Build Coastguard Worker 5037*da0073e9SAndroid Build Coastguard Worker ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx] 5038*da0073e9SAndroid Build Coastguard Worker 5039*da0073e9SAndroid Build Coastguard Worker normalized_shapes = ( 5040*da0073e9SAndroid Build Coastguard Worker (10, 30), # normalization on non-ragged dimension passes 5041*da0073e9SAndroid Build Coastguard Worker (ragged_size, 10, 30), # normalization on ragged dimension fails 5042*da0073e9SAndroid Build Coastguard Worker ) 5043*da0073e9SAndroid Build Coastguard Worker 5044*da0073e9SAndroid Build Coastguard Worker for normalized_shape in normalized_shapes: 5045*da0073e9SAndroid Build Coastguard Worker if ragged_size in normalized_shape: 5046*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5047*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5048*da0073e9SAndroid Build Coastguard Worker "not supported where lengths is not None if operating on the ragged dimension for NestedTensor", 5049*da0073e9SAndroid Build Coastguard Worker ): 5050*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.layer_norm( 5051*da0073e9SAndroid Build Coastguard Worker nt_with_holes, normalized_shape=normalized_shape 5052*da0073e9SAndroid Build Coastguard Worker ) 5053*da0073e9SAndroid Build Coastguard Worker else: 5054*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.layer_norm( 5055*da0073e9SAndroid Build Coastguard Worker nt_with_holes, normalized_shape=normalized_shape 5056*da0073e9SAndroid Build Coastguard Worker ) 5057*da0073e9SAndroid Build Coastguard Worker 5058*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 5059*da0073e9SAndroid Build Coastguard Worker @parametrize("keepdim", [True]) 5060*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 5061*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 5062*da0073e9SAndroid Build Coastguard Worker def test_mean_dim_reduce_multiple_dims( 5063*da0073e9SAndroid Build Coastguard Worker self, 5064*da0073e9SAndroid Build Coastguard Worker device, 5065*da0073e9SAndroid Build Coastguard Worker dtype, 5066*da0073e9SAndroid Build Coastguard Worker keepdim, 5067*da0073e9SAndroid Build Coastguard Worker requires_grad, 5068*da0073e9SAndroid Build Coastguard Worker components_require_grad, 5069*da0073e9SAndroid Build Coastguard Worker ): 5070*da0073e9SAndroid Build Coastguard Worker """ 5071*da0073e9SAndroid Build Coastguard Worker Mean on NestedTensor fails when trying to reduce across multiple dimensions 5072*da0073e9SAndroid Build Coastguard Worker """ 5073*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 5074*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, include_requires_grad=components_require_grad 5075*da0073e9SAndroid Build Coastguard Worker ) 5076*da0073e9SAndroid Build Coastguard Worker reduce_dims = ((0, 1), (2, 3), (2, 3, 4)) 5077*da0073e9SAndroid Build Coastguard Worker 5078*da0073e9SAndroid Build Coastguard Worker for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims): 5079*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 5080*da0073e9SAndroid Build Coastguard Worker tensor_list, 5081*da0073e9SAndroid Build Coastguard Worker device=device, 5082*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 5083*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 5084*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 5085*da0073e9SAndroid Build Coastguard Worker ) 5086*da0073e9SAndroid Build Coastguard Worker 5087*da0073e9SAndroid Build Coastguard Worker if nt.dim() > reduce_dim[-1]: 5088*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5089*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5090*da0073e9SAndroid Build Coastguard Worker "not supported across multiple dimensions for NestedTensor", 5091*da0073e9SAndroid Build Coastguard Worker ): 5092*da0073e9SAndroid Build Coastguard Worker out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim) 5093*da0073e9SAndroid Build Coastguard Worker 5094*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 5095*da0073e9SAndroid Build Coastguard Worker @parametrize("keepdim", [False, True]) 5096*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 5097*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 5098*da0073e9SAndroid Build Coastguard Worker def test_mean_dim_keepdim_False( 5099*da0073e9SAndroid Build Coastguard Worker self, 5100*da0073e9SAndroid Build Coastguard Worker device, 5101*da0073e9SAndroid Build Coastguard Worker dtype, 5102*da0073e9SAndroid Build Coastguard Worker keepdim, 5103*da0073e9SAndroid Build Coastguard Worker requires_grad, 5104*da0073e9SAndroid Build Coastguard Worker components_require_grad, 5105*da0073e9SAndroid Build Coastguard Worker ): 5106*da0073e9SAndroid Build Coastguard Worker """ 5107*da0073e9SAndroid Build Coastguard Worker Mean on NestedTensor fails when keepdim=False 5108*da0073e9SAndroid Build Coastguard Worker """ 5109*da0073e9SAndroid Build Coastguard Worker tensor_lists = self._get_example_tensor_lists( 5110*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, include_requires_grad=components_require_grad 5111*da0073e9SAndroid Build Coastguard Worker ) 5112*da0073e9SAndroid Build Coastguard Worker reduce_dims = ((1,), (2,), (3,)) 5113*da0073e9SAndroid Build Coastguard Worker 5114*da0073e9SAndroid Build Coastguard Worker for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims): 5115*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 5116*da0073e9SAndroid Build Coastguard Worker tensor_list, 5117*da0073e9SAndroid Build Coastguard Worker device=device, 5118*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 5119*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 5120*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 5121*da0073e9SAndroid Build Coastguard Worker ) 5122*da0073e9SAndroid Build Coastguard Worker 5123*da0073e9SAndroid Build Coastguard Worker if nt.dim() > reduce_dim[-1]: 5124*da0073e9SAndroid Build Coastguard Worker if not keepdim: 5125*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5126*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5127*da0073e9SAndroid Build Coastguard Worker "not supported when keepdim=False for NestedTensor", 5128*da0073e9SAndroid Build Coastguard Worker ): 5129*da0073e9SAndroid Build Coastguard Worker out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim) 5130*da0073e9SAndroid Build Coastguard Worker else: 5131*da0073e9SAndroid Build Coastguard Worker out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim) 5132*da0073e9SAndroid Build Coastguard Worker 5133*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.half) 5134*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 5135*da0073e9SAndroid Build Coastguard Worker @parametrize("weights_only", [False, True]) 5136*da0073e9SAndroid Build Coastguard Worker def test_serialization(self, device, dtype, requires_grad, weights_only): 5137*da0073e9SAndroid Build Coastguard Worker def compare_metadata(nt1, nt2): 5138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size()) 5139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides()) 5140*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5141*da0073e9SAndroid Build Coastguard Worker nt1._nested_tensor_storage_offsets(), 5142*da0073e9SAndroid Build Coastguard Worker nt2._nested_tensor_storage_offsets(), 5143*da0073e9SAndroid Build Coastguard Worker ) 5144*da0073e9SAndroid Build Coastguard Worker 5145*da0073e9SAndroid Build Coastguard Worker nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) 5146*da0073e9SAndroid Build Coastguard Worker for a in [nt_contiguous, nt_noncontiguous]: 5147*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 5148*da0073e9SAndroid Build Coastguard Worker serialized = torch.save(a, buffer) 5149*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 5150*da0073e9SAndroid Build Coastguard Worker b = torch.load(buffer, weights_only=weights_only) 5151*da0073e9SAndroid Build Coastguard Worker # should be both conceptually equal and metadata equivalent 5152*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 5153*da0073e9SAndroid Build Coastguard Worker compare_metadata(a, b) 5154*da0073e9SAndroid Build Coastguard Worker # should be conceptually equal but not necessarily metadata equivalent 5155*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, nt_contiguous) 5156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, nt_noncontiguous) 5157*da0073e9SAndroid Build Coastguard Worker 5158*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 5159*da0073e9SAndroid Build Coastguard Worker PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property" 5160*da0073e9SAndroid Build Coastguard Worker ) 5161*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 5162*da0073e9SAndroid Build Coastguard Worker def test_pin_memory(self, device): 5163*da0073e9SAndroid Build Coastguard Worker nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) 5164*da0073e9SAndroid Build Coastguard Worker for nt in [nt_contiguous, nt_noncontiguous]: 5165*da0073e9SAndroid Build Coastguard Worker self.assertFalse(nt.is_pinned()) 5166*da0073e9SAndroid Build Coastguard Worker pinned = nt.pin_memory(device) 5167*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pinned.is_pinned()) 5168*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt, pinned) 5169*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(nt.data_ptr(), pinned.data_ptr()) 5170*da0073e9SAndroid Build Coastguard Worker # test that pin_memory on already pinned tensor has no effect 5171*da0073e9SAndroid Build Coastguard Worker self.assertIs(pinned, pinned.pin_memory()) 5172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr()) 5173*da0073e9SAndroid Build Coastguard Worker 5174*da0073e9SAndroid Build Coastguard Worker @torch.compiler.disable 5175*da0073e9SAndroid Build Coastguard Worker def _validate_nt( 5176*da0073e9SAndroid Build Coastguard Worker self, 5177*da0073e9SAndroid Build Coastguard Worker nt, 5178*da0073e9SAndroid Build Coastguard Worker device, 5179*da0073e9SAndroid Build Coastguard Worker dtype, 5180*da0073e9SAndroid Build Coastguard Worker layout, 5181*da0073e9SAndroid Build Coastguard Worker requires_grad, 5182*da0073e9SAndroid Build Coastguard Worker dim, 5183*da0073e9SAndroid Build Coastguard Worker batch_size, 5184*da0073e9SAndroid Build Coastguard Worker contiguous, 5185*da0073e9SAndroid Build Coastguard Worker cached_min_seqlen=None, 5186*da0073e9SAndroid Build Coastguard Worker cached_max_seqlen=None, 5187*da0073e9SAndroid Build Coastguard Worker base=None, 5188*da0073e9SAndroid Build Coastguard Worker ref_nt=None, 5189*da0073e9SAndroid Build Coastguard Worker ): 5190*da0073e9SAndroid Build Coastguard Worker # Validate a bunch of properties after NT construction. 5191*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 5192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.dim(), dim) 5193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.device, device) 5194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.dtype, dtype) 5195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.layout, layout) 5196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.requires_grad, requires_grad) 5197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.is_contiguous(), contiguous) 5198*da0073e9SAndroid Build Coastguard Worker 5199*da0073e9SAndroid Build Coastguard Worker if layout == torch.jagged: 5200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt._values.device, device) 5201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt._offsets.device, device) 5202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.shape[0], batch_size) 5203*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(nt.shape[1], torch.SymInt)) 5204*da0073e9SAndroid Build Coastguard Worker 5205*da0073e9SAndroid Build Coastguard Worker if base is not None: 5206*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nt._is_view() and nt._base is base) 5207*da0073e9SAndroid Build Coastguard Worker replay_cache = nt._view_func(torch.randn_like(nt._base))._metadata_cache 5208*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5209*da0073e9SAndroid Build Coastguard Worker "min_seqlen" in replay_cache, cached_min_seqlen is not None 5210*da0073e9SAndroid Build Coastguard Worker ) 5211*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5212*da0073e9SAndroid Build Coastguard Worker "max_seqlen" in replay_cache, cached_max_seqlen is not None 5213*da0073e9SAndroid Build Coastguard Worker ) 5214*da0073e9SAndroid Build Coastguard Worker 5215*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5216*da0073e9SAndroid Build Coastguard Worker "min_seqlen" in nt._metadata_cache, cached_min_seqlen is not None 5217*da0073e9SAndroid Build Coastguard Worker ) 5218*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5219*da0073e9SAndroid Build Coastguard Worker "max_seqlen" in nt._metadata_cache, cached_max_seqlen is not None 5220*da0073e9SAndroid Build Coastguard Worker ) 5221*da0073e9SAndroid Build Coastguard Worker 5222*da0073e9SAndroid Build Coastguard Worker if cached_min_seqlen is not None: 5223*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt._min_seqlen, cached_min_seqlen) 5224*da0073e9SAndroid Build Coastguard Worker 5225*da0073e9SAndroid Build Coastguard Worker if cached_max_seqlen is not None: 5226*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt._max_seqlen, cached_max_seqlen) 5227*da0073e9SAndroid Build Coastguard Worker 5228*da0073e9SAndroid Build Coastguard Worker if ref_nt is not None: 5229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.size(0), ref_nt.size(0)) 5230*da0073e9SAndroid Build Coastguard Worker for n1, n2 in zip(nt.unbind(), ref_nt.unbind()): 5231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n1, n2) 5232*da0073e9SAndroid Build Coastguard Worker 5233*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.half) 5234*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 5235*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 5236*da0073e9SAndroid Build Coastguard Worker def test_jagged_layout_construction_nested_tensor( 5237*da0073e9SAndroid Build Coastguard Worker self, device, dtype, requires_grad, components_require_grad 5238*da0073e9SAndroid Build Coastguard Worker ): 5239*da0073e9SAndroid Build Coastguard Worker for tensor_list in self._get_example_tensor_lists( 5240*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=True, include_requires_grad=components_require_grad 5241*da0073e9SAndroid Build Coastguard Worker ): 5242*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 5243*da0073e9SAndroid Build Coastguard Worker tensor_list, 5244*da0073e9SAndroid Build Coastguard Worker device=device, 5245*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 5246*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 5247*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 5248*da0073e9SAndroid Build Coastguard Worker ) 5249*da0073e9SAndroid Build Coastguard Worker 5250*da0073e9SAndroid Build Coastguard Worker expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1 5251*da0073e9SAndroid Build Coastguard Worker expected_batch_size = len(tensor_list) 5252*da0073e9SAndroid Build Coastguard Worker expected_contiguous = True 5253*da0073e9SAndroid Build Coastguard Worker expected_min_seqlen = min( 5254*da0073e9SAndroid Build Coastguard Worker (torch.tensor(t) if isinstance(t, list) else t).shape[0] 5255*da0073e9SAndroid Build Coastguard Worker for t in tensor_list 5256*da0073e9SAndroid Build Coastguard Worker ) 5257*da0073e9SAndroid Build Coastguard Worker expected_max_seqlen = max( 5258*da0073e9SAndroid Build Coastguard Worker (torch.tensor(t) if isinstance(t, list) else t).shape[0] 5259*da0073e9SAndroid Build Coastguard Worker for t in tensor_list 5260*da0073e9SAndroid Build Coastguard Worker ) 5261*da0073e9SAndroid Build Coastguard Worker self._validate_nt( 5262*da0073e9SAndroid Build Coastguard Worker nt, 5263*da0073e9SAndroid Build Coastguard Worker device, 5264*da0073e9SAndroid Build Coastguard Worker dtype, 5265*da0073e9SAndroid Build Coastguard Worker torch.jagged, 5266*da0073e9SAndroid Build Coastguard Worker requires_grad, 5267*da0073e9SAndroid Build Coastguard Worker expected_dim, 5268*da0073e9SAndroid Build Coastguard Worker expected_batch_size, 5269*da0073e9SAndroid Build Coastguard Worker expected_contiguous, 5270*da0073e9SAndroid Build Coastguard Worker expected_min_seqlen, 5271*da0073e9SAndroid Build Coastguard Worker expected_max_seqlen, 5272*da0073e9SAndroid Build Coastguard Worker ) 5273*da0073e9SAndroid Build Coastguard Worker 5274*da0073e9SAndroid Build Coastguard Worker # Make sure grads -don't- flow back into original tensors for nested_tensor() 5275*da0073e9SAndroid Build Coastguard Worker if requires_grad: 5276*da0073e9SAndroid Build Coastguard Worker (nt * 2).backward(torch.ones_like(nt)) 5277*da0073e9SAndroid Build Coastguard Worker for t in tensor_list: 5278*da0073e9SAndroid Build Coastguard Worker t = t if isinstance(t, torch.Tensor) else torch.as_tensor(t) 5279*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.grad is None) 5280*da0073e9SAndroid Build Coastguard Worker 5281*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.half) 5282*da0073e9SAndroid Build Coastguard Worker @parametrize("components_require_grad", [False, True]) 5283*da0073e9SAndroid Build Coastguard Worker def test_jagged_layout_construction_as_nested_tensor( 5284*da0073e9SAndroid Build Coastguard Worker self, device, dtype, components_require_grad 5285*da0073e9SAndroid Build Coastguard Worker ): 5286*da0073e9SAndroid Build Coastguard Worker # NB: as_nested_tensor(tensor_list) doesn't support lists of lists for tensor_list 5287*da0073e9SAndroid Build Coastguard Worker for tensor_list in self._get_example_tensor_lists( 5288*da0073e9SAndroid Build Coastguard Worker include_list_of_lists=False, include_requires_grad=components_require_grad 5289*da0073e9SAndroid Build Coastguard Worker ): 5290*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor( 5291*da0073e9SAndroid Build Coastguard Worker tensor_list, device=device, dtype=dtype, layout=torch.jagged 5292*da0073e9SAndroid Build Coastguard Worker ) 5293*da0073e9SAndroid Build Coastguard Worker 5294*da0073e9SAndroid Build Coastguard Worker # nt.requires_grad=True should be set if at least one component requires grad 5295*da0073e9SAndroid Build Coastguard Worker expected_dim = tensor_list[0].dim() + 1 5296*da0073e9SAndroid Build Coastguard Worker expected_batch_size = len(tensor_list) 5297*da0073e9SAndroid Build Coastguard Worker expected_contiguous = True 5298*da0073e9SAndroid Build Coastguard Worker expected_min_seqlen = min( 5299*da0073e9SAndroid Build Coastguard Worker (torch.tensor(t) if isinstance(t, list) else t).shape[0] 5300*da0073e9SAndroid Build Coastguard Worker for t in tensor_list 5301*da0073e9SAndroid Build Coastguard Worker ) 5302*da0073e9SAndroid Build Coastguard Worker expected_max_seqlen = max( 5303*da0073e9SAndroid Build Coastguard Worker (torch.tensor(t) if isinstance(t, list) else t).shape[0] 5304*da0073e9SAndroid Build Coastguard Worker for t in tensor_list 5305*da0073e9SAndroid Build Coastguard Worker ) 5306*da0073e9SAndroid Build Coastguard Worker self._validate_nt( 5307*da0073e9SAndroid Build Coastguard Worker nt, 5308*da0073e9SAndroid Build Coastguard Worker device, 5309*da0073e9SAndroid Build Coastguard Worker dtype, 5310*da0073e9SAndroid Build Coastguard Worker torch.jagged, 5311*da0073e9SAndroid Build Coastguard Worker components_require_grad, 5312*da0073e9SAndroid Build Coastguard Worker expected_dim, 5313*da0073e9SAndroid Build Coastguard Worker expected_batch_size, 5314*da0073e9SAndroid Build Coastguard Worker expected_contiguous, 5315*da0073e9SAndroid Build Coastguard Worker expected_min_seqlen, 5316*da0073e9SAndroid Build Coastguard Worker expected_max_seqlen, 5317*da0073e9SAndroid Build Coastguard Worker ) 5318*da0073e9SAndroid Build Coastguard Worker 5319*da0073e9SAndroid Build Coastguard Worker # Make sure grads flow back into original tensors for as_nested_tensor() 5320*da0073e9SAndroid Build Coastguard Worker if components_require_grad: 5321*da0073e9SAndroid Build Coastguard Worker (nt * 2).backward(torch.ones_like(nt)) 5322*da0073e9SAndroid Build Coastguard Worker for t in tensor_list: 5323*da0073e9SAndroid Build Coastguard Worker if t.requires_grad: 5324*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad, torch.ones_like(t) * 2) 5325*da0073e9SAndroid Build Coastguard Worker else: 5326*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.grad is None) 5327*da0073e9SAndroid Build Coastguard Worker 5328*da0073e9SAndroid Build Coastguard Worker @xfailIfTorchDynamo 5329*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 5330*da0073e9SAndroid Build Coastguard Worker PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property" 5331*da0073e9SAndroid Build Coastguard Worker ) 5332*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 5333*da0073e9SAndroid Build Coastguard Worker def test_jagged_layout_construction_with_pinned_memory(self, device): 5334*da0073e9SAndroid Build Coastguard Worker for tensor_list in self._get_example_tensor_lists(): 5335*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 5336*da0073e9SAndroid Build Coastguard Worker tensor_list, layout=torch.jagged, device="cpu", pin_memory=True 5337*da0073e9SAndroid Build Coastguard Worker ) 5338*da0073e9SAndroid Build Coastguard Worker 5339*da0073e9SAndroid Build Coastguard Worker expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1 5340*da0073e9SAndroid Build Coastguard Worker expected_batch_size = len(tensor_list) 5341*da0073e9SAndroid Build Coastguard Worker expected_min_seqlen = min( 5342*da0073e9SAndroid Build Coastguard Worker (torch.tensor(t) if isinstance(t, list) else t).shape[0] 5343*da0073e9SAndroid Build Coastguard Worker for t in tensor_list 5344*da0073e9SAndroid Build Coastguard Worker ) 5345*da0073e9SAndroid Build Coastguard Worker expected_max_seqlen = max( 5346*da0073e9SAndroid Build Coastguard Worker (torch.tensor(t) if isinstance(t, list) else t).shape[0] 5347*da0073e9SAndroid Build Coastguard Worker for t in tensor_list 5348*da0073e9SAndroid Build Coastguard Worker ) 5349*da0073e9SAndroid Build Coastguard Worker self._validate_nt( 5350*da0073e9SAndroid Build Coastguard Worker nt, 5351*da0073e9SAndroid Build Coastguard Worker device="cpu", 5352*da0073e9SAndroid Build Coastguard Worker dtype=torch.float32, 5353*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 5354*da0073e9SAndroid Build Coastguard Worker requires_grad=False, 5355*da0073e9SAndroid Build Coastguard Worker dim=expected_dim, 5356*da0073e9SAndroid Build Coastguard Worker batch_size=expected_batch_size, 5357*da0073e9SAndroid Build Coastguard Worker contiguous=True, 5358*da0073e9SAndroid Build Coastguard Worker cached_min_seqlen=expected_min_seqlen, 5359*da0073e9SAndroid Build Coastguard Worker cached_max_seqlen=expected_max_seqlen, 5360*da0073e9SAndroid Build Coastguard Worker ) 5361*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nt.is_pinned()) 5362*da0073e9SAndroid Build Coastguard Worker 5363*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.half) 5364*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 5365*da0073e9SAndroid Build Coastguard Worker @parametrize("values_is_view", [False, True]) 5366*da0073e9SAndroid Build Coastguard Worker def test_jagged_view_from_values_offsets( 5367*da0073e9SAndroid Build Coastguard Worker self, device, dtype, requires_grad, values_is_view 5368*da0073e9SAndroid Build Coastguard Worker ): 5369*da0073e9SAndroid Build Coastguard Worker if values_is_view: 5370*da0073e9SAndroid Build Coastguard Worker # make values a view of base 5371*da0073e9SAndroid Build Coastguard Worker base = torch.randn( 5372*da0073e9SAndroid Build Coastguard Worker 2, 3, 4, 5, 6, device=device, dtype=dtype, requires_grad=requires_grad 5373*da0073e9SAndroid Build Coastguard Worker ) 5374*da0073e9SAndroid Build Coastguard Worker values = base.flatten(0, -2) 5375*da0073e9SAndroid Build Coastguard Worker else: 5376*da0073e9SAndroid Build Coastguard Worker values = torch.randn( 5377*da0073e9SAndroid Build Coastguard Worker 10, 5, device=device, dtype=dtype, requires_grad=requires_grad 5378*da0073e9SAndroid Build Coastguard Worker ) 5379*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64) 5380*da0073e9SAndroid Build Coastguard Worker 5381*da0073e9SAndroid Build Coastguard Worker nt = nested_view_from_values_offsets(values, offsets) 5382*da0073e9SAndroid Build Coastguard Worker 5383*da0073e9SAndroid Build Coastguard Worker expected_dim = values.dim() + 1 5384*da0073e9SAndroid Build Coastguard Worker expected_batch_size = offsets.shape[0] - 1 5385*da0073e9SAndroid Build Coastguard Worker expected_base = base if values_is_view else values 5386*da0073e9SAndroid Build Coastguard Worker lengths = offsets.diff() 5387*da0073e9SAndroid Build Coastguard Worker self._validate_nt( 5388*da0073e9SAndroid Build Coastguard Worker nt, 5389*da0073e9SAndroid Build Coastguard Worker device, 5390*da0073e9SAndroid Build Coastguard Worker dtype, 5391*da0073e9SAndroid Build Coastguard Worker torch.jagged, 5392*da0073e9SAndroid Build Coastguard Worker requires_grad, 5393*da0073e9SAndroid Build Coastguard Worker expected_dim, 5394*da0073e9SAndroid Build Coastguard Worker expected_batch_size, 5395*da0073e9SAndroid Build Coastguard Worker # ensure NT is a proper view 5396*da0073e9SAndroid Build Coastguard Worker base=expected_base, 5397*da0073e9SAndroid Build Coastguard Worker contiguous=True, 5398*da0073e9SAndroid Build Coastguard Worker # if no min / max are passed, expect the metadata cache to be empty 5399*da0073e9SAndroid Build Coastguard Worker cached_min_seqlen=None, 5400*da0073e9SAndroid Build Coastguard Worker cached_max_seqlen=None, 5401*da0073e9SAndroid Build Coastguard Worker ) 5402*da0073e9SAndroid Build Coastguard Worker 5403*da0073e9SAndroid Build Coastguard Worker if requires_grad: 5404*da0073e9SAndroid Build Coastguard Worker # Make sure grads flow back 5405*da0073e9SAndroid Build Coastguard Worker (nt * 2).backward(torch.ones_like(nt)) 5406*da0073e9SAndroid Build Coastguard Worker 5407*da0073e9SAndroid Build Coastguard Worker @torch.compiler.disable 5408*da0073e9SAndroid Build Coastguard Worker def _check_grad(t): 5409*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.grad is not None) 5410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad, torch.ones_like(t) * 2) 5411*da0073e9SAndroid Build Coastguard Worker 5412*da0073e9SAndroid Build Coastguard Worker _check_grad(base if values_is_view else values) 5413*da0073e9SAndroid Build Coastguard Worker 5414*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 5415*da0073e9SAndroid Build Coastguard Worker @parametrize("pass_min_max", [False, True]) 5416*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_from_jagged(self, device, dtype, pass_min_max): 5417*da0073e9SAndroid Build Coastguard Worker # === construct from (values, offsets) === 5418*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10, 5, device=device, dtype=dtype) 5419*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64) 5420*da0073e9SAndroid Build Coastguard Worker 5421*da0073e9SAndroid Build Coastguard Worker # compute min / max seqlen 5422*da0073e9SAndroid Build Coastguard Worker lengths = offsets.diff() 5423*da0073e9SAndroid Build Coastguard Worker min_seqlen = lengths.min().item() 5424*da0073e9SAndroid Build Coastguard Worker max_seqlen = lengths.max().item() 5425*da0073e9SAndroid Build Coastguard Worker 5426*da0073e9SAndroid Build Coastguard Worker if pass_min_max: 5427*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged( 5428*da0073e9SAndroid Build Coastguard Worker values, offsets=offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen 5429*da0073e9SAndroid Build Coastguard Worker ) 5430*da0073e9SAndroid Build Coastguard Worker else: 5431*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, offsets=offsets) 5432*da0073e9SAndroid Build Coastguard Worker self._validate_nt( 5433*da0073e9SAndroid Build Coastguard Worker nt, 5434*da0073e9SAndroid Build Coastguard Worker device, 5435*da0073e9SAndroid Build Coastguard Worker dtype, 5436*da0073e9SAndroid Build Coastguard Worker torch.jagged, 5437*da0073e9SAndroid Build Coastguard Worker requires_grad=False, 5438*da0073e9SAndroid Build Coastguard Worker dim=3, 5439*da0073e9SAndroid Build Coastguard Worker batch_size=4, 5440*da0073e9SAndroid Build Coastguard Worker contiguous=True, 5441*da0073e9SAndroid Build Coastguard Worker cached_min_seqlen=(min_seqlen if pass_min_max else None), 5442*da0073e9SAndroid Build Coastguard Worker cached_max_seqlen=(max_seqlen if pass_min_max else None), 5443*da0073e9SAndroid Build Coastguard Worker base=values, 5444*da0073e9SAndroid Build Coastguard Worker ) 5445*da0073e9SAndroid Build Coastguard Worker 5446*da0073e9SAndroid Build Coastguard Worker # === construct from (values, offsets, lengths) === 5447*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([2, 1, 1, 2], device=device) 5448*da0073e9SAndroid Build Coastguard Worker 5449*da0073e9SAndroid Build Coastguard Worker # compute min / max seqlen 5450*da0073e9SAndroid Build Coastguard Worker min_seqlen = lengths.min().item() 5451*da0073e9SAndroid Build Coastguard Worker max_seqlen = lengths.max().item() 5452*da0073e9SAndroid Build Coastguard Worker 5453*da0073e9SAndroid Build Coastguard Worker if pass_min_max: 5454*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged( 5455*da0073e9SAndroid Build Coastguard Worker values, 5456*da0073e9SAndroid Build Coastguard Worker offsets=offsets, 5457*da0073e9SAndroid Build Coastguard Worker lengths=lengths, 5458*da0073e9SAndroid Build Coastguard Worker min_seqlen=min_seqlen, 5459*da0073e9SAndroid Build Coastguard Worker max_seqlen=max_seqlen, 5460*da0073e9SAndroid Build Coastguard Worker ) 5461*da0073e9SAndroid Build Coastguard Worker else: 5462*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged( 5463*da0073e9SAndroid Build Coastguard Worker values, offsets=offsets, lengths=lengths 5464*da0073e9SAndroid Build Coastguard Worker ) 5465*da0073e9SAndroid Build Coastguard Worker 5466*da0073e9SAndroid Build Coastguard Worker # when both offsets / lengths are specified, expect non-contiguous 5467*da0073e9SAndroid Build Coastguard Worker self._validate_nt( 5468*da0073e9SAndroid Build Coastguard Worker nt, 5469*da0073e9SAndroid Build Coastguard Worker device, 5470*da0073e9SAndroid Build Coastguard Worker dtype, 5471*da0073e9SAndroid Build Coastguard Worker torch.jagged, 5472*da0073e9SAndroid Build Coastguard Worker requires_grad=False, 5473*da0073e9SAndroid Build Coastguard Worker dim=3, 5474*da0073e9SAndroid Build Coastguard Worker batch_size=4, 5475*da0073e9SAndroid Build Coastguard Worker contiguous=False, 5476*da0073e9SAndroid Build Coastguard Worker cached_min_seqlen=(min_seqlen if pass_min_max else None), 5477*da0073e9SAndroid Build Coastguard Worker cached_max_seqlen=(max_seqlen if pass_min_max else None), 5478*da0073e9SAndroid Build Coastguard Worker base=values, 5479*da0073e9SAndroid Build Coastguard Worker ) 5480*da0073e9SAndroid Build Coastguard Worker self.assertIs(nt.lengths(), lengths) 5481*da0073e9SAndroid Build Coastguard Worker 5482*da0073e9SAndroid Build Coastguard Worker # === construct from (values, lengths) === 5483*da0073e9SAndroid Build Coastguard Worker values = torch.randn(14, 5, device=device, dtype=dtype) 5484*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([2, 3, 4, 5], device=device) 5485*da0073e9SAndroid Build Coastguard Worker 5486*da0073e9SAndroid Build Coastguard Worker # compute min / max seqlen 5487*da0073e9SAndroid Build Coastguard Worker min_seqlen = lengths.min().item() 5488*da0073e9SAndroid Build Coastguard Worker max_seqlen = lengths.max().item() 5489*da0073e9SAndroid Build Coastguard Worker 5490*da0073e9SAndroid Build Coastguard Worker if pass_min_max: 5491*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged( 5492*da0073e9SAndroid Build Coastguard Worker values, lengths=lengths, min_seqlen=min_seqlen, max_seqlen=max_seqlen 5493*da0073e9SAndroid Build Coastguard Worker ) 5494*da0073e9SAndroid Build Coastguard Worker else: 5495*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths) 5496*da0073e9SAndroid Build Coastguard Worker 5497*da0073e9SAndroid Build Coastguard Worker # for now, if only lengths is specified, convert to offsets to integrate best with the 5498*da0073e9SAndroid Build Coastguard Worker # existing kernels 5499*da0073e9SAndroid Build Coastguard Worker expected_offsets = torch.tensor([0, 2, 5, 9, 14], device=device) 5500*da0073e9SAndroid Build Coastguard Worker expected_nt = torch.nested.nested_tensor_from_jagged( 5501*da0073e9SAndroid Build Coastguard Worker values, offsets=expected_offsets 5502*da0073e9SAndroid Build Coastguard Worker ) 5503*da0073e9SAndroid Build Coastguard Worker self._validate_nt( 5504*da0073e9SAndroid Build Coastguard Worker nt, 5505*da0073e9SAndroid Build Coastguard Worker device, 5506*da0073e9SAndroid Build Coastguard Worker dtype, 5507*da0073e9SAndroid Build Coastguard Worker torch.jagged, 5508*da0073e9SAndroid Build Coastguard Worker requires_grad=False, 5509*da0073e9SAndroid Build Coastguard Worker dim=3, 5510*da0073e9SAndroid Build Coastguard Worker batch_size=4, 5511*da0073e9SAndroid Build Coastguard Worker contiguous=True, 5512*da0073e9SAndroid Build Coastguard Worker cached_min_seqlen=(min_seqlen if pass_min_max else None), 5513*da0073e9SAndroid Build Coastguard Worker cached_max_seqlen=(max_seqlen if pass_min_max else None), 5514*da0073e9SAndroid Build Coastguard Worker base=values, 5515*da0073e9SAndroid Build Coastguard Worker ref_nt=expected_nt, 5516*da0073e9SAndroid Build Coastguard Worker ) 5517*da0073e9SAndroid Build Coastguard Worker 5518*da0073e9SAndroid Build Coastguard Worker # error case: no offsets or lengths 5519*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5520*da0073e9SAndroid Build Coastguard Worker RuntimeError, "At least one of offsets or lengths is required" 5521*da0073e9SAndroid Build Coastguard Worker ): 5522*da0073e9SAndroid Build Coastguard Worker torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None) 5523*da0073e9SAndroid Build Coastguard Worker 5524*da0073e9SAndroid Build Coastguard Worker @onlyCPU 5525*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_from_jagged_fx_trace(self, device): 5526*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 5527*da0073e9SAndroid Build Coastguard Worker return torch.nested.nested_tensor_from_jagged(x, y) 5528*da0073e9SAndroid Build Coastguard Worker 5529*da0073e9SAndroid Build Coastguard Worker def user_unwrapped(x, y): 5530*da0073e9SAndroid Build Coastguard Worker return fn(x, y) 5531*da0073e9SAndroid Build Coastguard Worker 5532*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5533*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5534*da0073e9SAndroid Build Coastguard Worker "torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace", 5535*da0073e9SAndroid Build Coastguard Worker ): 5536*da0073e9SAndroid Build Coastguard Worker torch.fx.symbolic_trace(user_unwrapped) 5537*da0073e9SAndroid Build Coastguard Worker 5538*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.half) 5539*da0073e9SAndroid Build Coastguard Worker @parametrize("dim", range(5)) 5540*da0073e9SAndroid Build Coastguard Worker @parametrize( 5541*da0073e9SAndroid Build Coastguard Worker "layout", 5542*da0073e9SAndroid Build Coastguard Worker [torch.strided, torch.jagged], 5543*da0073e9SAndroid Build Coastguard Worker name_fn=lambda l: f"layout_{str(l).split('.')[1]}", 5544*da0073e9SAndroid Build Coastguard Worker ) 5545*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 5546*da0073e9SAndroid Build Coastguard Worker @parametrize("contiguous", [False, True]) 5547*da0073e9SAndroid Build Coastguard Worker def test_as_nested_tensor_from_tensor( 5548*da0073e9SAndroid Build Coastguard Worker self, device, dtype, dim, layout, requires_grad, contiguous 5549*da0073e9SAndroid Build Coastguard Worker ): 5550*da0073e9SAndroid Build Coastguard Worker if dim == 0: 5551*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(3.0, requires_grad=requires_grad) 5552*da0073e9SAndroid Build Coastguard Worker else: 5553*da0073e9SAndroid Build Coastguard Worker t = torch.randn(*(3 for _ in range(dim)), requires_grad=requires_grad) 5554*da0073e9SAndroid Build Coastguard Worker assert t.dim() == dim 5555*da0073e9SAndroid Build Coastguard Worker 5556*da0073e9SAndroid Build Coastguard Worker if dim < 2: 5557*da0073e9SAndroid Build Coastguard Worker # 0-1 dim tensors can't be converted to NTs 5558*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5559*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected tensor argument to have dim" 5560*da0073e9SAndroid Build Coastguard Worker ): 5561*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor( 5562*da0073e9SAndroid Build Coastguard Worker t, device=device, dtype=dtype, layout=layout 5563*da0073e9SAndroid Build Coastguard Worker ) 5564*da0073e9SAndroid Build Coastguard Worker return 5565*da0073e9SAndroid Build Coastguard Worker 5566*da0073e9SAndroid Build Coastguard Worker orig_t = t 5567*da0073e9SAndroid Build Coastguard Worker if not contiguous: 5568*da0073e9SAndroid Build Coastguard Worker t = t.transpose(0, 1) 5569*da0073e9SAndroid Build Coastguard Worker 5570*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor(t, device=device, dtype=dtype, layout=layout) 5571*da0073e9SAndroid Build Coastguard Worker expected_dim = t.dim() 5572*da0073e9SAndroid Build Coastguard Worker expected_batch_size = t.size(0) 5573*da0073e9SAndroid Build Coastguard Worker expected_seqlen = t.size(1) if layout == torch.jagged else None 5574*da0073e9SAndroid Build Coastguard Worker self._validate_nt( 5575*da0073e9SAndroid Build Coastguard Worker nt, 5576*da0073e9SAndroid Build Coastguard Worker device, 5577*da0073e9SAndroid Build Coastguard Worker dtype, 5578*da0073e9SAndroid Build Coastguard Worker layout, 5579*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 5580*da0073e9SAndroid Build Coastguard Worker dim=dim, 5581*da0073e9SAndroid Build Coastguard Worker batch_size=expected_batch_size, 5582*da0073e9SAndroid Build Coastguard Worker contiguous=True, 5583*da0073e9SAndroid Build Coastguard Worker cached_min_seqlen=expected_seqlen, 5584*da0073e9SAndroid Build Coastguard Worker cached_max_seqlen=expected_seqlen, 5585*da0073e9SAndroid Build Coastguard Worker ) 5586*da0073e9SAndroid Build Coastguard Worker 5587*da0073e9SAndroid Build Coastguard Worker if torch.device(device) == t.device and dtype == t.dtype and contiguous: 5588*da0073e9SAndroid Build Coastguard Worker # should be the non-copying (view) case 5589*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nt._is_view() and nt._base is t) 5590*da0073e9SAndroid Build Coastguard Worker 5591*da0073e9SAndroid Build Coastguard Worker # should have equivalent components to construction from unbound tensor list 5592*da0073e9SAndroid Build Coastguard Worker nt_from_unbind = torch.nested.as_nested_tensor( 5593*da0073e9SAndroid Build Coastguard Worker list(t.unbind(0)), device=device, dtype=dtype, layout=layout 5594*da0073e9SAndroid Build Coastguard Worker ) 5595*da0073e9SAndroid Build Coastguard Worker self.assertEqualIgnoringNestedInts(nt, nt_from_unbind) 5596*da0073e9SAndroid Build Coastguard Worker 5597*da0073e9SAndroid Build Coastguard Worker # ensure call on a NT with the same properties returns the NT directly 5598*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.as_nested_tensor( 5599*da0073e9SAndroid Build Coastguard Worker nt, device=device, dtype=dtype, layout=layout 5600*da0073e9SAndroid Build Coastguard Worker ) 5601*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nt is nt2) 5602*da0073e9SAndroid Build Coastguard Worker 5603*da0073e9SAndroid Build Coastguard Worker # ensure call with device=None uses input tensor device 5604*da0073e9SAndroid Build Coastguard Worker nt3 = torch.nested.as_nested_tensor( 5605*da0073e9SAndroid Build Coastguard Worker t.to(device=device, dtype=dtype), 5606*da0073e9SAndroid Build Coastguard Worker device=None, 5607*da0073e9SAndroid Build Coastguard Worker dtype=None, 5608*da0073e9SAndroid Build Coastguard Worker layout=layout, 5609*da0073e9SAndroid Build Coastguard Worker ) 5610*da0073e9SAndroid Build Coastguard Worker self._validate_nt( 5611*da0073e9SAndroid Build Coastguard Worker nt3, 5612*da0073e9SAndroid Build Coastguard Worker device, 5613*da0073e9SAndroid Build Coastguard Worker dtype, 5614*da0073e9SAndroid Build Coastguard Worker layout, 5615*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 5616*da0073e9SAndroid Build Coastguard Worker dim=dim, 5617*da0073e9SAndroid Build Coastguard Worker batch_size=expected_batch_size, 5618*da0073e9SAndroid Build Coastguard Worker contiguous=True, 5619*da0073e9SAndroid Build Coastguard Worker cached_min_seqlen=expected_seqlen, 5620*da0073e9SAndroid Build Coastguard Worker cached_max_seqlen=expected_seqlen, 5621*da0073e9SAndroid Build Coastguard Worker ) 5622*da0073e9SAndroid Build Coastguard Worker 5623*da0073e9SAndroid Build Coastguard Worker # we don't support conversion between layouts this way atm 5624*da0073e9SAndroid Build Coastguard Worker other_layout = torch.strided if layout == torch.jagged else torch.jagged 5625*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5626*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Converting between nested tensor layouts is not supported" 5627*da0073e9SAndroid Build Coastguard Worker ): 5628*da0073e9SAndroid Build Coastguard Worker torch.nested.as_nested_tensor( 5629*da0073e9SAndroid Build Coastguard Worker nt, device=device, dtype=dtype, layout=other_layout 5630*da0073e9SAndroid Build Coastguard Worker ) 5631*da0073e9SAndroid Build Coastguard Worker 5632*da0073e9SAndroid Build Coastguard Worker if requires_grad: 5633*da0073e9SAndroid Build Coastguard Worker # make sure gradients flow back into inputs 5634*da0073e9SAndroid Build Coastguard Worker (nt * 2).backward(torch.ones_like(nt)) 5635*da0073e9SAndroid Build Coastguard Worker self.assertEqual(orig_t.grad, torch.ones_like(orig_t) * 2) 5636*da0073e9SAndroid Build Coastguard Worker 5637*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.half) 5638*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 5639*da0073e9SAndroid Build Coastguard Worker def test_device_dtype_transfer_updates_offsets(self, device, dtype): 5640*da0073e9SAndroid Build Coastguard Worker for tensor_list in self._get_example_tensor_lists(): 5641*da0073e9SAndroid Build Coastguard Worker orig_device = torch.device("cpu") 5642*da0073e9SAndroid Build Coastguard Worker orig_dtype = torch.float32 5643*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 5644*da0073e9SAndroid Build Coastguard Worker tensor_list, layout=torch.jagged, device=orig_device, dtype=orig_dtype 5645*da0073e9SAndroid Build Coastguard Worker ) 5646*da0073e9SAndroid Build Coastguard Worker 5647*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.int64, nt.offsets().dtype) 5648*da0073e9SAndroid Build Coastguard Worker nt = nt.to(device=device).to(dtype=dtype) 5649*da0073e9SAndroid Build Coastguard Worker 5650*da0073e9SAndroid Build Coastguard Worker # offsets should still be int64 on the new device 5651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.values().device, nt.offsets().device) 5652*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.int64, nt.offsets().dtype) 5653*da0073e9SAndroid Build Coastguard Worker 5654*da0073e9SAndroid Build Coastguard Worker def test_unbind(self, device): 5655*da0073e9SAndroid Build Coastguard Worker for tensor_list in self._get_example_tensor_lists(): 5656*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 5657*da0073e9SAndroid Build Coastguard Worker tensor_list, layout=torch.jagged, device=device 5658*da0073e9SAndroid Build Coastguard Worker ) # ragged_idx = 1 5659*da0073e9SAndroid Build Coastguard Worker out = nt.unbind() 5660*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(out), len(tensor_list)) 5661*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(out): 5662*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, tensor_list[i]) 5663*da0073e9SAndroid Build Coastguard Worker 5664*da0073e9SAndroid Build Coastguard Worker @parametrize("ragged_idx", [2, 3]) 5665*da0073e9SAndroid Build Coastguard Worker def test_unbind_transpose(self, device, ragged_idx): 5666*da0073e9SAndroid Build Coastguard Worker for tensor_list in self._get_example_tensor_lists(): 5667*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 5668*da0073e9SAndroid Build Coastguard Worker tensor_list, layout=torch.jagged, device=device 5669*da0073e9SAndroid Build Coastguard Worker ) 5670*da0073e9SAndroid Build Coastguard Worker if ragged_idx < nt.dim(): 5671*da0073e9SAndroid Build Coastguard Worker nt = nt.transpose(1, ragged_idx) # set ragged_idx 5672*da0073e9SAndroid Build Coastguard Worker out = nt.unbind() 5673*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(out), len(tensor_list)) 5674*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(out): 5675*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5676*da0073e9SAndroid Build Coastguard Worker t.transpose(0, ragged_idx - 1), tensor_list[i] 5677*da0073e9SAndroid Build Coastguard Worker ) # transpose back each element of result 5678*da0073e9SAndroid Build Coastguard Worker 5679*da0073e9SAndroid Build Coastguard Worker def test_unbind_transpose_ragged_idx_last_dim(self, device): 5680*da0073e9SAndroid Build Coastguard Worker for tensor_list in self._get_example_tensor_lists(): 5681*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 5682*da0073e9SAndroid Build Coastguard Worker tensor_list, layout=torch.jagged, device=device 5683*da0073e9SAndroid Build Coastguard Worker ).transpose(1, -1) # set ragged_idx = last dimension 5684*da0073e9SAndroid Build Coastguard Worker out = nt.unbind() 5685*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(out), len(tensor_list)) 5686*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(out): 5687*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5688*da0073e9SAndroid Build Coastguard Worker t.transpose(0, -1), tensor_list[i] 5689*da0073e9SAndroid Build Coastguard Worker ) # transpose back each element of result 5690*da0073e9SAndroid Build Coastguard Worker 5691*da0073e9SAndroid Build Coastguard Worker def test_unbind_lengths(self, device): 5692*da0073e9SAndroid Build Coastguard Worker values = torch.randn(16, 128, device=device) 5693*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 8, 12, 13, 16], device=device) 5694*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([6, 2, 1, 2], device=device) 5695*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged( 5696*da0073e9SAndroid Build Coastguard Worker values, offsets=offsets, lengths=lengths 5697*da0073e9SAndroid Build Coastguard Worker ) # 3D nested tensor 5698*da0073e9SAndroid Build Coastguard Worker 5699*da0073e9SAndroid Build Coastguard Worker tensor_list = [] 5700*da0073e9SAndroid Build Coastguard Worker for i in range(offsets.shape[0] - 1): 5701*da0073e9SAndroid Build Coastguard Worker tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i])]) 5702*da0073e9SAndroid Build Coastguard Worker 5703*da0073e9SAndroid Build Coastguard Worker out = nt.unbind() 5704*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(out), len(tensor_list)) 5705*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(out): 5706*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, tensor_list[i]) 5707*da0073e9SAndroid Build Coastguard Worker 5708*da0073e9SAndroid Build Coastguard Worker def test_unbind_lengths_ragged_idx_1(self, device): 5709*da0073e9SAndroid Build Coastguard Worker values = torch.randn(16, 8, 128, device=device) 5710*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 8, 12, 13, 16], device=device) 5711*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([6, 2, 1, 2], device=device) 5712*da0073e9SAndroid Build Coastguard Worker ragged_idx = 1 5713*da0073e9SAndroid Build Coastguard Worker nt = torch.nested._internal.nested_tensor.NestedTensor( 5714*da0073e9SAndroid Build Coastguard Worker values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx 5715*da0073e9SAndroid Build Coastguard Worker ) # 4D nested tensor 5716*da0073e9SAndroid Build Coastguard Worker 5717*da0073e9SAndroid Build Coastguard Worker tensor_list = [] 5718*da0073e9SAndroid Build Coastguard Worker for i in range(offsets.shape[0] - 1): 5719*da0073e9SAndroid Build Coastguard Worker tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i]), :, :]) 5720*da0073e9SAndroid Build Coastguard Worker 5721*da0073e9SAndroid Build Coastguard Worker out = nt.unbind() 5722*da0073e9SAndroid Build Coastguard Worker 5723*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(out), len(tensor_list)) 5724*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(out): 5725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, tensor_list[i]) 5726*da0073e9SAndroid Build Coastguard Worker 5727*da0073e9SAndroid Build Coastguard Worker def test_unbind_lengths_ragged_idx_equals_2_bad_dim(self, device): 5728*da0073e9SAndroid Build Coastguard Worker values = torch.randn(16, 8, 128, device=device) 5729*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 8, 12, 13, 16], device=device) 5730*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([6, 2, 1, 2], device=device) 5731*da0073e9SAndroid Build Coastguard Worker ragged_idx = 2 5732*da0073e9SAndroid Build Coastguard Worker nt = torch.nested._internal.nested_tensor.NestedTensor( 5733*da0073e9SAndroid Build Coastguard Worker values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx 5734*da0073e9SAndroid Build Coastguard Worker ) # 4D nested tensor 5735*da0073e9SAndroid Build Coastguard Worker 5736*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 5737*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5738*da0073e9SAndroid Build Coastguard Worker r"unbind\(\): nested tensor offsets and lengths.*", 5739*da0073e9SAndroid Build Coastguard Worker lambda: nt.unbind(), 5740*da0073e9SAndroid Build Coastguard Worker ) 5741*da0073e9SAndroid Build Coastguard Worker 5742*da0073e9SAndroid Build Coastguard Worker def test_unbind_lengths_ragged_idx_2(self, device): 5743*da0073e9SAndroid Build Coastguard Worker values = torch.randn(16, 8, 128, device=device) 5744*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 4, 8], device=device) 5745*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([2, 1, 3], device=device) 5746*da0073e9SAndroid Build Coastguard Worker ragged_idx = 2 5747*da0073e9SAndroid Build Coastguard Worker nt = torch.nested._internal.nested_tensor.NestedTensor( 5748*da0073e9SAndroid Build Coastguard Worker values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx 5749*da0073e9SAndroid Build Coastguard Worker ) # 4D nested tensor 5750*da0073e9SAndroid Build Coastguard Worker 5751*da0073e9SAndroid Build Coastguard Worker tensor_list = [] 5752*da0073e9SAndroid Build Coastguard Worker for i in range(offsets.shape[0] - 1): 5753*da0073e9SAndroid Build Coastguard Worker tensor_list.append(values[:, offsets[i] : (offsets[i] + lengths[i]), :]) 5754*da0073e9SAndroid Build Coastguard Worker 5755*da0073e9SAndroid Build Coastguard Worker out = nt.unbind() 5756*da0073e9SAndroid Build Coastguard Worker 5757*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(out), len(tensor_list)) 5758*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(out): 5759*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, tensor_list[i]) 5760*da0073e9SAndroid Build Coastguard Worker 5761*da0073e9SAndroid Build Coastguard Worker def test_unbind_lengths_ragged_idx_3(self, device): 5762*da0073e9SAndroid Build Coastguard Worker values = torch.randn(16, 8, 128, device=device) 5763*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 100, 128], device=device) 5764*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([50, 28], device=device) 5765*da0073e9SAndroid Build Coastguard Worker ragged_idx = 3 5766*da0073e9SAndroid Build Coastguard Worker nt = torch.nested._internal.nested_tensor.NestedTensor( 5767*da0073e9SAndroid Build Coastguard Worker values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx 5768*da0073e9SAndroid Build Coastguard Worker ) # 4D nested tensor 5769*da0073e9SAndroid Build Coastguard Worker 5770*da0073e9SAndroid Build Coastguard Worker tensor_list = [] 5771*da0073e9SAndroid Build Coastguard Worker for i in range(offsets.shape[0] - 1): 5772*da0073e9SAndroid Build Coastguard Worker tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])]) 5773*da0073e9SAndroid Build Coastguard Worker 5774*da0073e9SAndroid Build Coastguard Worker out = nt.unbind() 5775*da0073e9SAndroid Build Coastguard Worker 5776*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(out), len(tensor_list)) 5777*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(out): 5778*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, tensor_list[i]) 5779*da0073e9SAndroid Build Coastguard Worker 5780*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo( 5781*da0073e9SAndroid Build Coastguard Worker "TorchDynamo raises an error for ragged_idx == 0 earlier than Torch" 5782*da0073e9SAndroid Build Coastguard Worker ) 5783*da0073e9SAndroid Build Coastguard Worker def test_unbind_lengths_ragged_idx_0(self, device): 5784*da0073e9SAndroid Build Coastguard Worker values = torch.randn(16, 8, 128, device=device) 5785*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 100, 128], device=device) 5786*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([50, 28], device=device) 5787*da0073e9SAndroid Build Coastguard Worker ragged_idx = 0 5788*da0073e9SAndroid Build Coastguard Worker nt = torch.nested._internal.nested_tensor.NestedTensor( 5789*da0073e9SAndroid Build Coastguard Worker values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx 5790*da0073e9SAndroid Build Coastguard Worker ) # 4D nested tensor 5791*da0073e9SAndroid Build Coastguard Worker 5792*da0073e9SAndroid Build Coastguard Worker tensor_list = [] 5793*da0073e9SAndroid Build Coastguard Worker for i in range(offsets.shape[0] - 1): 5794*da0073e9SAndroid Build Coastguard Worker tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])]) 5795*da0073e9SAndroid Build Coastguard Worker 5796*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 5797*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5798*da0073e9SAndroid Build Coastguard Worker r"unbind\(\): nested tensor.*out of bounds", 5799*da0073e9SAndroid Build Coastguard Worker lambda: nt.unbind(), 5800*da0073e9SAndroid Build Coastguard Worker ) 5801*da0073e9SAndroid Build Coastguard Worker 5802*da0073e9SAndroid Build Coastguard Worker def test_narrow(self, device): 5803*da0073e9SAndroid Build Coastguard Worker starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64) 5804*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64) 5805*da0073e9SAndroid Build Coastguard Worker buffer = ( 5806*da0073e9SAndroid Build Coastguard Worker torch.arange(0, 10, device=device, dtype=torch.int64) 5807*da0073e9SAndroid Build Coastguard Worker .unsqueeze(0) 5808*da0073e9SAndroid Build Coastguard Worker .expand(5, -1) 5809*da0073e9SAndroid Build Coastguard Worker .clone() 5810*da0073e9SAndroid Build Coastguard Worker .detach() 5811*da0073e9SAndroid Build Coastguard Worker ) 5812*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.narrow(buffer, 1, starts, lengths, layout=torch.jagged) 5813*da0073e9SAndroid Build Coastguard Worker 5814*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nt._is_view() and nt._base is buffer) 5815*da0073e9SAndroid Build Coastguard Worker 5816*da0073e9SAndroid Build Coastguard Worker # TODO: Use this approach when unbind is functional 5817*da0073e9SAndroid Build Coastguard Worker # unbinded_nt = nt.unbind() 5818*da0073e9SAndroid Build Coastguard Worker # for i in range(starts.shape[0]): 5819*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i]) 5820*da0073e9SAndroid Build Coastguard Worker for i in range(starts.shape[0]): 5821*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5822*da0073e9SAndroid Build Coastguard Worker torch.arange( 5823*da0073e9SAndroid Build Coastguard Worker starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64 5824*da0073e9SAndroid Build Coastguard Worker ), 5825*da0073e9SAndroid Build Coastguard Worker nt.values()[nt.offsets()[i] : (nt.offsets()[i] + nt.lengths()[i])], 5826*da0073e9SAndroid Build Coastguard Worker ) 5827*da0073e9SAndroid Build Coastguard Worker 5828*da0073e9SAndroid Build Coastguard Worker def test_njt_cat(self, device): 5829*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 3], device=device, dtype=torch.int64) 5830*da0073e9SAndroid Build Coastguard Worker values_1 = torch.randn( 5831*da0073e9SAndroid Build Coastguard Worker 3, 2, dtype=torch.float64, device=device, requires_grad=True 5832*da0073e9SAndroid Build Coastguard Worker ) 5833*da0073e9SAndroid Build Coastguard Worker values_2 = torch.randn( 5834*da0073e9SAndroid Build Coastguard Worker 3, 4, dtype=torch.float64, device=device, requires_grad=True 5835*da0073e9SAndroid Build Coastguard Worker ) 5836*da0073e9SAndroid Build Coastguard Worker 5837*da0073e9SAndroid Build Coastguard Worker def grad_test_func(values_1, values_2, offsets): 5838*da0073e9SAndroid Build Coastguard Worker nt_1 = torch.nested.nested_tensor_from_jagged(values_1, offsets) 5839*da0073e9SAndroid Build Coastguard Worker nt_2 = torch.nested.nested_tensor_from_jagged(values_2, offsets) 5840*da0073e9SAndroid Build Coastguard Worker nt_3 = torch.cat([nt_1, nt_2], dim=-1) 5841*da0073e9SAndroid Build Coastguard Worker return nt_3.values() 5842*da0073e9SAndroid Build Coastguard Worker 5843*da0073e9SAndroid Build Coastguard Worker assert gradcheck( 5844*da0073e9SAndroid Build Coastguard Worker grad_test_func, 5845*da0073e9SAndroid Build Coastguard Worker inputs=(values_1, values_2, offsets), 5846*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5847*da0073e9SAndroid Build Coastguard Worker ) 5848*da0073e9SAndroid Build Coastguard Worker 5849*da0073e9SAndroid Build Coastguard Worker def test_is_contiguous(self, device): 5850*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) 5851*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) 5852*da0073e9SAndroid Build Coastguard Worker c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) 5853*da0073e9SAndroid Build Coastguard Worker nt_contiguous = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 5854*da0073e9SAndroid Build Coastguard Worker 5855*da0073e9SAndroid Build Coastguard Worker starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64) 5856*da0073e9SAndroid Build Coastguard Worker lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64) 5857*da0073e9SAndroid Build Coastguard Worker narrow_base = ( 5858*da0073e9SAndroid Build Coastguard Worker torch.arange(0, 10, device=device, dtype=torch.int64) 5859*da0073e9SAndroid Build Coastguard Worker .unsqueeze(0) 5860*da0073e9SAndroid Build Coastguard Worker .expand(5, -1) 5861*da0073e9SAndroid Build Coastguard Worker .clone() 5862*da0073e9SAndroid Build Coastguard Worker ) 5863*da0073e9SAndroid Build Coastguard Worker nt_noncontiguous = torch.nested.narrow( 5864*da0073e9SAndroid Build Coastguard Worker narrow_base, 1, starts_nc, lengths_nc, layout=torch.jagged 5865*da0073e9SAndroid Build Coastguard Worker ) 5866*da0073e9SAndroid Build Coastguard Worker 5867*da0073e9SAndroid Build Coastguard Worker starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64) 5868*da0073e9SAndroid Build Coastguard Worker lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64) 5869*da0073e9SAndroid Build Coastguard Worker nt_contiguous_narrow = torch.nested.narrow( 5870*da0073e9SAndroid Build Coastguard Worker narrow_base, 1, starts_c, lengths_c, layout=torch.jagged 5871*da0073e9SAndroid Build Coastguard Worker ) 5872*da0073e9SAndroid Build Coastguard Worker 5873*da0073e9SAndroid Build Coastguard Worker # Test contiguous case 5874*da0073e9SAndroid Build Coastguard Worker assert nt_contiguous.is_contiguous() 5875*da0073e9SAndroid Build Coastguard Worker 5876*da0073e9SAndroid Build Coastguard Worker # Test narrow case 5877*da0073e9SAndroid Build Coastguard Worker assert not nt_noncontiguous.is_contiguous() 5878*da0073e9SAndroid Build Coastguard Worker assert nt_contiguous_narrow.is_contiguous() 5879*da0073e9SAndroid Build Coastguard Worker 5880*da0073e9SAndroid Build Coastguard Worker # Test querying by memory_format 5881*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5882*da0073e9SAndroid Build Coastguard Worker nt_contiguous.is_contiguous(memory_format=torch.contiguous_format) 5883*da0073e9SAndroid Build Coastguard Worker ) 5884*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5885*da0073e9SAndroid Build Coastguard Worker not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format) 5886*da0073e9SAndroid Build Coastguard Worker ) 5887*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5888*da0073e9SAndroid Build Coastguard Worker nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format) 5889*da0073e9SAndroid Build Coastguard Worker ) 5890*da0073e9SAndroid Build Coastguard Worker 5891*da0073e9SAndroid Build Coastguard Worker def test_layout_under_torch_dispatch_mode(self): 5892*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.logging_tensor import ( 5893*da0073e9SAndroid Build Coastguard Worker capture_logs_with_logging_tensor_mode, 5894*da0073e9SAndroid Build Coastguard Worker ) 5895*da0073e9SAndroid Build Coastguard Worker 5896*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 5897*da0073e9SAndroid Build Coastguard Worker [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged 5898*da0073e9SAndroid Build Coastguard Worker ) 5899*da0073e9SAndroid Build Coastguard Worker 5900*da0073e9SAndroid Build Coastguard Worker with capture_logs_with_logging_tensor_mode(): 5901*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.layout, torch.jagged) 5902*da0073e9SAndroid Build Coastguard Worker 5903*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 5904*da0073e9SAndroid Build Coastguard Worker @parametrize( 5905*da0073e9SAndroid Build Coastguard Worker "func", [torch.empty_like, torch.randn_like], name_fn=lambda f: f.__name__ 5906*da0073e9SAndroid Build Coastguard Worker ) 5907*da0073e9SAndroid Build Coastguard Worker def test_like_shape(self, func): 5908*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 5909*da0073e9SAndroid Build Coastguard Worker [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged 5910*da0073e9SAndroid Build Coastguard Worker ) 5911*da0073e9SAndroid Build Coastguard Worker nt_like = func(nt) 5912*da0073e9SAndroid Build Coastguard Worker 5913*da0073e9SAndroid Build Coastguard Worker for nt_ub in nt_like.unbind(): 5914*da0073e9SAndroid Build Coastguard Worker t_like = func(nt_ub) 5915*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_ub.shape, t_like.shape) 5916*da0073e9SAndroid Build Coastguard Worker 5917*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 5918*da0073e9SAndroid Build Coastguard Worker @parametrize( 5919*da0073e9SAndroid Build Coastguard Worker "func", [torch.ones_like, torch.zeros_like], name_fn=lambda f: f.__name__ 5920*da0073e9SAndroid Build Coastguard Worker ) 5921*da0073e9SAndroid Build Coastguard Worker def test_like_value(self, func): 5922*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 5923*da0073e9SAndroid Build Coastguard Worker [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged 5924*da0073e9SAndroid Build Coastguard Worker ) 5925*da0073e9SAndroid Build Coastguard Worker nt_like = func(nt) 5926*da0073e9SAndroid Build Coastguard Worker 5927*da0073e9SAndroid Build Coastguard Worker for nt_ub in nt_like.unbind(): 5928*da0073e9SAndroid Build Coastguard Worker t_like = func(nt_ub) 5929*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_ub, t_like) 5930*da0073e9SAndroid Build Coastguard Worker 5931*da0073e9SAndroid Build Coastguard Worker def test_noncontiguous_pointwise(self, device): 5932*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device) 5933*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 3, 4, requires_grad=True, dtype=torch.float64, device=device) 5934*da0073e9SAndroid Build Coastguard Worker c = torch.randn(4, 3, 4, requires_grad=True, dtype=torch.float64, device=device) 5935*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([a, b, c], layout=torch.jagged) 5936*da0073e9SAndroid Build Coastguard Worker # transpose ragged dim 5937*da0073e9SAndroid Build Coastguard Worker transposed = nt.transpose(1, 2) 5938*da0073e9SAndroid Build Coastguard Worker self.assertFalse(transposed.is_contiguous()) 5939*da0073e9SAndroid Build Coastguard Worker clone = transposed.clone() 5940*da0073e9SAndroid Build Coastguard Worker 5941*da0073e9SAndroid Build Coastguard Worker def check_nt_equality(x, y): 5942*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.values(), y.values()) 5943*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.offsets(), y.offsets()) 5944*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x._ragged_idx, y._ragged_idx) 5945*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.shape, y.shape) 5946*da0073e9SAndroid Build Coastguard Worker 5947*da0073e9SAndroid Build Coastguard Worker self.assertFalse(clone.is_contiguous()) 5948*da0073e9SAndroid Build Coastguard Worker check_nt_equality(clone, transposed) 5949*da0073e9SAndroid Build Coastguard Worker 5950*da0073e9SAndroid Build Coastguard Worker clone_contig = transposed.clone(memory_format=torch.contiguous_format) 5951*da0073e9SAndroid Build Coastguard Worker self.assertTrue(clone_contig.is_contiguous()) 5952*da0073e9SAndroid Build Coastguard Worker check_nt_equality(clone_contig, transposed) 5953*da0073e9SAndroid Build Coastguard Worker 5954*da0073e9SAndroid Build Coastguard Worker detached = transposed.detach() 5955*da0073e9SAndroid Build Coastguard Worker self.assertFalse(clone.is_contiguous()) 5956*da0073e9SAndroid Build Coastguard Worker check_nt_equality(detached, transposed) 5957*da0073e9SAndroid Build Coastguard Worker 5958*da0073e9SAndroid Build Coastguard Worker def test_permute(self, device): 5959*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 5960*da0073e9SAndroid Build Coastguard Worker [2, None, 3, 5], device, torch.float32, layout=torch.jagged 5961*da0073e9SAndroid Build Coastguard Worker ) 5962*da0073e9SAndroid Build Coastguard Worker nt_shape = nt.shape 5963*da0073e9SAndroid Build Coastguard Worker nt_inner_shape = nt.values().shape 5964*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5965*da0073e9SAndroid Build Coastguard Worker ValueError, 5966*da0073e9SAndroid Build Coastguard Worker r"permute\(\): number of dimensions in the tensor input \(4\) " 5967*da0073e9SAndroid Build Coastguard Worker + r"does not match the length of the desired ordering of dimensions \(3\).", 5968*da0073e9SAndroid Build Coastguard Worker ): 5969*da0073e9SAndroid Build Coastguard Worker nt.permute(0, 2, 1) 5970*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5971*da0073e9SAndroid Build Coastguard Worker ValueError, r"permute\(\): duplicate dims are not allowed." 5972*da0073e9SAndroid Build Coastguard Worker ): 5973*da0073e9SAndroid Build Coastguard Worker nt.permute(0, 2, -2, 3) 5974*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5975*da0073e9SAndroid Build Coastguard Worker ValueError, "Permute is not supported on the batch dimension for jagged NT" 5976*da0073e9SAndroid Build Coastguard Worker ): 5977*da0073e9SAndroid Build Coastguard Worker nt.permute(1, 0, 2, 3) 5978*da0073e9SAndroid Build Coastguard Worker nt_permute = nt.permute(0, 2, 1, -1) 5979*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5980*da0073e9SAndroid Build Coastguard Worker nt_permute.shape, (nt_shape[0], nt_shape[2], nt_shape[1], nt_shape[3]) 5981*da0073e9SAndroid Build Coastguard Worker ) 5982*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5983*da0073e9SAndroid Build Coastguard Worker nt_permute.values().shape, 5984*da0073e9SAndroid Build Coastguard Worker (nt_inner_shape[1], nt_inner_shape[0], nt_inner_shape[2]), 5985*da0073e9SAndroid Build Coastguard Worker ) 5986*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_permute._ragged_idx, 2) 5987*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_permute.permute(0, 2, 1, 3), nt) 5988*da0073e9SAndroid Build Coastguard Worker 5989*da0073e9SAndroid Build Coastguard Worker def test_to_dtype(self, device): 5990*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 5991*da0073e9SAndroid Build Coastguard Worker [2, None, 3], device, torch.float32, layout=torch.jagged 5992*da0073e9SAndroid Build Coastguard Worker ) 5993*da0073e9SAndroid Build Coastguard Worker nt_after = nt.to(torch.float64) 5994*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.float32, nt.dtype) 5995*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.float64, nt_after.dtype) 5996*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.float64, nt_after.values().dtype) 5997*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.int64, nt_after.offsets().dtype) 5998*da0073e9SAndroid Build Coastguard Worker 5999*da0073e9SAndroid Build Coastguard Worker noncontiguous_nt = nt.transpose(1, 2) 6000*da0073e9SAndroid Build Coastguard Worker noncontiguous_nt_after = noncontiguous_nt.to(torch.bfloat16) 6001*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bfloat16, noncontiguous_nt_after.dtype) 6002*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bfloat16, noncontiguous_nt_after.values().dtype) 6003*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.int64, noncontiguous_nt_after.offsets().dtype) 6004*da0073e9SAndroid Build Coastguard Worker 6005*da0073e9SAndroid Build Coastguard Worker def test_to_copy(self, device): 6006*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 6007*da0073e9SAndroid Build Coastguard Worker [ 6008*da0073e9SAndroid Build Coastguard Worker torch.randn( 6009*da0073e9SAndroid Build Coastguard Worker i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device 6010*da0073e9SAndroid Build Coastguard Worker ) 6011*da0073e9SAndroid Build Coastguard Worker for i in range(3) 6012*da0073e9SAndroid Build Coastguard Worker ], 6013*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 6014*da0073e9SAndroid Build Coastguard Worker ) 6015*da0073e9SAndroid Build Coastguard Worker 6016*da0073e9SAndroid Build Coastguard Worker nt_copy_dtype = torch.ops.aten._to_copy(nt, dtype=torch.float16) 6017*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.float16, nt_copy_dtype.dtype) 6018*da0073e9SAndroid Build Coastguard Worker 6019*da0073e9SAndroid Build Coastguard Worker nt_t = nt.transpose(1, 2) 6020*da0073e9SAndroid Build Coastguard Worker nt_t_copy_dtype = torch.ops.aten._to_copy(nt_t, dtype=torch.float16) 6021*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.float16, nt_t_copy_dtype.dtype) 6022*da0073e9SAndroid Build Coastguard Worker 6023*da0073e9SAndroid Build Coastguard Worker def test_copy_(self, device): 6024*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 4], device=device) 6025*da0073e9SAndroid Build Coastguard Worker a = torch.nested.nested_tensor_from_jagged( 6026*da0073e9SAndroid Build Coastguard Worker torch.zeros(4, 3, device=device), offsets 6027*da0073e9SAndroid Build Coastguard Worker ) 6028*da0073e9SAndroid Build Coastguard Worker b = torch.nested.nested_tensor_from_jagged( 6029*da0073e9SAndroid Build Coastguard Worker torch.ones(4, 3, device=device), offsets 6030*da0073e9SAndroid Build Coastguard Worker ) 6031*da0073e9SAndroid Build Coastguard Worker a.copy_(b) 6032*da0073e9SAndroid Build Coastguard Worker torch._dynamo.disable(self.assertEqual)(a, b) 6033*da0073e9SAndroid Build Coastguard Worker 6034*da0073e9SAndroid Build Coastguard Worker offsets_2 = torch.tensor([0, 2, 4], device=device) 6035*da0073e9SAndroid Build Coastguard Worker c = torch.nested.nested_tensor_from_jagged( 6036*da0073e9SAndroid Build Coastguard Worker torch.ones(4, 3, device=device), offsets_2 6037*da0073e9SAndroid Build Coastguard Worker ) 6038*da0073e9SAndroid Build Coastguard Worker # fail when tensors have the same size but not the exact same offset tensor. 6039*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6040*da0073e9SAndroid Build Coastguard Worker RuntimeError, 6041*da0073e9SAndroid Build Coastguard Worker "copy_ only supports Nested Tensors that have same size and the exact same offset tensor.", 6042*da0073e9SAndroid Build Coastguard Worker ): 6043*da0073e9SAndroid Build Coastguard Worker a.copy_(c) 6044*da0073e9SAndroid Build Coastguard Worker 6045*da0073e9SAndroid Build Coastguard Worker # fail when tensors have different sizes 6046*da0073e9SAndroid Build Coastguard Worker a = a.transpose(1, 2) 6047*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6048*da0073e9SAndroid Build Coastguard Worker RuntimeError, 6049*da0073e9SAndroid Build Coastguard Worker "copy_ only supports Nested Tensors that have same size and the exact same offset tensor.", 6050*da0073e9SAndroid Build Coastguard Worker ): 6051*da0073e9SAndroid Build Coastguard Worker a.copy_(b) 6052*da0073e9SAndroid Build Coastguard Worker 6053*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Dynamo doesn't know how to trace prof.events()") 6054*da0073e9SAndroid Build Coastguard Worker def test_profiler_sequence_nr(self): 6055*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile() as prof: 6056*da0073e9SAndroid Build Coastguard Worker values = torch.randn(4, 6, requires_grad=True) 6057*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 4]) 6058*da0073e9SAndroid Build Coastguard Worker values = values * 2 6059*da0073e9SAndroid Build Coastguard Worker l = torch.nn.Linear(6, 8) 6060*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, offsets) 6061*da0073e9SAndroid Build Coastguard Worker 6062*da0073e9SAndroid Build Coastguard Worker nt = l(nt) 6063*da0073e9SAndroid Build Coastguard Worker val = nt.values() 6064*da0073e9SAndroid Build Coastguard Worker 6065*da0073e9SAndroid Build Coastguard Worker loss = val.sum() 6066*da0073e9SAndroid Build Coastguard Worker loss.backward() 6067*da0073e9SAndroid Build Coastguard Worker 6068*da0073e9SAndroid Build Coastguard Worker fwd_seq_nrs = [] 6069*da0073e9SAndroid Build Coastguard Worker for evt in prof.events(): 6070*da0073e9SAndroid Build Coastguard Worker if ( 6071*da0073e9SAndroid Build Coastguard Worker "linear" in evt.name.lower() 6072*da0073e9SAndroid Build Coastguard Worker and "backward" not in evt.name.lower() 6073*da0073e9SAndroid Build Coastguard Worker and evt.sequence_nr != -1 6074*da0073e9SAndroid Build Coastguard Worker ): 6075*da0073e9SAndroid Build Coastguard Worker fwd_seq_nrs.append(evt.sequence_nr) 6076*da0073e9SAndroid Build Coastguard Worker 6077*da0073e9SAndroid Build Coastguard Worker bwd_seq_nrs = [] 6078*da0073e9SAndroid Build Coastguard Worker for evt in prof.events(): 6079*da0073e9SAndroid Build Coastguard Worker if ( 6080*da0073e9SAndroid Build Coastguard Worker "linear" in evt.name.lower() 6081*da0073e9SAndroid Build Coastguard Worker and "backward" in evt.name.lower() 6082*da0073e9SAndroid Build Coastguard Worker and "evaluate_function" not in evt.name.lower() 6083*da0073e9SAndroid Build Coastguard Worker and evt.sequence_nr != -1 6084*da0073e9SAndroid Build Coastguard Worker ): 6085*da0073e9SAndroid Build Coastguard Worker bwd_seq_nrs.append(evt.sequence_nr) 6086*da0073e9SAndroid Build Coastguard Worker 6087*da0073e9SAndroid Build Coastguard Worker # There should only be one such event with a sequence number: 6088*da0073e9SAndroid Build Coastguard Worker # the PythonTLSSnapshot event - but, note that it's not terrible if 6089*da0073e9SAndroid Build Coastguard Worker # we end up with multiple events with the same sequence number - so we 6090*da0073e9SAndroid Build Coastguard Worker # could relax this check if it becomes inconvenient to maintain this 6091*da0073e9SAndroid Build Coastguard Worker # property. 6092*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(fwd_seq_nrs), 1) 6093*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(bwd_seq_nrs), 1) 6094*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwd_seq_nrs[0], bwd_seq_nrs[0]) 6095*da0073e9SAndroid Build Coastguard Worker 6096*da0073e9SAndroid Build Coastguard Worker def test_is_same_size(self, device): 6097*da0073e9SAndroid Build Coastguard Worker def get_3_tensors(): 6098*da0073e9SAndroid Build Coastguard Worker return [ 6099*da0073e9SAndroid Build Coastguard Worker torch.randn( 6100*da0073e9SAndroid Build Coastguard Worker i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device 6101*da0073e9SAndroid Build Coastguard Worker ) 6102*da0073e9SAndroid Build Coastguard Worker for i in range(3) 6103*da0073e9SAndroid Build Coastguard Worker ] 6104*da0073e9SAndroid Build Coastguard Worker 6105*da0073e9SAndroid Build Coastguard Worker nt1, offsets1 = jagged_from_list(get_3_tensors(), None) 6106*da0073e9SAndroid Build Coastguard Worker nt2, offsets1 = jagged_from_list(get_3_tensors(), offsets1) 6107*da0073e9SAndroid Build Coastguard Worker 6108*da0073e9SAndroid Build Coastguard Worker nt3, offsets2 = jagged_from_list(get_3_tensors(), None) 6109*da0073e9SAndroid Build Coastguard Worker nt4, offsets2 = jagged_from_list(get_3_tensors(), offsets2) 6110*da0073e9SAndroid Build Coastguard Worker 6111*da0073e9SAndroid Build Coastguard Worker def check_size(nt1, nt2, nt3, nt4): 6112*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.ops.aten.is_same_size(nt1, nt2)) 6113*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.ops.aten.is_same_size(nt3, nt4)) 6114*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.ops.aten.is_same_size(nt1, nt3)) 6115*da0073e9SAndroid Build Coastguard Worker 6116*da0073e9SAndroid Build Coastguard Worker check_size(nt1, nt2, nt3, nt4) 6117*da0073e9SAndroid Build Coastguard Worker 6118*da0073e9SAndroid Build Coastguard Worker nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4)) 6119*da0073e9SAndroid Build Coastguard Worker check_size(nt1_t, nt2_t, nt3_t, nt4_t) 6120*da0073e9SAndroid Build Coastguard Worker 6121*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compiles internally") 6122*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 6123*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6124*da0073e9SAndroid Build Coastguard Worker def test_specialize_dynamic_shape(self, device): 6125*da0073e9SAndroid Build Coastguard Worker values = torch.randn((18, 16), device=device) 6126*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 3, 6, 15, 18], device=device) 6127*da0073e9SAndroid Build Coastguard Worker like_values = torch.randn_like(values) 6128*da0073e9SAndroid Build Coastguard Worker 6129*da0073e9SAndroid Build Coastguard Worker # this marks values as dynamic 6130*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, offsets) 6131*da0073e9SAndroid Build Coastguard Worker 6132*da0073e9SAndroid Build Coastguard Worker def fn(values, same_size): 6133*da0073e9SAndroid Build Coastguard Worker # here, the dynamic shape is specialized by same_size's shape 6134*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/127097 6135*da0073e9SAndroid Build Coastguard Worker # make sure this doesn't error out in torch.compile 6136*da0073e9SAndroid Build Coastguard Worker return values + same_size 6137*da0073e9SAndroid Build Coastguard Worker 6138*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6139*da0073e9SAndroid Build Coastguard Worker fn(values, like_values), 6140*da0073e9SAndroid Build Coastguard Worker torch.compile(fn)(values, like_values), 6141*da0073e9SAndroid Build Coastguard Worker ) 6142*da0073e9SAndroid Build Coastguard Worker 6143*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compiles internally") 6144*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 6145*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6146*da0073e9SAndroid Build Coastguard Worker def test_specialize_dynamic_shape_recompile(self, device): 6147*da0073e9SAndroid Build Coastguard Worker def generate_inp(total_len): 6148*da0073e9SAndroid Build Coastguard Worker values = torch.randn((total_len, 16), device=device) 6149*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 3, 6, 15, total_len], device=device) 6150*da0073e9SAndroid Build Coastguard Worker like_values = torch.randn_like(values) 6151*da0073e9SAndroid Build Coastguard Worker return values, offsets, like_values 6152*da0073e9SAndroid Build Coastguard Worker 6153*da0073e9SAndroid Build Coastguard Worker def check_results(ref_fn, res_fn, args): 6154*da0073e9SAndroid Build Coastguard Worker values, offsets, like_values = args 6155*da0073e9SAndroid Build Coastguard Worker # this may add dynamic shape markings 6156*da0073e9SAndroid Build Coastguard Worker # goal of this test is to make sure that whatever markings are there, 6157*da0073e9SAndroid Build Coastguard Worker # we eventually stop recompiling as shape changes. 6158*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, offsets) 6159*da0073e9SAndroid Build Coastguard Worker 6160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_fn(values, like_values), res_fn(values, like_values)) 6161*da0073e9SAndroid Build Coastguard Worker 6162*da0073e9SAndroid Build Coastguard Worker def fn(values, same_size): 6163*da0073e9SAndroid Build Coastguard Worker return values + same_size 6164*da0073e9SAndroid Build Coastguard Worker 6165*da0073e9SAndroid Build Coastguard Worker compile_counter = torch._dynamo.testing.CompileCounter() 6166*da0073e9SAndroid Build Coastguard Worker 6167*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(compile_counter, nopython=True)(fn) 6168*da0073e9SAndroid Build Coastguard Worker check_results(fn, compiled_fn, generate_inp(18)) 6169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compile_counter.frame_count, 1) 6170*da0073e9SAndroid Build Coastguard Worker 6171*da0073e9SAndroid Build Coastguard Worker check_results(fn, compiled_fn, generate_inp(19)) 6172*da0073e9SAndroid Build Coastguard Worker # we'll probably recompile here with dynamic shapes - it's okay if not though. 6173*da0073e9SAndroid Build Coastguard Worker frame_count_2 = compile_counter.frame_count 6174*da0073e9SAndroid Build Coastguard Worker self.assertIn(frame_count_2, [1, 2]) 6175*da0073e9SAndroid Build Coastguard Worker 6176*da0073e9SAndroid Build Coastguard Worker # make sure that by now we've already compiled with dynamic shapes, so additional 6177*da0073e9SAndroid Build Coastguard Worker # shapes should not trigger additional recompiles. 6178*da0073e9SAndroid Build Coastguard Worker check_results(fn, compiled_fn, generate_inp(20)) 6179*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compile_counter.frame_count, frame_count_2) 6180*da0073e9SAndroid Build Coastguard Worker 6181*da0073e9SAndroid Build Coastguard Worker # Note 1: Math fallback doesn't work with bfloat16 on CUDA 6182*da0073e9SAndroid Build Coastguard Worker # Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT 6183*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 6184*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 6185*da0073e9SAndroid Build Coastguard Worker "ROCm doesn't support flash attention or mem_efficient attention for NT", 6186*da0073e9SAndroid Build Coastguard Worker ) 6187*da0073e9SAndroid Build Coastguard Worker @dtypes( 6188*da0073e9SAndroid Build Coastguard Worker *( 6189*da0073e9SAndroid Build Coastguard Worker [torch.float16, torch.bfloat16, torch.float32] 6190*da0073e9SAndroid Build Coastguard Worker if SM80OrLater 6191*da0073e9SAndroid Build Coastguard Worker else [torch.float16, torch.float32] 6192*da0073e9SAndroid Build Coastguard Worker ) 6193*da0073e9SAndroid Build Coastguard Worker ) 6194*da0073e9SAndroid Build Coastguard Worker def test_sdpa(self, device, dtype): 6195*da0073e9SAndroid Build Coastguard Worker batch_size = 1 6196*da0073e9SAndroid Build Coastguard Worker emb_dims = 128 6197*da0073e9SAndroid Build Coastguard Worker n_heads = 8 6198*da0073e9SAndroid Build Coastguard Worker head_dims = emb_dims // n_heads 6199*da0073e9SAndroid Build Coastguard Worker 6200*da0073e9SAndroid Build Coastguard Worker sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) 6201*da0073e9SAndroid Build Coastguard Worker sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) 6202*da0073e9SAndroid Build Coastguard Worker 6203*da0073e9SAndroid Build Coastguard Worker query = torch.nn.Linear( 6204*da0073e9SAndroid Build Coastguard Worker emb_dims, emb_dims, bias=False, device=device, dtype=dtype 6205*da0073e9SAndroid Build Coastguard Worker ) 6206*da0073e9SAndroid Build Coastguard Worker key = torch.nn.Linear( 6207*da0073e9SAndroid Build Coastguard Worker emb_dims, emb_dims, bias=False, device=device, dtype=dtype 6208*da0073e9SAndroid Build Coastguard Worker ) 6209*da0073e9SAndroid Build Coastguard Worker value = torch.nn.Linear( 6210*da0073e9SAndroid Build Coastguard Worker emb_dims, emb_dims, bias=False, device=device, dtype=dtype 6211*da0073e9SAndroid Build Coastguard Worker ) 6212*da0073e9SAndroid Build Coastguard Worker 6213*da0073e9SAndroid Build Coastguard Worker # Simplest case: 1 sentence, no batching 6214*da0073e9SAndroid Build Coastguard Worker x_d1 = sen1.unsqueeze(0) 6215*da0073e9SAndroid Build Coastguard Worker x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged) 6216*da0073e9SAndroid Build Coastguard Worker 6217*da0073e9SAndroid Build Coastguard Worker # See note below for why we detach here. 6218*da0073e9SAndroid Build Coastguard Worker q_d1 = ( 6219*da0073e9SAndroid Build Coastguard Worker query(x_d1) 6220*da0073e9SAndroid Build Coastguard Worker .view(batch_size, -1, n_heads, head_dims) 6221*da0073e9SAndroid Build Coastguard Worker .detach() 6222*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 6223*da0073e9SAndroid Build Coastguard Worker ) 6224*da0073e9SAndroid Build Coastguard Worker q_d1_t = q_d1.transpose(1, 2) 6225*da0073e9SAndroid Build Coastguard Worker k_d1 = ( 6226*da0073e9SAndroid Build Coastguard Worker key(x_d1) 6227*da0073e9SAndroid Build Coastguard Worker .view(batch_size, -1, n_heads, head_dims) 6228*da0073e9SAndroid Build Coastguard Worker .detach() 6229*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 6230*da0073e9SAndroid Build Coastguard Worker ) 6231*da0073e9SAndroid Build Coastguard Worker k_d1_t = k_d1.transpose(1, 2) 6232*da0073e9SAndroid Build Coastguard Worker v_d1 = ( 6233*da0073e9SAndroid Build Coastguard Worker value(x_d1) 6234*da0073e9SAndroid Build Coastguard Worker .view(batch_size, -1, n_heads, head_dims) 6235*da0073e9SAndroid Build Coastguard Worker .detach() 6236*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 6237*da0073e9SAndroid Build Coastguard Worker ) 6238*da0073e9SAndroid Build Coastguard Worker v_d1_t = v_d1.transpose(1, 2) 6239*da0073e9SAndroid Build Coastguard Worker 6240*da0073e9SAndroid Build Coastguard Worker q_nt = ( 6241*da0073e9SAndroid Build Coastguard Worker query(x_nt) 6242*da0073e9SAndroid Build Coastguard Worker .view(*x_nt.size()[0:2], n_heads, head_dims) 6243*da0073e9SAndroid Build Coastguard Worker .detach() 6244*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 6245*da0073e9SAndroid Build Coastguard Worker ) 6246*da0073e9SAndroid Build Coastguard Worker q_nt_t = q_nt.transpose(1, 2) 6247*da0073e9SAndroid Build Coastguard Worker k_nt = ( 6248*da0073e9SAndroid Build Coastguard Worker key(x_nt) 6249*da0073e9SAndroid Build Coastguard Worker .view(*x_nt.size()[0:2], n_heads, head_dims) 6250*da0073e9SAndroid Build Coastguard Worker .detach() 6251*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 6252*da0073e9SAndroid Build Coastguard Worker ) 6253*da0073e9SAndroid Build Coastguard Worker k_nt_t = k_nt.transpose(1, 2) 6254*da0073e9SAndroid Build Coastguard Worker v_nt = ( 6255*da0073e9SAndroid Build Coastguard Worker value(x_nt) 6256*da0073e9SAndroid Build Coastguard Worker .view(*x_nt.size()[0:2], n_heads, head_dims) 6257*da0073e9SAndroid Build Coastguard Worker .detach() 6258*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 6259*da0073e9SAndroid Build Coastguard Worker ) 6260*da0073e9SAndroid Build Coastguard Worker v_nt_t = v_nt.transpose(1, 2) 6261*da0073e9SAndroid Build Coastguard Worker 6262*da0073e9SAndroid Build Coastguard Worker # High Precision Math Reference 6263*da0073e9SAndroid Build Coastguard Worker q_d1_f32 = q_d1.to(torch.float32) 6264*da0073e9SAndroid Build Coastguard Worker k_d1_f32 = k_d1.to(torch.float32) 6265*da0073e9SAndroid Build Coastguard Worker v_d1_f32 = v_d1.to(torch.float32) 6266*da0073e9SAndroid Build Coastguard Worker q_d1_f32_t = q_d1_f32.transpose(1, 2) 6267*da0073e9SAndroid Build Coastguard Worker k_d1_f32_t = k_d1_f32.transpose(1, 2) 6268*da0073e9SAndroid Build Coastguard Worker v_d1_f32_t = v_d1_f32.transpose(1, 2) 6269*da0073e9SAndroid Build Coastguard Worker out_ref = torch.ops.aten._scaled_dot_product_attention_math( 6270*da0073e9SAndroid Build Coastguard Worker q_d1_f32_t, k_d1_f32_t, v_d1_f32_t 6271*da0073e9SAndroid Build Coastguard Worker )[0] 6272*da0073e9SAndroid Build Coastguard Worker grads_ref = torch.autograd.grad(out_ref.sum(), (q_d1_f32, k_d1_f32, v_d1_f32)) 6273*da0073e9SAndroid Build Coastguard Worker 6274*da0073e9SAndroid Build Coastguard Worker # Low Precision Math Reference 6275*da0073e9SAndroid Build Coastguard Worker out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 6276*da0073e9SAndroid Build Coastguard Worker q_d1_t, k_d1_t, v_d1_t 6277*da0073e9SAndroid Build Coastguard Worker )[0] 6278*da0073e9SAndroid Build Coastguard Worker grads_lp_ref = torch.autograd.grad(out_lp_ref.sum(), (q_d1, k_d1, v_d1)) 6279*da0073e9SAndroid Build Coastguard Worker 6280*da0073e9SAndroid Build Coastguard Worker # Compute tolerances 6281*da0073e9SAndroid Build Coastguard Worker output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) 6282*da0073e9SAndroid Build Coastguard Worker grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(grads_ref[0], grads_lp_ref[0]) 6283*da0073e9SAndroid Build Coastguard Worker grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(grads_ref[1], grads_lp_ref[1]) 6284*da0073e9SAndroid Build Coastguard Worker grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(grads_ref[2], grads_lp_ref[2]) 6285*da0073e9SAndroid Build Coastguard Worker grad_atols = [grad_q_ref_atol, grad_k_ref_atol, grad_v_ref_atol] 6286*da0073e9SAndroid Build Coastguard Worker grad_rtols = [grad_q_ref_rtol, grad_k_ref_rtol, grad_v_ref_rtol] 6287*da0073e9SAndroid Build Coastguard Worker 6288*da0073e9SAndroid Build Coastguard Worker attn_d1 = torch.nn.functional.scaled_dot_product_attention( 6289*da0073e9SAndroid Build Coastguard Worker q_d1_t, k_d1_t, v_d1_t 6290*da0073e9SAndroid Build Coastguard Worker ).transpose(1, 2) 6291*da0073e9SAndroid Build Coastguard Worker attn_nt = torch.nn.functional.scaled_dot_product_attention( 6292*da0073e9SAndroid Build Coastguard Worker q_nt_t, k_nt_t, v_nt_t 6293*da0073e9SAndroid Build Coastguard Worker ).transpose(1, 2) 6294*da0073e9SAndroid Build Coastguard Worker 6295*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6296*da0073e9SAndroid Build Coastguard Worker attn_d1, 6297*da0073e9SAndroid Build Coastguard Worker attn_nt.unbind()[0].unsqueeze(0), 6298*da0073e9SAndroid Build Coastguard Worker atol=output_ref_atol, 6299*da0073e9SAndroid Build Coastguard Worker rtol=output_ref_rtol, 6300*da0073e9SAndroid Build Coastguard Worker ) 6301*da0073e9SAndroid Build Coastguard Worker 6302*da0073e9SAndroid Build Coastguard Worker # Simple case: 2 sentences, no extra params 6303*da0073e9SAndroid Build Coastguard Worker x_d2 = sen2.unsqueeze(0) 6304*da0073e9SAndroid Build Coastguard Worker x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged) 6305*da0073e9SAndroid Build Coastguard Worker 6306*da0073e9SAndroid Build Coastguard Worker # NB: we make sure the leaf tensor we compute gradients for is the view-ed tensor before 6307*da0073e9SAndroid Build Coastguard Worker # it is transposed. This is because today we cannot backward through view or unbind a 6308*da0073e9SAndroid Build Coastguard Worker # transposed tensor. 6309*da0073e9SAndroid Build Coastguard Worker q_d2 = ( 6310*da0073e9SAndroid Build Coastguard Worker query(x_d2) 6311*da0073e9SAndroid Build Coastguard Worker .view(batch_size, -1, n_heads, head_dims) 6312*da0073e9SAndroid Build Coastguard Worker .detach() 6313*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 6314*da0073e9SAndroid Build Coastguard Worker ) 6315*da0073e9SAndroid Build Coastguard Worker q_d2_t = q_d2.transpose(1, 2) 6316*da0073e9SAndroid Build Coastguard Worker k_d2 = ( 6317*da0073e9SAndroid Build Coastguard Worker key(x_d2) 6318*da0073e9SAndroid Build Coastguard Worker .view(batch_size, -1, n_heads, head_dims) 6319*da0073e9SAndroid Build Coastguard Worker .detach() 6320*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 6321*da0073e9SAndroid Build Coastguard Worker ) 6322*da0073e9SAndroid Build Coastguard Worker k_d2_t = k_d2.transpose(1, 2) 6323*da0073e9SAndroid Build Coastguard Worker v_d2 = ( 6324*da0073e9SAndroid Build Coastguard Worker value(x_d2) 6325*da0073e9SAndroid Build Coastguard Worker .view(batch_size, -1, n_heads, head_dims) 6326*da0073e9SAndroid Build Coastguard Worker .detach() 6327*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 6328*da0073e9SAndroid Build Coastguard Worker ) 6329*da0073e9SAndroid Build Coastguard Worker v_d2_t = v_d2.transpose(1, 2) 6330*da0073e9SAndroid Build Coastguard Worker 6331*da0073e9SAndroid Build Coastguard Worker q_nt = ( 6332*da0073e9SAndroid Build Coastguard Worker query(x_nt) 6333*da0073e9SAndroid Build Coastguard Worker .view(*x_nt.size()[0:2], n_heads, head_dims) 6334*da0073e9SAndroid Build Coastguard Worker .detach() 6335*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 6336*da0073e9SAndroid Build Coastguard Worker ) 6337*da0073e9SAndroid Build Coastguard Worker q_nt_t = q_nt.transpose(1, 2) 6338*da0073e9SAndroid Build Coastguard Worker k_nt = ( 6339*da0073e9SAndroid Build Coastguard Worker key(x_nt) 6340*da0073e9SAndroid Build Coastguard Worker .view(*x_nt.size()[0:2], n_heads, head_dims) 6341*da0073e9SAndroid Build Coastguard Worker .detach() 6342*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 6343*da0073e9SAndroid Build Coastguard Worker ) 6344*da0073e9SAndroid Build Coastguard Worker k_nt_t = k_nt.transpose(1, 2) 6345*da0073e9SAndroid Build Coastguard Worker v_nt = ( 6346*da0073e9SAndroid Build Coastguard Worker value(x_nt) 6347*da0073e9SAndroid Build Coastguard Worker .view(*x_nt.size()[0:2], n_heads, head_dims) 6348*da0073e9SAndroid Build Coastguard Worker .detach() 6349*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 6350*da0073e9SAndroid Build Coastguard Worker ) 6351*da0073e9SAndroid Build Coastguard Worker v_nt_t = v_nt.transpose(1, 2) 6352*da0073e9SAndroid Build Coastguard Worker 6353*da0073e9SAndroid Build Coastguard Worker attn_d2 = torch.nn.functional.scaled_dot_product_attention( 6354*da0073e9SAndroid Build Coastguard Worker q_d2_t, k_d2_t, v_d2_t 6355*da0073e9SAndroid Build Coastguard Worker ).transpose(1, 2) 6356*da0073e9SAndroid Build Coastguard Worker d1_grads = torch.autograd.grad(attn_d1.sum(), (q_d1, k_d1, v_d1)) 6357*da0073e9SAndroid Build Coastguard Worker d2_grads = torch.autograd.grad(attn_d2.sum(), (q_d2, k_d2, v_d2)) 6358*da0073e9SAndroid Build Coastguard Worker 6359*da0073e9SAndroid Build Coastguard Worker # Simple case 3: batch_size = 1, seq_len = 1 6360*da0073e9SAndroid Build Coastguard Worker q_3 = torch.randn(1, 8, 16, dtype=dtype, device=device) 6361*da0073e9SAndroid Build Coastguard Worker q_nt_3 = torch.nested.as_nested_tensor([q_3], layout=torch.jagged) 6362*da0073e9SAndroid Build Coastguard Worker q_nt_3 = q_nt_3.transpose(1, 2) 6363*da0073e9SAndroid Build Coastguard Worker attn_out = torch.nn.functional.scaled_dot_product_attention( 6364*da0073e9SAndroid Build Coastguard Worker q_nt_3, q_nt_3, q_nt_3 6365*da0073e9SAndroid Build Coastguard Worker ) 6366*da0073e9SAndroid Build Coastguard Worker self.assertEqual(attn_out.shape, q_nt_3.shape) 6367*da0073e9SAndroid Build Coastguard Worker 6368*da0073e9SAndroid Build Coastguard Worker def check_forward_backward(): 6369*da0073e9SAndroid Build Coastguard Worker attn_nt = torch.nn.functional.scaled_dot_product_attention( 6370*da0073e9SAndroid Build Coastguard Worker q_nt_t, k_nt_t, v_nt_t 6371*da0073e9SAndroid Build Coastguard Worker ).transpose(1, 2) 6372*da0073e9SAndroid Build Coastguard Worker 6373*da0073e9SAndroid Build Coastguard Worker attn_nts = attn_nt.unbind() 6374*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6375*da0073e9SAndroid Build Coastguard Worker attn_d1, 6376*da0073e9SAndroid Build Coastguard Worker attn_nts[0].unsqueeze(0), 6377*da0073e9SAndroid Build Coastguard Worker atol=output_ref_atol, 6378*da0073e9SAndroid Build Coastguard Worker rtol=output_ref_rtol, 6379*da0073e9SAndroid Build Coastguard Worker ) 6380*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6381*da0073e9SAndroid Build Coastguard Worker attn_d2, 6382*da0073e9SAndroid Build Coastguard Worker attn_nts[1].unsqueeze(0), 6383*da0073e9SAndroid Build Coastguard Worker atol=output_ref_atol, 6384*da0073e9SAndroid Build Coastguard Worker rtol=output_ref_rtol, 6385*da0073e9SAndroid Build Coastguard Worker ) 6386*da0073e9SAndroid Build Coastguard Worker 6387*da0073e9SAndroid Build Coastguard Worker nt_grads = torch.autograd.grad(attn_nt.values().sum(), (q_nt, k_nt, v_nt)) 6388*da0073e9SAndroid Build Coastguard Worker for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip( 6389*da0073e9SAndroid Build Coastguard Worker nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols 6390*da0073e9SAndroid Build Coastguard Worker ): 6391*da0073e9SAndroid Build Coastguard Worker unbound_nt_grads = nt_grad.unbind() 6392*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6393*da0073e9SAndroid Build Coastguard Worker d1_grad, 6394*da0073e9SAndroid Build Coastguard Worker unbound_nt_grads[0].unsqueeze(0), 6395*da0073e9SAndroid Build Coastguard Worker atol=grad_atol, 6396*da0073e9SAndroid Build Coastguard Worker rtol=grad_rtol, 6397*da0073e9SAndroid Build Coastguard Worker ) 6398*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6399*da0073e9SAndroid Build Coastguard Worker d2_grad, 6400*da0073e9SAndroid Build Coastguard Worker unbound_nt_grads[1].unsqueeze(0), 6401*da0073e9SAndroid Build Coastguard Worker atol=grad_atol, 6402*da0073e9SAndroid Build Coastguard Worker rtol=grad_rtol, 6403*da0073e9SAndroid Build Coastguard Worker ) 6404*da0073e9SAndroid Build Coastguard Worker 6405*da0073e9SAndroid Build Coastguard Worker # Default 6406*da0073e9SAndroid Build Coastguard Worker check_forward_backward() 6407*da0073e9SAndroid Build Coastguard Worker 6408*da0073e9SAndroid Build Coastguard Worker # Test dispatcher works by calling only mem-effn and math (as they are safe for all devices) 6409*da0073e9SAndroid Build Coastguard Worker with torch.backends.cuda.sdp_kernel( 6410*da0073e9SAndroid Build Coastguard Worker enable_flash=False, enable_mem_efficient=True, enable_math=True 6411*da0073e9SAndroid Build Coastguard Worker ): 6412*da0073e9SAndroid Build Coastguard Worker check_forward_backward() 6413*da0073e9SAndroid Build Coastguard Worker 6414*da0073e9SAndroid Build Coastguard Worker # Test math fallback 6415*da0073e9SAndroid Build Coastguard Worker with torch.backends.cuda.sdp_kernel( 6416*da0073e9SAndroid Build Coastguard Worker enable_flash=False, enable_mem_efficient=False, enable_math=True 6417*da0073e9SAndroid Build Coastguard Worker ): 6418*da0073e9SAndroid Build Coastguard Worker # Math fallback doesn't work with bfloat16 on CUDA because 6419*da0073e9SAndroid Build Coastguard Worker # "group_gemm_dispatch" not implemented for 'BFloat16' 6420*da0073e9SAndroid Build Coastguard Worker if not (str(device).startswith("cuda") and dtype == torch.bfloat16): 6421*da0073e9SAndroid Build Coastguard Worker check_forward_backward() 6422*da0073e9SAndroid Build Coastguard Worker 6423*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("SDPA test compiles internally") 6424*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 6425*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6426*da0073e9SAndroid Build Coastguard Worker # Guarding with sqrt() doesn't work on ROCm? 6427*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm 6428*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 6429*da0073e9SAndroid Build Coastguard Worker @dtypes( 6430*da0073e9SAndroid Build Coastguard Worker *( 6431*da0073e9SAndroid Build Coastguard Worker [torch.float16, torch.bfloat16, torch.float32] 6432*da0073e9SAndroid Build Coastguard Worker if SM80OrLater 6433*da0073e9SAndroid Build Coastguard Worker else [torch.float16, torch.float32] 6434*da0073e9SAndroid Build Coastguard Worker ) 6435*da0073e9SAndroid Build Coastguard Worker ) 6436*da0073e9SAndroid Build Coastguard Worker def test_sdpa_compile(self, device, dtype): 6437*da0073e9SAndroid Build Coastguard Worker batch_size = 1 6438*da0073e9SAndroid Build Coastguard Worker emb_dims = 1024 6439*da0073e9SAndroid Build Coastguard Worker n_heads = 8 6440*da0073e9SAndroid Build Coastguard Worker head_dims = emb_dims // n_heads 6441*da0073e9SAndroid Build Coastguard Worker 6442*da0073e9SAndroid Build Coastguard Worker sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) 6443*da0073e9SAndroid Build Coastguard Worker sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) 6444*da0073e9SAndroid Build Coastguard Worker 6445*da0073e9SAndroid Build Coastguard Worker query = torch.nn.Linear( 6446*da0073e9SAndroid Build Coastguard Worker emb_dims, emb_dims, bias=False, device=device, dtype=dtype 6447*da0073e9SAndroid Build Coastguard Worker ) 6448*da0073e9SAndroid Build Coastguard Worker key = torch.nn.Linear( 6449*da0073e9SAndroid Build Coastguard Worker emb_dims, emb_dims, bias=False, device=device, dtype=dtype 6450*da0073e9SAndroid Build Coastguard Worker ) 6451*da0073e9SAndroid Build Coastguard Worker value = torch.nn.Linear( 6452*da0073e9SAndroid Build Coastguard Worker emb_dims, emb_dims, bias=False, device=device, dtype=dtype 6453*da0073e9SAndroid Build Coastguard Worker ) 6454*da0073e9SAndroid Build Coastguard Worker 6455*da0073e9SAndroid Build Coastguard Worker # Simplest case: 1 sentence, no batching 6456*da0073e9SAndroid Build Coastguard Worker x_d1 = sen1.unsqueeze(0) 6457*da0073e9SAndroid Build Coastguard Worker x_d2 = sen2.unsqueeze(0) 6458*da0073e9SAndroid Build Coastguard Worker x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged) 6459*da0073e9SAndroid Build Coastguard Worker 6460*da0073e9SAndroid Build Coastguard Worker q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) 6461*da0073e9SAndroid Build Coastguard Worker k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) 6462*da0073e9SAndroid Build Coastguard Worker v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) 6463*da0073e9SAndroid Build Coastguard Worker q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) 6464*da0073e9SAndroid Build Coastguard Worker k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) 6465*da0073e9SAndroid Build Coastguard Worker v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) 6466*da0073e9SAndroid Build Coastguard Worker 6467*da0073e9SAndroid Build Coastguard Worker q_nt = ( 6468*da0073e9SAndroid Build Coastguard Worker query(x_nt) 6469*da0073e9SAndroid Build Coastguard Worker .view(*x_nt.size()[0:2], n_heads, head_dims) 6470*da0073e9SAndroid Build Coastguard Worker .detach() 6471*da0073e9SAndroid Build Coastguard Worker .transpose(1, 2) 6472*da0073e9SAndroid Build Coastguard Worker ) 6473*da0073e9SAndroid Build Coastguard Worker k_nt = ( 6474*da0073e9SAndroid Build Coastguard Worker key(x_nt) 6475*da0073e9SAndroid Build Coastguard Worker .view(*x_nt.size()[0:2], n_heads, head_dims) 6476*da0073e9SAndroid Build Coastguard Worker .detach() 6477*da0073e9SAndroid Build Coastguard Worker .transpose(1, 2) 6478*da0073e9SAndroid Build Coastguard Worker ) 6479*da0073e9SAndroid Build Coastguard Worker v_nt = ( 6480*da0073e9SAndroid Build Coastguard Worker value(x_nt) 6481*da0073e9SAndroid Build Coastguard Worker .view(*x_nt.size()[0:2], n_heads, head_dims) 6482*da0073e9SAndroid Build Coastguard Worker .detach() 6483*da0073e9SAndroid Build Coastguard Worker .transpose(1, 2) 6484*da0073e9SAndroid Build Coastguard Worker ) 6485*da0073e9SAndroid Build Coastguard Worker 6486*da0073e9SAndroid Build Coastguard Worker # High Precision Math Reference 6487*da0073e9SAndroid Build Coastguard Worker q_d1_f32 = q_d1.to(torch.float32) 6488*da0073e9SAndroid Build Coastguard Worker k_d1_f32 = k_d1.to(torch.float32) 6489*da0073e9SAndroid Build Coastguard Worker v_d1_f32 = v_d1.to(torch.float32) 6490*da0073e9SAndroid Build Coastguard Worker out_ref = torch.ops.aten._scaled_dot_product_attention_math( 6491*da0073e9SAndroid Build Coastguard Worker q_d1_f32, k_d1_f32, v_d1_f32 6492*da0073e9SAndroid Build Coastguard Worker )[0] 6493*da0073e9SAndroid Build Coastguard Worker # Low Precision Math Reference 6494*da0073e9SAndroid Build Coastguard Worker out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 6495*da0073e9SAndroid Build Coastguard Worker q_d1, k_d1, v_d1 6496*da0073e9SAndroid Build Coastguard Worker )[0] 6497*da0073e9SAndroid Build Coastguard Worker output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) 6498*da0073e9SAndroid Build Coastguard Worker 6499*da0073e9SAndroid Build Coastguard Worker attn_d1 = torch.nn.functional.scaled_dot_product_attention( 6500*da0073e9SAndroid Build Coastguard Worker q_d1, k_d1, v_d1 6501*da0073e9SAndroid Build Coastguard Worker ).transpose(1, 2) 6502*da0073e9SAndroid Build Coastguard Worker attn_d2 = torch.nn.functional.scaled_dot_product_attention( 6503*da0073e9SAndroid Build Coastguard Worker q_d2, k_d2, v_d2 6504*da0073e9SAndroid Build Coastguard Worker ).transpose(1, 2) 6505*da0073e9SAndroid Build Coastguard Worker 6506*da0073e9SAndroid Build Coastguard Worker compiled_sdpa = torch.compile(torch.nn.functional.scaled_dot_product_attention) 6507*da0073e9SAndroid Build Coastguard Worker attn_nt = compiled_sdpa(q_nt, k_nt, v_nt).transpose(1, 2) 6508*da0073e9SAndroid Build Coastguard Worker 6509*da0073e9SAndroid Build Coastguard Worker attn_nts = attn_nt.unbind() 6510*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6511*da0073e9SAndroid Build Coastguard Worker attn_d1, 6512*da0073e9SAndroid Build Coastguard Worker attn_nts[0].unsqueeze(0), 6513*da0073e9SAndroid Build Coastguard Worker atol=output_ref_atol, 6514*da0073e9SAndroid Build Coastguard Worker rtol=output_ref_rtol, 6515*da0073e9SAndroid Build Coastguard Worker ) 6516*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6517*da0073e9SAndroid Build Coastguard Worker attn_d2, 6518*da0073e9SAndroid Build Coastguard Worker attn_nts[1].unsqueeze(0), 6519*da0073e9SAndroid Build Coastguard Worker atol=output_ref_atol, 6520*da0073e9SAndroid Build Coastguard Worker rtol=output_ref_rtol, 6521*da0073e9SAndroid Build Coastguard Worker ) 6522*da0073e9SAndroid Build Coastguard Worker 6523*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.double, torch.half) 6524*da0073e9SAndroid Build Coastguard Worker def test_sdpa_with_constant_sequence_length(self, device, dtype): 6525*da0073e9SAndroid Build Coastguard Worker # shape (B, P*, S, D) 6526*da0073e9SAndroid Build Coastguard Worker # B: batch size 6527*da0073e9SAndroid Build Coastguard Worker # P*: ragged number of prompts 6528*da0073e9SAndroid Build Coastguard Worker # S: (constant) sequence length 6529*da0073e9SAndroid Build Coastguard Worker # D: embedding size 6530*da0073e9SAndroid Build Coastguard Worker query = random_nt_from_dims( 6531*da0073e9SAndroid Build Coastguard Worker [4, None, 8, 10], 6532*da0073e9SAndroid Build Coastguard Worker device=device, 6533*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 6534*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 6535*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 6536*da0073e9SAndroid Build Coastguard Worker ) 6537*da0073e9SAndroid Build Coastguard Worker key = random_nt_from_similar(query) 6538*da0073e9SAndroid Build Coastguard Worker value = random_nt_from_similar(query) 6539*da0073e9SAndroid Build Coastguard Worker output = F.scaled_dot_product_attention(query, key, value) 6540*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(output, NestedTensor)) 6541*da0073e9SAndroid Build Coastguard Worker output.values().sum().backward() 6542*da0073e9SAndroid Build Coastguard Worker 6543*da0073e9SAndroid Build Coastguard Worker query_dense = query.clone().detach().requires_grad_(True) 6544*da0073e9SAndroid Build Coastguard Worker # should be equivalent to just running the buffers through 6545*da0073e9SAndroid Build Coastguard Worker output_dense = F.scaled_dot_product_attention( 6546*da0073e9SAndroid Build Coastguard Worker query_dense.values(), key.values(), value.values() 6547*da0073e9SAndroid Build Coastguard Worker ) 6548*da0073e9SAndroid Build Coastguard Worker torch._dynamo.disable(self.assertEqual)(output._values, output_dense) 6549*da0073e9SAndroid Build Coastguard Worker output_dense.sum().backward() 6550*da0073e9SAndroid Build Coastguard Worker torch._dynamo.disable(self.assertEqual)(query.grad, query_dense.grad) 6551*da0073e9SAndroid Build Coastguard Worker 6552*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 6553*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 6554*da0073e9SAndroid Build Coastguard Worker not PLATFORM_SUPPORTS_FUSED_ATTENTION, 6555*da0073e9SAndroid Build Coastguard Worker "Platform doesn't support flash or mem-efficient attention", 6556*da0073e9SAndroid Build Coastguard Worker ) 6557*da0073e9SAndroid Build Coastguard Worker @dtypes( 6558*da0073e9SAndroid Build Coastguard Worker *( 6559*da0073e9SAndroid Build Coastguard Worker [torch.float16, torch.bfloat16, torch.float32] 6560*da0073e9SAndroid Build Coastguard Worker if SM80OrLater 6561*da0073e9SAndroid Build Coastguard Worker else [torch.float16, torch.float32] 6562*da0073e9SAndroid Build Coastguard Worker ) 6563*da0073e9SAndroid Build Coastguard Worker ) 6564*da0073e9SAndroid Build Coastguard Worker def test_sdpa_with_packed_in_proj(self, device, dtype): 6565*da0073e9SAndroid Build Coastguard Worker # shape (B, *, D) 6566*da0073e9SAndroid Build Coastguard Worker input_packed = random_nt_from_dims( 6567*da0073e9SAndroid Build Coastguard Worker [5, None, 10], device=device, dtype=dtype, layout=torch.jagged 6568*da0073e9SAndroid Build Coastguard Worker ) 6569*da0073e9SAndroid Build Coastguard Worker 6570*da0073e9SAndroid Build Coastguard Worker # Do input projection. 6571*da0073e9SAndroid Build Coastguard Worker num_heads = 2 6572*da0073e9SAndroid Build Coastguard Worker # should be multiple of 4 for efficient kernels (e.g. flash / mem-efficient) 6573*da0073e9SAndroid Build Coastguard Worker head_dim = 8 6574*da0073e9SAndroid Build Coastguard Worker qkv_linear = torch.nn.Linear(10, num_heads * head_dim * 3).to( 6575*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype 6576*da0073e9SAndroid Build Coastguard Worker ) 6577*da0073e9SAndroid Build Coastguard Worker 6578*da0073e9SAndroid Build Coastguard Worker def in_proj(input_packed, qkv_linear=qkv_linear): 6579*da0073e9SAndroid Build Coastguard Worker qkv_post_proj = qkv_linear(input_packed) 6580*da0073e9SAndroid Build Coastguard Worker # these are non-contiguous to trigger _is_safe_to_get_storage_as_tensor() 6581*da0073e9SAndroid Build Coastguard Worker q, k, v = qkv_post_proj.chunk(3, dim=-1) 6582*da0073e9SAndroid Build Coastguard Worker q = q.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3) 6583*da0073e9SAndroid Build Coastguard Worker k = k.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3) 6584*da0073e9SAndroid Build Coastguard Worker v = v.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3) 6585*da0073e9SAndroid Build Coastguard Worker return q, k, v 6586*da0073e9SAndroid Build Coastguard Worker 6587*da0073e9SAndroid Build Coastguard Worker q, k, v = in_proj(input_packed) 6588*da0073e9SAndroid Build Coastguard Worker output = F.scaled_dot_product_attention(q, k, v, attn_mask=None) 6589*da0073e9SAndroid Build Coastguard Worker 6590*da0073e9SAndroid Build Coastguard Worker # compare to individually running unbound components through 6591*da0073e9SAndroid Build Coastguard Worker for in_component, out_component in zip( 6592*da0073e9SAndroid Build Coastguard Worker input_packed.unbind(), output.transpose(-2, -3).unbind() 6593*da0073e9SAndroid Build Coastguard Worker ): 6594*da0073e9SAndroid Build Coastguard Worker q, k, v = in_proj(in_component) 6595*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(q, k, v).transpose(-2, -3) 6596*da0073e9SAndroid Build Coastguard Worker 6597*da0073e9SAndroid Build Coastguard Worker # Low Precision Math Reference 6598*da0073e9SAndroid Build Coastguard Worker out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q, k, v)[ 6599*da0073e9SAndroid Build Coastguard Worker 0 6600*da0073e9SAndroid Build Coastguard Worker ].transpose(-2, -3) 6601*da0073e9SAndroid Build Coastguard Worker output_ref_atol, output_ref_rtol = get_tolerances( 6602*da0073e9SAndroid Build Coastguard Worker out, out_lp_ref, fudge_factor=2 6603*da0073e9SAndroid Build Coastguard Worker ) 6604*da0073e9SAndroid Build Coastguard Worker 6605*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6606*da0073e9SAndroid Build Coastguard Worker out, out_component, atol=output_ref_atol, rtol=output_ref_rtol 6607*da0073e9SAndroid Build Coastguard Worker ) 6608*da0073e9SAndroid Build Coastguard Worker 6609*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("SDPA test compiles internally") 6610*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 6611*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6612*da0073e9SAndroid Build Coastguard Worker # mha_varlen_fwd not supported on ROCm 6613*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm 6614*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 6615*da0073e9SAndroid Build Coastguard Worker @dtypes( 6616*da0073e9SAndroid Build Coastguard Worker *( 6617*da0073e9SAndroid Build Coastguard Worker [torch.float16, torch.bfloat16, torch.float32] 6618*da0073e9SAndroid Build Coastguard Worker if SM80OrLater 6619*da0073e9SAndroid Build Coastguard Worker else [torch.float16, torch.float32] 6620*da0073e9SAndroid Build Coastguard Worker ) 6621*da0073e9SAndroid Build Coastguard Worker ) 6622*da0073e9SAndroid Build Coastguard Worker def test_sdpa_backwards(self, device, dtype): 6623*da0073e9SAndroid Build Coastguard Worker values = torch.randn(9, 3, 256, requires_grad=True, device=device, dtype=dtype) 6624*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 1, 3, 5, 9], device=device, dtype=torch.int64) 6625*da0073e9SAndroid Build Coastguard Worker 6626*da0073e9SAndroid Build Coastguard Worker @torch.compile 6627*da0073e9SAndroid Build Coastguard Worker def f(values, offsets): 6628*da0073e9SAndroid Build Coastguard Worker nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4) 6629*da0073e9SAndroid Build Coastguard Worker nt = nt.transpose(-2, -3) 6630*da0073e9SAndroid Build Coastguard Worker # purposefully graph break to trigger view replay for subclass view input 6631*da0073e9SAndroid Build Coastguard Worker torch.tensor(1).item() 6632*da0073e9SAndroid Build Coastguard Worker output = F.scaled_dot_product_attention(nt, nt, nt).transpose(-2, -3) 6633*da0073e9SAndroid Build Coastguard Worker return convert_nt_to_jagged(output) 6634*da0073e9SAndroid Build Coastguard Worker 6635*da0073e9SAndroid Build Coastguard Worker output = f(values, offsets) 6636*da0073e9SAndroid Build Coastguard Worker output.sum().backward() 6637*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values.grad, torch.ones_like(values)) 6638*da0073e9SAndroid Build Coastguard Worker 6639*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 6640*da0073e9SAndroid Build Coastguard Worker not PLATFORM_SUPPORTS_FUSED_ATTENTION, 6641*da0073e9SAndroid Build Coastguard Worker "Platform doesn't support flash or mem-efficient attention", 6642*da0073e9SAndroid Build Coastguard Worker ) 6643*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6644*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm 6645*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 6646*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 6647*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 6648*da0073e9SAndroid Build Coastguard Worker def test_sdpa_autocast(self, device): 6649*da0073e9SAndroid Build Coastguard Worker def fn_nt(values32, values16, offsets): 6650*da0073e9SAndroid Build Coastguard Worker nt32 = convert_jagged_to_nested_tensor(values32, offsets, max_length=16) 6651*da0073e9SAndroid Build Coastguard Worker nt16 = convert_jagged_to_nested_tensor(values16, offsets, max_length=16) 6652*da0073e9SAndroid Build Coastguard Worker nt32 = nt32.transpose(1, 2) 6653*da0073e9SAndroid Build Coastguard Worker nt16 = nt16.transpose(1, 2) 6654*da0073e9SAndroid Build Coastguard Worker return F.scaled_dot_product_attention(nt32, nt16, nt32) 6655*da0073e9SAndroid Build Coastguard Worker 6656*da0073e9SAndroid Build Coastguard Worker def fn_dense(x32, x16): 6657*da0073e9SAndroid Build Coastguard Worker x32 = x32.view(8, 16, 4, 16).transpose(1, 2) 6658*da0073e9SAndroid Build Coastguard Worker x16 = x16.view(8, 16, 4, 16).transpose(1, 2) 6659*da0073e9SAndroid Build Coastguard Worker return F.scaled_dot_product_attention(x32, x16, x32) 6660*da0073e9SAndroid Build Coastguard Worker 6661*da0073e9SAndroid Build Coastguard Worker values32 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float32) 6662*da0073e9SAndroid Build Coastguard Worker values16 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float16) 6663*da0073e9SAndroid Build Coastguard Worker offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32) 6664*da0073e9SAndroid Build Coastguard Worker 6665*da0073e9SAndroid Build Coastguard Worker x32 = values32.clone() 6666*da0073e9SAndroid Build Coastguard Worker x16 = values16.clone() 6667*da0073e9SAndroid Build Coastguard Worker 6668*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type="cuda", dtype=torch.float16): 6669*da0073e9SAndroid Build Coastguard Worker out_dense_eager = fn_dense(x32, x16) 6670*da0073e9SAndroid Build Coastguard Worker out_dense_compiled = torch.compile(fn_dense)(x32, x16) 6671*da0073e9SAndroid Build Coastguard Worker out_nt_eager = fn_nt(values32, values16, offsets) 6672*da0073e9SAndroid Build Coastguard Worker out_nt_compiled = torch.compile(fn_nt)(values32, values16, offsets) 6673*da0073e9SAndroid Build Coastguard Worker 6674*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_dense_eager, out_dense_compiled) 6675*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6676*da0073e9SAndroid Build Coastguard Worker out_dense_eager.transpose(1, 2), 6677*da0073e9SAndroid Build Coastguard Worker out_nt_eager.values().transpose(0, 1).view(8, 16, 4, 16), 6678*da0073e9SAndroid Build Coastguard Worker ) 6679*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6680*da0073e9SAndroid Build Coastguard Worker out_dense_eager.transpose(1, 2), 6681*da0073e9SAndroid Build Coastguard Worker out_nt_compiled.values().transpose(0, 1).view(8, 16, 4, 16), 6682*da0073e9SAndroid Build Coastguard Worker ) 6683*da0073e9SAndroid Build Coastguard Worker 6684*da0073e9SAndroid Build Coastguard Worker def get_values(): 6685*da0073e9SAndroid Build Coastguard Worker return tuple( 6686*da0073e9SAndroid Build Coastguard Worker x.clone().detach().requires_grad_(True) for x in (values32, values16) 6687*da0073e9SAndroid Build Coastguard Worker ) 6688*da0073e9SAndroid Build Coastguard Worker 6689*da0073e9SAndroid Build Coastguard Worker v32_dense_eager, v16_dense_eager = get_values() 6690*da0073e9SAndroid Build Coastguard Worker v32_dense_compile, v16_dense_compile = get_values() 6691*da0073e9SAndroid Build Coastguard Worker v32_nt_eager, v16_nt_eager = get_values() 6692*da0073e9SAndroid Build Coastguard Worker v32_nt_compile, v16_nt_compile = get_values() 6693*da0073e9SAndroid Build Coastguard Worker 6694*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type="cuda", dtype=torch.float16): 6695*da0073e9SAndroid Build Coastguard Worker loss_dense_eager = fn_dense(v32_dense_eager, v16_dense_eager).sum() 6696*da0073e9SAndroid Build Coastguard Worker loss_dense_compile = torch.compile(fn_dense)( 6697*da0073e9SAndroid Build Coastguard Worker v32_dense_compile, v16_dense_compile 6698*da0073e9SAndroid Build Coastguard Worker ).sum() 6699*da0073e9SAndroid Build Coastguard Worker loss_nt_eager = fn_nt(v32_nt_eager, v16_nt_eager, offsets).values().sum() 6700*da0073e9SAndroid Build Coastguard Worker loss_nt_compile = ( 6701*da0073e9SAndroid Build Coastguard Worker torch.compile(fn_nt)(v32_nt_compile, v16_nt_compile, offsets) 6702*da0073e9SAndroid Build Coastguard Worker .values() 6703*da0073e9SAndroid Build Coastguard Worker .sum() 6704*da0073e9SAndroid Build Coastguard Worker ) 6705*da0073e9SAndroid Build Coastguard Worker 6706*da0073e9SAndroid Build Coastguard Worker loss_dense_eager.backward() 6707*da0073e9SAndroid Build Coastguard Worker loss_dense_compile.backward() 6708*da0073e9SAndroid Build Coastguard Worker loss_nt_eager.backward() 6709*da0073e9SAndroid Build Coastguard Worker loss_nt_compile.backward() 6710*da0073e9SAndroid Build Coastguard Worker 6711*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v32_dense_eager.grad, v32_dense_compile.grad) 6712*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v32_dense_eager.grad, v32_nt_eager.grad) 6713*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v32_dense_eager.grad, v32_nt_compile.grad) 6714*da0073e9SAndroid Build Coastguard Worker 6715*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v16_dense_eager.grad, v16_dense_compile.grad) 6716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v16_dense_eager.grad, v16_nt_eager.grad) 6717*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v16_dense_eager.grad, v16_nt_compile.grad) 6718*da0073e9SAndroid Build Coastguard Worker 6719*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 6720*da0073e9SAndroid Build Coastguard Worker not PLATFORM_SUPPORTS_FUSED_ATTENTION, 6721*da0073e9SAndroid Build Coastguard Worker "Platform doesn't support flash or mem-efficient attention", 6722*da0073e9SAndroid Build Coastguard Worker ) 6723*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6724*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm 6725*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 6726*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 6727*da0073e9SAndroid Build Coastguard Worker def test_sdpa_flop_counter(self, device): 6728*da0073e9SAndroid Build Coastguard Worker from torch.utils.flop_counter import FlopCounterMode 6729*da0073e9SAndroid Build Coastguard Worker 6730*da0073e9SAndroid Build Coastguard Worker def get_flops(nt): 6731*da0073e9SAndroid Build Coastguard Worker flop_counter = FlopCounterMode(display=False) 6732*da0073e9SAndroid Build Coastguard Worker with flop_counter: 6733*da0073e9SAndroid Build Coastguard Worker ret = torch.nn.functional.scaled_dot_product_attention(nt, nt, nt) 6734*da0073e9SAndroid Build Coastguard Worker ret.values().sum().backward() 6735*da0073e9SAndroid Build Coastguard Worker return flop_counter.get_total_flops() 6736*da0073e9SAndroid Build Coastguard Worker 6737*da0073e9SAndroid Build Coastguard Worker values = torch.randn( 6738*da0073e9SAndroid Build Coastguard Worker (8 * 16, 4, 16), requires_grad=True, device=device, dtype=torch.float16 6739*da0073e9SAndroid Build Coastguard Worker ) 6740*da0073e9SAndroid Build Coastguard Worker offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32) 6741*da0073e9SAndroid Build Coastguard Worker nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16) 6742*da0073e9SAndroid Build Coastguard Worker 6743*da0073e9SAndroid Build Coastguard Worker values_meta = torch.randn( 6744*da0073e9SAndroid Build Coastguard Worker (8 * 16, 4, 16), requires_grad=True, device="meta", dtype=torch.float16 6745*da0073e9SAndroid Build Coastguard Worker ) 6746*da0073e9SAndroid Build Coastguard Worker offsets_meta = torch.arange(0, 8 * 16 + 1, 16, device="meta", dtype=torch.int32) 6747*da0073e9SAndroid Build Coastguard Worker nt_meta = convert_jagged_to_nested_tensor(values, offsets, max_length=16) 6748*da0073e9SAndroid Build Coastguard Worker 6749*da0073e9SAndroid Build Coastguard Worker self.assertEqual(get_flops(nt), get_flops(nt_meta)) 6750*da0073e9SAndroid Build Coastguard Worker 6751*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 6752*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_activation_checkpoint(self, device): 6753*da0073e9SAndroid Build Coastguard Worker values = torch.randn( 6754*da0073e9SAndroid Build Coastguard Worker 9, 3, 256, requires_grad=True, device=device, dtype=torch.float32 6755*da0073e9SAndroid Build Coastguard Worker ) 6756*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([1, 2, 3, 3], device=device, dtype=torch.int64) 6757*da0073e9SAndroid Build Coastguard Worker offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0) 6758*da0073e9SAndroid Build Coastguard Worker 6759*da0073e9SAndroid Build Coastguard Worker def fn(values, offsets): 6760*da0073e9SAndroid Build Coastguard Worker nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4) 6761*da0073e9SAndroid Build Coastguard Worker return convert_nt_to_jagged(nt).sum() 6762*da0073e9SAndroid Build Coastguard Worker 6763*da0073e9SAndroid Build Coastguard Worker checkpoint(fn, values, offsets, use_reentrant=False).backward() 6764*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(values.grad) 6765*da0073e9SAndroid Build Coastguard Worker 6766*da0073e9SAndroid Build Coastguard Worker context_fn = partial( 6767*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts, [torch.ops.aten.cumsum.default] 6768*da0073e9SAndroid Build Coastguard Worker ) 6769*da0073e9SAndroid Build Coastguard Worker 6770*da0073e9SAndroid Build Coastguard Worker values.grad = None 6771*da0073e9SAndroid Build Coastguard Worker 6772*da0073e9SAndroid Build Coastguard Worker def fn(values, lengths): 6773*da0073e9SAndroid Build Coastguard Worker offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0) 6774*da0073e9SAndroid Build Coastguard Worker nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4) 6775*da0073e9SAndroid Build Coastguard Worker return convert_nt_to_jagged(nt).sum() 6776*da0073e9SAndroid Build Coastguard Worker 6777*da0073e9SAndroid Build Coastguard Worker checkpoint( 6778*da0073e9SAndroid Build Coastguard Worker fn, values, lengths, use_reentrant=False, context_fn=context_fn 6779*da0073e9SAndroid Build Coastguard Worker ).backward() 6780*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(values.grad) 6781*da0073e9SAndroid Build Coastguard Worker 6782*da0073e9SAndroid Build Coastguard Worker # Internally-defined NT use cases are lifted to here for maximum test realism. 6783*da0073e9SAndroid Build Coastguard Worker # TODO: Remove these when ViewNestedFromBuffer, etc. are deprecated. 6784*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm # not needed 6785*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compiles internally") 6786*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 6787*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 6788*da0073e9SAndroid Build Coastguard Worker @parametrize("use_legacy_api", [True, False]) 6789*da0073e9SAndroid Build Coastguard Worker @skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644") 6790*da0073e9SAndroid Build Coastguard Worker def test_dummy_mha_with_nt(self, device, use_legacy_api): 6791*da0073e9SAndroid Build Coastguard Worker bs = 3 6792*da0073e9SAndroid Build Coastguard Worker d1 = 2 6793*da0073e9SAndroid Build Coastguard Worker d2 = 4 6794*da0073e9SAndroid Build Coastguard Worker d3 = 16 6795*da0073e9SAndroid Build Coastguard Worker n_heads = 2 6796*da0073e9SAndroid Build Coastguard Worker d_head = d3 // n_heads 6797*da0073e9SAndroid Build Coastguard Worker max_length_1 = 10 6798*da0073e9SAndroid Build Coastguard Worker max_length_2 = 20 6799*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(0) 6800*da0073e9SAndroid Build Coastguard Worker 6801*da0073e9SAndroid Build Coastguard Worker class mha(torch.nn.Module): 6802*da0073e9SAndroid Build Coastguard Worker def __init__(self, use_legacy_api) -> None: 6803*da0073e9SAndroid Build Coastguard Worker super().__init__() 6804*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(0) 6805*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(d2, d3, device=device) 6806*da0073e9SAndroid Build Coastguard Worker self.use_legacy_api = use_legacy_api 6807*da0073e9SAndroid Build Coastguard Worker 6808*da0073e9SAndroid Build Coastguard Worker def forward(self, query, value, offsets): 6809*da0073e9SAndroid Build Coastguard Worker value = self.linear(value) 6810*da0073e9SAndroid Build Coastguard Worker if self.use_legacy_api: 6811*da0073e9SAndroid Build Coastguard Worker key = convert_jagged_to_nested_tensor_legacy( 6812*da0073e9SAndroid Build Coastguard Worker value, offsets, max_length_1 6813*da0073e9SAndroid Build Coastguard Worker ) 6814*da0073e9SAndroid Build Coastguard Worker value = convert_jagged_to_nested_tensor_legacy( 6815*da0073e9SAndroid Build Coastguard Worker value, offsets, max_length_2 6816*da0073e9SAndroid Build Coastguard Worker ) 6817*da0073e9SAndroid Build Coastguard Worker query = convert_dense_to_nested_tensor_legacy(query) 6818*da0073e9SAndroid Build Coastguard Worker else: 6819*da0073e9SAndroid Build Coastguard Worker key = convert_jagged_to_nested_tensor(value, offsets, max_length_1) 6820*da0073e9SAndroid Build Coastguard Worker value = convert_jagged_to_nested_tensor( 6821*da0073e9SAndroid Build Coastguard Worker value, offsets, max_length_2 6822*da0073e9SAndroid Build Coastguard Worker ) 6823*da0073e9SAndroid Build Coastguard Worker query = convert_dense_to_nested_tensor(query) 6824*da0073e9SAndroid Build Coastguard Worker q = query.view(bs, -1, n_heads, d_head).transpose(1, 2) 6825*da0073e9SAndroid Build Coastguard Worker k = key.view(bs, -1, n_heads, d_head).transpose(1, 2) 6826*da0073e9SAndroid Build Coastguard Worker v = value.view(bs, -1, n_heads, d_head).transpose(1, 2) 6827*da0073e9SAndroid Build Coastguard Worker 6828*da0073e9SAndroid Build Coastguard Worker with torch.nn.attention.sdpa_kernel( 6829*da0073e9SAndroid Build Coastguard Worker [ 6830*da0073e9SAndroid Build Coastguard Worker torch.nn.attention.SDPBackend.FLASH_ATTENTION, 6831*da0073e9SAndroid Build Coastguard Worker torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, 6832*da0073e9SAndroid Build Coastguard Worker ] 6833*da0073e9SAndroid Build Coastguard Worker ): 6834*da0073e9SAndroid Build Coastguard Worker attn_output = torch.nn.functional.scaled_dot_product_attention( 6835*da0073e9SAndroid Build Coastguard Worker q, 6836*da0073e9SAndroid Build Coastguard Worker k, 6837*da0073e9SAndroid Build Coastguard Worker v, 6838*da0073e9SAndroid Build Coastguard Worker attn_mask=None, 6839*da0073e9SAndroid Build Coastguard Worker dropout_p=0.0, 6840*da0073e9SAndroid Build Coastguard Worker is_causal=False, 6841*da0073e9SAndroid Build Coastguard Worker ) 6842*da0073e9SAndroid Build Coastguard Worker attn_output = attn_output.transpose(1, 2) 6843*da0073e9SAndroid Build Coastguard Worker if self.use_legacy_api: 6844*da0073e9SAndroid Build Coastguard Worker attn_output = convert_nt_to_jagged_legacy(attn_output) 6845*da0073e9SAndroid Build Coastguard Worker else: 6846*da0073e9SAndroid Build Coastguard Worker attn_output = convert_nt_to_jagged(attn_output) 6847*da0073e9SAndroid Build Coastguard Worker return attn_output, key._max_seqlen, value._max_seqlen 6848*da0073e9SAndroid Build Coastguard Worker 6849*da0073e9SAndroid Build Coastguard Worker query = torch.rand(bs, d1, d3, device=device) 6850*da0073e9SAndroid Build Coastguard Worker value = torch.rand(30, d2, requires_grad=True, device=device) 6851*da0073e9SAndroid Build Coastguard Worker # total_length must > than max_length otherwise flash_attn backwark will fail 6852*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 3, 30], device=device) 6853*da0073e9SAndroid Build Coastguard Worker 6854*da0073e9SAndroid Build Coastguard Worker m = mha(use_legacy_api) 6855*da0073e9SAndroid Build Coastguard Worker symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m) 6856*da0073e9SAndroid Build Coastguard Worker m = torch.compile(symbolic_traced) 6857*da0073e9SAndroid Build Coastguard Worker attn_output, cached_key_max_seqlen, cached_value_max_seqlen = m( 6858*da0073e9SAndroid Build Coastguard Worker query, value, offsets 6859*da0073e9SAndroid Build Coastguard Worker ) 6860*da0073e9SAndroid Build Coastguard Worker loss = attn_output.sum() 6861*da0073e9SAndroid Build Coastguard Worker # Check that NT can be fx traced and torch.compile, and backward works 6862*da0073e9SAndroid Build Coastguard Worker loss.backward() 6863*da0073e9SAndroid Build Coastguard Worker 6864*da0073e9SAndroid Build Coastguard Worker # Check that value.requires_grad is not lost after tracing and compiling 6865*da0073e9SAndroid Build Coastguard Worker value_grad = value.grad # save for comparison later 6866*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(value_grad) 6867*da0073e9SAndroid Build Coastguard Worker # check that max_seqlen is cached properly 6868*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cached_key_max_seqlen, max_length_1) 6869*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cached_value_max_seqlen, max_length_2) 6870*da0073e9SAndroid Build Coastguard Worker 6871*da0073e9SAndroid Build Coastguard Worker # check if the output is numerically equivalent with the eager mode 6872*da0073e9SAndroid Build Coastguard Worker m_eager = mha(use_legacy_api) 6873*da0073e9SAndroid Build Coastguard Worker 6874*da0073e9SAndroid Build Coastguard Worker value.grad = None 6875*da0073e9SAndroid Build Coastguard Worker attn_output_eager, _, _ = m_eager(query, value, offsets) 6876*da0073e9SAndroid Build Coastguard Worker attn_output_eager.sum().backward() 6877*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(attn_output_eager, attn_output)) 6878*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(value_grad, value.grad)) 6879*da0073e9SAndroid Build Coastguard Worker 6880*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 6881*da0073e9SAndroid Build Coastguard Worker def test_apply_(self, device, dtype): 6882*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 6883*da0073e9SAndroid Build Coastguard Worker [5, None, 10], 6884*da0073e9SAndroid Build Coastguard Worker device=device, 6885*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 6886*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 6887*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 6888*da0073e9SAndroid Build Coastguard Worker ) 6889*da0073e9SAndroid Build Coastguard Worker 6890*da0073e9SAndroid Build Coastguard Worker def f(x): 6891*da0073e9SAndroid Build Coastguard Worker return x * 2 6892*da0073e9SAndroid Build Coastguard Worker 6893*da0073e9SAndroid Build Coastguard Worker if device != "cpu": 6894*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6895*da0073e9SAndroid Build Coastguard Worker TypeError, "apply_ is only implemented on CPU tensors" 6896*da0073e9SAndroid Build Coastguard Worker ): 6897*da0073e9SAndroid Build Coastguard Worker nt.apply_(f) 6898*da0073e9SAndroid Build Coastguard Worker return 6899*da0073e9SAndroid Build Coastguard Worker 6900*da0073e9SAndroid Build Coastguard Worker before = nt._values.clone().detach() 6901*da0073e9SAndroid Build Coastguard Worker 6902*da0073e9SAndroid Build Coastguard Worker nt.apply_(f) 6903*da0073e9SAndroid Build Coastguard Worker expected = f(before) 6904*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, nt._values) 6905*da0073e9SAndroid Build Coastguard Worker # apply_ should swap values in-place without appending to autograd graph 6906*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(nt.grad) 6907*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(nt._values.grad_fn) 6908*da0073e9SAndroid Build Coastguard Worker 6909*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float64, torch.float32, torch.half) 6910*da0073e9SAndroid Build Coastguard Worker def test_jagged_padded_dense_conversion_kernels(self, device, dtype): 6911*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10, 5, device=device, dtype=dtype) 6912*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 1, 3, 8, 10], device=device, dtype=torch.int64) 6913*da0073e9SAndroid Build Coastguard Worker max_length = offsets.diff().max().item() 6914*da0073e9SAndroid Build Coastguard Worker padding_value = 1.3 6915*da0073e9SAndroid Build Coastguard Worker 6916*da0073e9SAndroid Build Coastguard Worker # convert jagged -> padded dense 6917*da0073e9SAndroid Build Coastguard Worker padded = torch.ops.aten._jagged_to_padded_dense_forward( 6918*da0073e9SAndroid Build Coastguard Worker values, [offsets], [max_length], padding_value 6919*da0073e9SAndroid Build Coastguard Worker ) 6920*da0073e9SAndroid Build Coastguard Worker 6921*da0073e9SAndroid Build Coastguard Worker batch_size = offsets.shape[0] - 1 6922*da0073e9SAndroid Build Coastguard Worker expected_padded_shape = (batch_size, max_length, values.shape[-1]) 6923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(padded.shape, expected_padded_shape) 6924*da0073e9SAndroid Build Coastguard Worker 6925*da0073e9SAndroid Build Coastguard Worker # convert padded dense -> jagged 6926*da0073e9SAndroid Build Coastguard Worker total_L = values.shape[0] 6927*da0073e9SAndroid Build Coastguard Worker output_jagged = torch.ops.aten._padded_dense_to_jagged_forward( 6928*da0073e9SAndroid Build Coastguard Worker padded, [offsets], total_L 6929*da0073e9SAndroid Build Coastguard Worker ) 6930*da0073e9SAndroid Build Coastguard Worker 6931*da0073e9SAndroid Build Coastguard Worker # should be equivalent to the original values 6932*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, output_jagged) 6933*da0073e9SAndroid Build Coastguard Worker 6934*da0073e9SAndroid Build Coastguard Worker # success case: truncate to max length as needed 6935*da0073e9SAndroid Build Coastguard Worker trunc_max_length = max_length - 1 6936*da0073e9SAndroid Build Coastguard Worker trunc_padded = torch.ops.aten._jagged_to_padded_dense_forward( 6937*da0073e9SAndroid Build Coastguard Worker values, [offsets], [trunc_max_length], padding_value 6938*da0073e9SAndroid Build Coastguard Worker ) 6939*da0073e9SAndroid Build Coastguard Worker self.assertEqual(padded[:, :trunc_max_length, :], trunc_padded) 6940*da0073e9SAndroid Build Coastguard Worker 6941*da0073e9SAndroid Build Coastguard Worker # specific to CPU impls 6942*da0073e9SAndroid Build Coastguard Worker if device == "cpu": 6943*da0073e9SAndroid Build Coastguard Worker # error case: multiple offsets on cpu since CPU kernels don't support more now 6944*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6945*da0073e9SAndroid Build Coastguard Worker RuntimeError, "only a single jagged dim is supported" 6946*da0073e9SAndroid Build Coastguard Worker ): 6947*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._jagged_to_padded_dense_forward( 6948*da0073e9SAndroid Build Coastguard Worker values, [offsets, offsets], [max_length, max_length], padding_value 6949*da0073e9SAndroid Build Coastguard Worker ) 6950*da0073e9SAndroid Build Coastguard Worker 6951*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6952*da0073e9SAndroid Build Coastguard Worker RuntimeError, "only a single jagged dim is supported" 6953*da0073e9SAndroid Build Coastguard Worker ): 6954*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._padded_dense_to_jagged_forward( 6955*da0073e9SAndroid Build Coastguard Worker padded, [offsets, offsets], total_L 6956*da0073e9SAndroid Build Coastguard Worker ) 6957*da0073e9SAndroid Build Coastguard Worker 6958*da0073e9SAndroid Build Coastguard Worker # error case: > 1D offsets 6959*da0073e9SAndroid Build Coastguard Worker offsets2d = offsets.unsqueeze(-1) 6960*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"): 6961*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._jagged_to_padded_dense_forward( 6962*da0073e9SAndroid Build Coastguard Worker values, [offsets2d], [max_length], padding_value 6963*da0073e9SAndroid Build Coastguard Worker ) 6964*da0073e9SAndroid Build Coastguard Worker 6965*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"): 6966*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._padded_dense_to_jagged_forward( 6967*da0073e9SAndroid Build Coastguard Worker padded, [offsets2d], total_L 6968*da0073e9SAndroid Build Coastguard Worker ) 6969*da0073e9SAndroid Build Coastguard Worker 6970*da0073e9SAndroid Build Coastguard Worker # error case: final offset != total_L 6971*da0073e9SAndroid Build Coastguard Worker offsets_wrong = offsets.clone().detach() 6972*da0073e9SAndroid Build Coastguard Worker offsets_wrong[-1] = total_L + 1 6973*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6974*da0073e9SAndroid Build Coastguard Worker RuntimeError, "final offset should match total_L value" 6975*da0073e9SAndroid Build Coastguard Worker ): 6976*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._padded_dense_to_jagged_forward( 6977*da0073e9SAndroid Build Coastguard Worker padded, [offsets_wrong], total_L 6978*da0073e9SAndroid Build Coastguard Worker ) 6979*da0073e9SAndroid Build Coastguard Worker 6980*da0073e9SAndroid Build Coastguard Worker # error case: 1D padded input 6981*da0073e9SAndroid Build Coastguard Worker padded_wrong = padded.flatten().clone().detach() 6982*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected padded dim >= 2"): 6983*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._padded_dense_to_jagged_forward( 6984*da0073e9SAndroid Build Coastguard Worker padded_wrong, [offsets], total_L 6985*da0073e9SAndroid Build Coastguard Worker ) 6986*da0073e9SAndroid Build Coastguard Worker 6987*da0073e9SAndroid Build Coastguard Worker # error case: batch item has length > max length 6988*da0073e9SAndroid Build Coastguard Worker # max_length is 5 above; 7 here 6989*da0073e9SAndroid Build Coastguard Worker offsets_wrong = torch.tensor( 6990*da0073e9SAndroid Build Coastguard Worker [0, 1, 8, 9, 10], device=device, dtype=torch.int64 6991*da0073e9SAndroid Build Coastguard Worker ) 6992*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "found batch item of length"): 6993*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._padded_dense_to_jagged_forward( 6994*da0073e9SAndroid Build Coastguard Worker padded, [offsets_wrong], total_L 6995*da0073e9SAndroid Build Coastguard Worker ) 6996*da0073e9SAndroid Build Coastguard Worker 6997*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 6998*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Test compiles internally") 6999*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 7000*da0073e9SAndroid Build Coastguard Worker sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" 7001*da0073e9SAndroid Build Coastguard Worker ) 7002*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 7003*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 7004*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm 7005*da0073e9SAndroid Build Coastguard Worker def test_compile_preserves_metadata_cache(self, device, dtype): 7006*da0073e9SAndroid Build Coastguard Worker # shape (B, *, D) 7007*da0073e9SAndroid Build Coastguard Worker nt = random_nt_from_dims( 7008*da0073e9SAndroid Build Coastguard Worker [4, None, 3, 16], 7009*da0073e9SAndroid Build Coastguard Worker device=device, 7010*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 7011*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 7012*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 7013*da0073e9SAndroid Build Coastguard Worker ) 7014*da0073e9SAndroid Build Coastguard Worker 7015*da0073e9SAndroid Build Coastguard Worker # expect min / max seqlen to be stored here 7016*da0073e9SAndroid Build Coastguard Worker cache = dict(nt._metadata_cache) 7017*da0073e9SAndroid Build Coastguard Worker 7018*da0073e9SAndroid Build Coastguard Worker @torch.compile 7019*da0073e9SAndroid Build Coastguard Worker def f(nt): 7020*da0073e9SAndroid Build Coastguard Worker q = nt.transpose(-3, -2) 7021*da0073e9SAndroid Build Coastguard Worker output = F.scaled_dot_product_attention(q, q, q).transpose(-3, -2) 7022*da0073e9SAndroid Build Coastguard Worker return output 7023*da0073e9SAndroid Build Coastguard Worker 7024*da0073e9SAndroid Build Coastguard Worker output = f(nt) 7025*da0073e9SAndroid Build Coastguard Worker output.backward(torch.ones_like(output)) 7026*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output._metadata_cache, cache) 7027*da0073e9SAndroid Build Coastguard Worker 7028*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 7029*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Test compiles internally") 7030*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 7031*da0073e9SAndroid Build Coastguard Worker sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" 7032*da0073e9SAndroid Build Coastguard Worker ) 7033*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 7034*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 7035*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm 7036*da0073e9SAndroid Build Coastguard Worker def test_compile_with_dynamic_max_seq_len(self, device, dtype): 7037*da0073e9SAndroid Build Coastguard Worker # shape (B, *, D) 7038*da0073e9SAndroid Build Coastguard Worker # max seq len: 18 7039*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 7040*da0073e9SAndroid Build Coastguard Worker [ 7041*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 5), 7042*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 5), 7043*da0073e9SAndroid Build Coastguard Worker torch.randn(18, 5), 7044*da0073e9SAndroid Build Coastguard Worker ], 7045*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 7046*da0073e9SAndroid Build Coastguard Worker ) 7047*da0073e9SAndroid Build Coastguard Worker 7048*da0073e9SAndroid Build Coastguard Worker # max seq len: 19 7049*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor( 7050*da0073e9SAndroid Build Coastguard Worker [ 7051*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 5), 7052*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 5), 7053*da0073e9SAndroid Build Coastguard Worker torch.randn(19, 5), 7054*da0073e9SAndroid Build Coastguard Worker ], 7055*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 7056*da0073e9SAndroid Build Coastguard Worker ) 7057*da0073e9SAndroid Build Coastguard Worker 7058*da0073e9SAndroid Build Coastguard Worker def f(nt): 7059*da0073e9SAndroid Build Coastguard Worker # TODO: Replace with public API when we can use @properties 7060*da0073e9SAndroid Build Coastguard Worker return torch.ones_like(nt) * nt._get_max_seqlen() 7061*da0073e9SAndroid Build Coastguard Worker 7062*da0073e9SAndroid Build Coastguard Worker for dynamic in [False, True, None]: 7063*da0073e9SAndroid Build Coastguard Worker self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) 7064*da0073e9SAndroid Build Coastguard Worker 7065*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 7066*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Test compiles internally") 7067*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 7068*da0073e9SAndroid Build Coastguard Worker sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" 7069*da0073e9SAndroid Build Coastguard Worker ) 7070*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 7071*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 7072*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm 7073*da0073e9SAndroid Build Coastguard Worker def test_compile_with_dynamic_min_seq_len(self, device, dtype): 7074*da0073e9SAndroid Build Coastguard Worker # shape (B, *, D) 7075*da0073e9SAndroid Build Coastguard Worker # min seq len: 7 7076*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 7077*da0073e9SAndroid Build Coastguard Worker [ 7078*da0073e9SAndroid Build Coastguard Worker torch.randn(7, 5), 7079*da0073e9SAndroid Build Coastguard Worker torch.randn(8, 5), 7080*da0073e9SAndroid Build Coastguard Worker torch.randn(9, 5), 7081*da0073e9SAndroid Build Coastguard Worker ], 7082*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 7083*da0073e9SAndroid Build Coastguard Worker ) 7084*da0073e9SAndroid Build Coastguard Worker 7085*da0073e9SAndroid Build Coastguard Worker # min seq len: 8 7086*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor( 7087*da0073e9SAndroid Build Coastguard Worker [ 7088*da0073e9SAndroid Build Coastguard Worker torch.randn(8, 5), 7089*da0073e9SAndroid Build Coastguard Worker torch.randn(9, 5), 7090*da0073e9SAndroid Build Coastguard Worker torch.randn(10, 5), 7091*da0073e9SAndroid Build Coastguard Worker ], 7092*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 7093*da0073e9SAndroid Build Coastguard Worker ) 7094*da0073e9SAndroid Build Coastguard Worker 7095*da0073e9SAndroid Build Coastguard Worker def f(nt): 7096*da0073e9SAndroid Build Coastguard Worker # TODO: Replace with public API when we can use @properties 7097*da0073e9SAndroid Build Coastguard Worker return torch.ones_like(nt) * nt._get_min_seqlen() 7098*da0073e9SAndroid Build Coastguard Worker 7099*da0073e9SAndroid Build Coastguard Worker for dynamic in [False, True, None]: 7100*da0073e9SAndroid Build Coastguard Worker self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) 7101*da0073e9SAndroid Build Coastguard Worker 7102*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 7103*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Test compiles internally") 7104*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 7105*da0073e9SAndroid Build Coastguard Worker sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" 7106*da0073e9SAndroid Build Coastguard Worker ) 7107*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 7108*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") 7109*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm 7110*da0073e9SAndroid Build Coastguard Worker def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): 7111*da0073e9SAndroid Build Coastguard Worker # shape (B, *, D) 7112*da0073e9SAndroid Build Coastguard Worker # max seq len: 18 7113*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 7114*da0073e9SAndroid Build Coastguard Worker [ 7115*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 5), 7116*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 5), 7117*da0073e9SAndroid Build Coastguard Worker torch.randn(18, 5), 7118*da0073e9SAndroid Build Coastguard Worker ], 7119*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 7120*da0073e9SAndroid Build Coastguard Worker ) 7121*da0073e9SAndroid Build Coastguard Worker 7122*da0073e9SAndroid Build Coastguard Worker # max seq len: 19 7123*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor( 7124*da0073e9SAndroid Build Coastguard Worker [ 7125*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 5), 7126*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 5), 7127*da0073e9SAndroid Build Coastguard Worker torch.randn(19, 5), 7128*da0073e9SAndroid Build Coastguard Worker ], 7129*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 7130*da0073e9SAndroid Build Coastguard Worker ) 7131*da0073e9SAndroid Build Coastguard Worker 7132*da0073e9SAndroid Build Coastguard Worker def f(nt): 7133*da0073e9SAndroid Build Coastguard Worker nt2 = nt.sin() + 1 7134*da0073e9SAndroid Build Coastguard Worker # TODO: Replace with public API when we can use @properties 7135*da0073e9SAndroid Build Coastguard Worker return torch.ones_like(nt2) * nt2._get_max_seqlen() 7136*da0073e9SAndroid Build Coastguard Worker 7137*da0073e9SAndroid Build Coastguard Worker ref = f(nt) 7138*da0073e9SAndroid Build Coastguard Worker output = torch.compile(f, fullgraph=True, dynamic=False)(nt) 7139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, output) 7140*da0073e9SAndroid Build Coastguard Worker 7141*da0073e9SAndroid Build Coastguard Worker for dynamic in [False, True, None]: 7142*da0073e9SAndroid Build Coastguard Worker self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) 7143*da0073e9SAndroid Build Coastguard Worker 7144*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.double, torch.half) 7145*da0073e9SAndroid Build Coastguard Worker def test_unbind_backward(self, device, dtype): 7146*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 7147*da0073e9SAndroid Build Coastguard Worker [ 7148*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 4, device=device), 7149*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 4, device=device), 7150*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 4, device=device), 7151*da0073e9SAndroid Build Coastguard Worker ], 7152*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 7153*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 7154*da0073e9SAndroid Build Coastguard Worker ) 7155*da0073e9SAndroid Build Coastguard Worker 7156*da0073e9SAndroid Build Coastguard Worker a, b, c = nt.unbind() 7157*da0073e9SAndroid Build Coastguard Worker b.sum().backward() 7158*da0073e9SAndroid Build Coastguard Worker 7159*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable 7160*da0073e9SAndroid Build Coastguard Worker def check(nt): 7161*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.zeros_like(nt) 7162*da0073e9SAndroid Build Coastguard Worker expected_grad.unbind()[1].add_(1.0) 7163*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt.grad, expected_grad) 7164*da0073e9SAndroid Build Coastguard Worker 7165*da0073e9SAndroid Build Coastguard Worker check(nt) 7166*da0073e9SAndroid Build Coastguard Worker 7167*da0073e9SAndroid Build Coastguard Worker 7168*da0073e9SAndroid Build Coastguard WorkerFORWARD_FAILURES = { 7169*da0073e9SAndroid Build Coastguard Worker # === BEGIN NotImplementedError SECTION === 7170*da0073e9SAndroid Build Coastguard Worker # unary 7171*da0073e9SAndroid Build Coastguard Worker "nn.functional.celu", 7172*da0073e9SAndroid Build Coastguard Worker "nn.functional.elu", 7173*da0073e9SAndroid Build Coastguard Worker "nn.functional.hardshrink", 7174*da0073e9SAndroid Build Coastguard Worker "nn.functional.hardsigmoid", 7175*da0073e9SAndroid Build Coastguard Worker "nn.functional.hardtanh", 7176*da0073e9SAndroid Build Coastguard Worker "nn.functional.logsigmoid", 7177*da0073e9SAndroid Build Coastguard Worker "nn.functional.mish", 7178*da0073e9SAndroid Build Coastguard Worker "nn.functional.relu6", 7179*da0073e9SAndroid Build Coastguard Worker "nn.functional.rrelu", 7180*da0073e9SAndroid Build Coastguard Worker "nn.functional.selu", 7181*da0073e9SAndroid Build Coastguard Worker "nn.functional.softplus", 7182*da0073e9SAndroid Build Coastguard Worker "nn.functional.softshrink", 7183*da0073e9SAndroid Build Coastguard Worker "nn.functional.threshold", 7184*da0073e9SAndroid Build Coastguard Worker "rad2deg", 7185*da0073e9SAndroid Build Coastguard Worker # binary 7186*da0073e9SAndroid Build Coastguard Worker "__rsub__", 7187*da0073e9SAndroid Build Coastguard Worker "complex", 7188*da0073e9SAndroid Build Coastguard Worker "floor_divide", 7189*da0073e9SAndroid Build Coastguard Worker "polar", 7190*da0073e9SAndroid Build Coastguard Worker "rsub", 7191*da0073e9SAndroid Build Coastguard Worker # reduction 7192*da0073e9SAndroid Build Coastguard Worker "all", 7193*da0073e9SAndroid Build Coastguard Worker "amax", 7194*da0073e9SAndroid Build Coastguard Worker "amin", 7195*da0073e9SAndroid Build Coastguard Worker "any", 7196*da0073e9SAndroid Build Coastguard Worker "argmax", 7197*da0073e9SAndroid Build Coastguard Worker "argmin", 7198*da0073e9SAndroid Build Coastguard Worker "count_nonzero", 7199*da0073e9SAndroid Build Coastguard Worker "linalg.vector_norm", 7200*da0073e9SAndroid Build Coastguard Worker "nansum", 7201*da0073e9SAndroid Build Coastguard Worker "std", 7202*da0073e9SAndroid Build Coastguard Worker "std.unbiased", 7203*da0073e9SAndroid Build Coastguard Worker "var", 7204*da0073e9SAndroid Build Coastguard Worker "var.unbiased", 7205*da0073e9SAndroid Build Coastguard Worker # === BEGIN UNSUPPORTED SECTION === 7206*da0073e9SAndroid Build Coastguard Worker # RuntimeError: mean(): not supported for NestedTensor on dim=1 7207*da0073e9SAndroid Build Coastguard Worker "mean", 7208*da0073e9SAndroid Build Coastguard Worker # ValueError: expects strided tensor (got torch.jagged tensor) 7209*da0073e9SAndroid Build Coastguard Worker "masked.amax", 7210*da0073e9SAndroid Build Coastguard Worker "masked.amin", 7211*da0073e9SAndroid Build Coastguard Worker "masked.argmax", 7212*da0073e9SAndroid Build Coastguard Worker "masked.argmin", 7213*da0073e9SAndroid Build Coastguard Worker "masked.logsumexp", 7214*da0073e9SAndroid Build Coastguard Worker "masked.mean", 7215*da0073e9SAndroid Build Coastguard Worker "masked.norm", 7216*da0073e9SAndroid Build Coastguard Worker "masked.prod", 7217*da0073e9SAndroid Build Coastguard Worker "masked.std", 7218*da0073e9SAndroid Build Coastguard Worker "masked.sum", 7219*da0073e9SAndroid Build Coastguard Worker "masked.var", 7220*da0073e9SAndroid Build Coastguard Worker # === BEGIN BUG SECTION === 7221*da0073e9SAndroid Build Coastguard Worker # Returns a tuple of Tensors so it doesn't work with NJT's unary pointwise logic 7222*da0073e9SAndroid Build Coastguard Worker "frexp", 7223*da0073e9SAndroid Build Coastguard Worker # Need to adjust sample input func to pass the right thing 7224*da0073e9SAndroid Build Coastguard Worker "nn.functional.prelu", 7225*da0073e9SAndroid Build Coastguard Worker # TypeError: fill() received an invalid combination of arguments 7226*da0073e9SAndroid Build Coastguard Worker # got (NestedTensor), but expected one of: 7227*da0073e9SAndroid Build Coastguard Worker # * (Tensor input, Tensor value) 7228*da0073e9SAndroid Build Coastguard Worker # * (Tensor input, Number value) 7229*da0073e9SAndroid Build Coastguard Worker "fill", 7230*da0073e9SAndroid Build Coastguard Worker # RuntimeError: unsupported tensor layout: Jagged 7231*da0073e9SAndroid Build Coastguard Worker "jiterator_binary", 7232*da0073e9SAndroid Build Coastguard Worker "jiterator_binary_return_by_ref", 7233*da0073e9SAndroid Build Coastguard Worker "jiterator_unary", 7234*da0073e9SAndroid Build Coastguard Worker # Bug found: sum() with keepdim=True returns invalid shape 7235*da0073e9SAndroid Build Coastguard Worker "sum", 7236*da0073e9SAndroid Build Coastguard Worker # RuntimeError: prod(): keepdim=True must be set for NestedTensor 7237*da0073e9SAndroid Build Coastguard Worker "prod", 7238*da0073e9SAndroid Build Coastguard Worker # RuntimeError: "jagged_to_padded_dense" not implemented for 'Bool' 7239*da0073e9SAndroid Build Coastguard Worker "nanmean", 7240*da0073e9SAndroid Build Coastguard Worker} 7241*da0073e9SAndroid Build Coastguard Worker 7242*da0073e9SAndroid Build Coastguard WorkerBACKWARD_FAILURES = { 7243*da0073e9SAndroid Build Coastguard Worker *FORWARD_FAILURES, 7244*da0073e9SAndroid Build Coastguard Worker # TODO: categorize these 7245*da0073e9SAndroid Build Coastguard Worker "__rpow__", 7246*da0073e9SAndroid Build Coastguard Worker "atanh", 7247*da0073e9SAndroid Build Coastguard Worker "cdouble", 7248*da0073e9SAndroid Build Coastguard Worker "cfloat", 7249*da0073e9SAndroid Build Coastguard Worker "chalf", 7250*da0073e9SAndroid Build Coastguard Worker "clamp_max", 7251*da0073e9SAndroid Build Coastguard Worker "clamp_min", 7252*da0073e9SAndroid Build Coastguard Worker "copysign", 7253*da0073e9SAndroid Build Coastguard Worker "float_power", 7254*da0073e9SAndroid Build Coastguard Worker "max.binary", 7255*da0073e9SAndroid Build Coastguard Worker "maximum", 7256*da0073e9SAndroid Build Coastguard Worker "min.binary", 7257*da0073e9SAndroid Build Coastguard Worker "minimum", 7258*da0073e9SAndroid Build Coastguard Worker "pow", 7259*da0073e9SAndroid Build Coastguard Worker "sgn", 7260*da0073e9SAndroid Build Coastguard Worker "sinc", 7261*da0073e9SAndroid Build Coastguard Worker "special.i1", 7262*da0073e9SAndroid Build Coastguard Worker "special.i1e", 7263*da0073e9SAndroid Build Coastguard Worker # clone() on a "non-contiguous with holes" NJT allocates a new offsets -> new nested int 7264*da0073e9SAndroid Build Coastguard Worker # RuntimeError: Function CloneBackward0 returned an invalid gradient at index 0 - 7265*da0073e9SAndroid Build Coastguard Worker # got [3, j29, 5] but expected shape compatible with [3, j28, 5] 7266*da0073e9SAndroid Build Coastguard Worker "clone", 7267*da0073e9SAndroid Build Coastguard Worker # Calling into torch.ops.aten.size directly 7268*da0073e9SAndroid Build Coastguard Worker "masked_select", 7269*da0073e9SAndroid Build Coastguard Worker} 7270*da0073e9SAndroid Build Coastguard Worker 7271*da0073e9SAndroid Build Coastguard WorkerCOMPILE_FORWARD_FAILURES = { 7272*da0073e9SAndroid Build Coastguard Worker *FORWARD_FAILURES, 7273*da0073e9SAndroid Build Coastguard Worker # clone() on non-contiguous with holes NJTs currently use unbind(), leading to 7274*da0073e9SAndroid Build Coastguard Worker # data-dependent error in torch.compile 7275*da0073e9SAndroid Build Coastguard Worker "clone", 7276*da0073e9SAndroid Build Coastguard Worker} 7277*da0073e9SAndroid Build Coastguard Worker 7278*da0073e9SAndroid Build Coastguard WorkerCOMPARE_TENSOR_COMPONENT_EQUALITY = { 7279*da0073e9SAndroid Build Coastguard Worker # masked_select is expected to output a different shape 7280*da0073e9SAndroid Build Coastguard Worker "masked_select", 7281*da0073e9SAndroid Build Coastguard Worker} 7282*da0073e9SAndroid Build Coastguard Worker 7283*da0073e9SAndroid Build Coastguard Worker 7284*da0073e9SAndroid Build Coastguard Workerdef withXFails(failure_list): 7285*da0073e9SAndroid Build Coastguard Worker return decorateIf( 7286*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 7287*da0073e9SAndroid Build Coastguard Worker lambda params: params["op"].full_name in failure_list, 7288*da0073e9SAndroid Build Coastguard Worker ) 7289*da0073e9SAndroid Build Coastguard Worker 7290*da0073e9SAndroid Build Coastguard Worker 7291*da0073e9SAndroid Build Coastguard Worker# OpInfo-based NJT tests. These tests utilize an NJT-specific op_db generated from the standard 7292*da0073e9SAndroid Build Coastguard Worker# op_db. Note that certain tradeoffs were made wrt coverage vs. time spent running tests: 7293*da0073e9SAndroid Build Coastguard Worker# * All tests run with dtype=torch.float32 only 7294*da0073e9SAndroid Build Coastguard Workerclass TestNestedTensorOpInfo(NestedTensorTestCase): 7295*da0073e9SAndroid Build Coastguard Worker # TODO: move this 7296*da0073e9SAndroid Build Coastguard Worker def _gen_grad_outputs(self, out_val): 7297*da0073e9SAndroid Build Coastguard Worker if isinstance(out_val, (list, tuple)): 7298*da0073e9SAndroid Build Coastguard Worker return tuple(torch.ones_like(c) for c in out_val) 7299*da0073e9SAndroid Build Coastguard Worker else: 7300*da0073e9SAndroid Build Coastguard Worker return (torch.ones_like(out_val),) 7301*da0073e9SAndroid Build Coastguard Worker 7302*da0073e9SAndroid Build Coastguard Worker @withXFails(FORWARD_FAILURES) 7303*da0073e9SAndroid Build Coastguard Worker @ops([op for op in njt_op_db if op.supports_njt], allowed_dtypes=(torch.float32,)) 7304*da0073e9SAndroid Build Coastguard Worker def test_forward(self, device, dtype, op): 7305*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=False): 7306*da0073e9SAndroid Build Coastguard Worker # compare to reference, but expect different nested int 7307*da0073e9SAndroid Build Coastguard Worker out = op.op(sample.input, *sample.args, **sample.kwargs) 7308*da0073e9SAndroid Build Coastguard Worker out_ref = op.ref(op, sample) 7309*da0073e9SAndroid Build Coastguard Worker self.assertEqualIgnoringNestedInts(out, out_ref) 7310*da0073e9SAndroid Build Coastguard Worker 7311*da0073e9SAndroid Build Coastguard Worker @withXFails(BACKWARD_FAILURES) 7312*da0073e9SAndroid Build Coastguard Worker @ops( 7313*da0073e9SAndroid Build Coastguard Worker [op for op in njt_op_db if op.supports_njt and op.supports_autograd], 7314*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=(torch.float32,), 7315*da0073e9SAndroid Build Coastguard Worker ) 7316*da0073e9SAndroid Build Coastguard Worker def test_backward(self, device, dtype, op): 7317*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=True): 7318*da0073e9SAndroid Build Coastguard Worker # compare to reference, but expect different nested int 7319*da0073e9SAndroid Build Coastguard Worker out = op.op(sample.input, *sample.args, **sample.kwargs) 7320*da0073e9SAndroid Build Coastguard Worker out_ref = op.ref(op, sample) 7321*da0073e9SAndroid Build Coastguard Worker self.assertEqualIgnoringNestedInts(out, out_ref) 7322*da0073e9SAndroid Build Coastguard Worker 7323*da0073e9SAndroid Build Coastguard Worker inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs)) 7324*da0073e9SAndroid Build Coastguard Worker g_inps = [ 7325*da0073e9SAndroid Build Coastguard Worker inp 7326*da0073e9SAndroid Build Coastguard Worker for inp in inps 7327*da0073e9SAndroid Build Coastguard Worker if isinstance(inp, torch.Tensor) and inp.requires_grad 7328*da0073e9SAndroid Build Coastguard Worker ] 7329*da0073e9SAndroid Build Coastguard Worker if len(g_inps) > 0: 7330*da0073e9SAndroid Build Coastguard Worker grads = torch.autograd.grad( 7331*da0073e9SAndroid Build Coastguard Worker out, inputs=g_inps, grad_outputs=self._gen_grad_outputs(out) 7332*da0073e9SAndroid Build Coastguard Worker ) 7333*da0073e9SAndroid Build Coastguard Worker 7334*da0073e9SAndroid Build Coastguard Worker grads_ref = torch.autograd.grad( 7335*da0073e9SAndroid Build Coastguard Worker out_ref, 7336*da0073e9SAndroid Build Coastguard Worker inputs=g_inps, 7337*da0073e9SAndroid Build Coastguard Worker grad_outputs=self._gen_grad_outputs(out_ref), 7338*da0073e9SAndroid Build Coastguard Worker ) 7339*da0073e9SAndroid Build Coastguard Worker 7340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads, grads_ref) 7341*da0073e9SAndroid Build Coastguard Worker 7342*da0073e9SAndroid Build Coastguard Worker @withXFails(COMPILE_FORWARD_FAILURES) 7343*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) 7344*da0073e9SAndroid Build Coastguard Worker @ops([op for op in njt_op_db if op.supports_njt], allowed_dtypes=(torch.float32,)) 7345*da0073e9SAndroid Build Coastguard Worker def test_compile_forward(self, device, dtype, op): 7346*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=False): 7347*da0073e9SAndroid Build Coastguard Worker torch.compiler.reset() 7348*da0073e9SAndroid Build Coastguard Worker 7349*da0073e9SAndroid Build Coastguard Worker op_fn = op.op 7350*da0073e9SAndroid Build Coastguard Worker 7351*da0073e9SAndroid Build Coastguard Worker def f(*args, **kwargs): 7352*da0073e9SAndroid Build Coastguard Worker return op_fn(*args, **kwargs) 7353*da0073e9SAndroid Build Coastguard Worker 7354*da0073e9SAndroid Build Coastguard Worker compiled_f = torch.compile( 7355*da0073e9SAndroid Build Coastguard Worker f, fullgraph=True, backend="aot_eager_decomp_partition" 7356*da0073e9SAndroid Build Coastguard Worker ) 7357*da0073e9SAndroid Build Coastguard Worker 7358*da0073e9SAndroid Build Coastguard Worker out_ref = f(sample.input, *sample.args, **sample.kwargs) 7359*da0073e9SAndroid Build Coastguard Worker out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs) 7360*da0073e9SAndroid Build Coastguard Worker 7361*da0073e9SAndroid Build Coastguard Worker if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY: 7362*da0073e9SAndroid Build Coastguard Worker self.assertEqualIgnoringNestedInts(out_compile, out_ref) 7363*da0073e9SAndroid Build Coastguard Worker else: 7364*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_compile, out_ref) 7365*da0073e9SAndroid Build Coastguard Worker 7366*da0073e9SAndroid Build Coastguard Worker @withXFails(BACKWARD_FAILURES) 7367*da0073e9SAndroid Build Coastguard Worker @ops( 7368*da0073e9SAndroid Build Coastguard Worker [op for op in njt_op_db if op.supports_njt and op.supports_autograd], 7369*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=(torch.float32,), 7370*da0073e9SAndroid Build Coastguard Worker ) 7371*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) 7372*da0073e9SAndroid Build Coastguard Worker def test_compile_backward(self, device, dtype, op): 7373*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs(device=device, dtype=dtype, requires_grad=True): 7374*da0073e9SAndroid Build Coastguard Worker torch.compiler.reset() 7375*da0073e9SAndroid Build Coastguard Worker 7376*da0073e9SAndroid Build Coastguard Worker op_fn = op.op 7377*da0073e9SAndroid Build Coastguard Worker 7378*da0073e9SAndroid Build Coastguard Worker def f(*args, **kwargs): 7379*da0073e9SAndroid Build Coastguard Worker return op_fn(*args, **kwargs) 7380*da0073e9SAndroid Build Coastguard Worker 7381*da0073e9SAndroid Build Coastguard Worker compiled_f = torch.compile( 7382*da0073e9SAndroid Build Coastguard Worker f, fullgraph=True, backend="aot_eager_decomp_partition" 7383*da0073e9SAndroid Build Coastguard Worker ) 7384*da0073e9SAndroid Build Coastguard Worker 7385*da0073e9SAndroid Build Coastguard Worker out_ref = f(sample.input, *sample.args, **sample.kwargs) 7386*da0073e9SAndroid Build Coastguard Worker out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs) 7387*da0073e9SAndroid Build Coastguard Worker 7388*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_compile, out_ref) 7389*da0073e9SAndroid Build Coastguard Worker 7390*da0073e9SAndroid Build Coastguard Worker inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs)) 7391*da0073e9SAndroid Build Coastguard Worker g_inps = [ 7392*da0073e9SAndroid Build Coastguard Worker inp 7393*da0073e9SAndroid Build Coastguard Worker for inp in inps 7394*da0073e9SAndroid Build Coastguard Worker if isinstance(inp, torch.Tensor) and inp.requires_grad 7395*da0073e9SAndroid Build Coastguard Worker ] 7396*da0073e9SAndroid Build Coastguard Worker if len(g_inps) > 0: 7397*da0073e9SAndroid Build Coastguard Worker grads_compile = torch.autograd.grad( 7398*da0073e9SAndroid Build Coastguard Worker out_compile, 7399*da0073e9SAndroid Build Coastguard Worker inputs=g_inps, 7400*da0073e9SAndroid Build Coastguard Worker grad_outputs=self._gen_grad_outputs(out_compile), 7401*da0073e9SAndroid Build Coastguard Worker ) 7402*da0073e9SAndroid Build Coastguard Worker 7403*da0073e9SAndroid Build Coastguard Worker grads_ref = torch.autograd.grad( 7404*da0073e9SAndroid Build Coastguard Worker out_ref, inputs=g_inps, grad_outputs=self._gen_grad_outputs(out_ref) 7405*da0073e9SAndroid Build Coastguard Worker ) 7406*da0073e9SAndroid Build Coastguard Worker 7407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads_compile, grads_ref) 7408*da0073e9SAndroid Build Coastguard Worker 7409*da0073e9SAndroid Build Coastguard Worker 7410*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestNestedTensor) 7411*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNestedTensorDeviceType, globals()) 7412*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNestedTensorAutograd, globals()) 7413*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNestedTensorSubclass, globals()) 7414*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNestedTensorOpInfo, globals()) 7415*da0073e9SAndroid Build Coastguard Worker 7416*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 7417*da0073e9SAndroid Build Coastguard Worker run_tests() 7418