1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3 4import unittest 5import warnings 6 7import torch 8import torch.distributed as dist 9import torch.testing._internal.common_methods_invocations as common_ops 10from torch.distributed._tensor import DeviceMesh, DTensor 11from torch.overrides import resolve_name 12from torch.testing._internal.common_device_type import ( 13 instantiate_device_type_tests, 14 ops, 15) 16from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db 17from torch.testing._internal.common_utils import ( 18 run_tests, 19 suppress_warnings, 20 TEST_WITH_ASAN, 21) 22from torch.testing._internal.distributed._tensor.common_dtensor import ( 23 DTensorConverter, 24 DTensorOpTestBase, 25) 26from torch.utils import _pytree as pytree 27from torch.utils._pytree import tree_map 28 29 30# rewrite common size variables to sth can be sharded evenly 31# we can enable uneven shards later, but need to adjust more on 32# sample inputs (i.e. view/reshape need to adjust shape size as well) 33common_ops.L = 24 34common_ops.M = 12 35common_ops.S = 4 36common_ops.XS = 2 37 38 39# Copied from functorch 40def xfail(op_name, variant_name="", *, device_type=None, dtypes=None): 41 return (op_name, variant_name, device_type, dtypes, True) 42 43 44def skip(op_name, variant_name="", *, device_type=None, dtypes=None): 45 return (op_name, variant_name, device_type, dtypes, False) 46 47 48def skipOps(test_case_name, base_test_name, to_skip): 49 all_opinfos = op_db 50 for xfail in to_skip: 51 op_name, variant_name, device_type, dtypes, expected_failure = xfail 52 matching_opinfos = [ 53 o 54 for o in all_opinfos 55 if o.name == op_name and o.variant_test_name == variant_name 56 ] 57 assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}" 58 for opinfo in matching_opinfos: 59 decorators = list(opinfo.decorators) 60 if expected_failure: 61 decorator = DecorateInfo( 62 unittest.expectedFailure, 63 test_case_name, 64 base_test_name, 65 device_type=device_type, 66 dtypes=dtypes, 67 ) 68 decorators.append(decorator) 69 else: 70 decorator = DecorateInfo( 71 unittest.skip("Skipped!"), 72 test_case_name, 73 base_test_name, 74 device_type=device_type, 75 dtypes=dtypes, 76 ) 77 decorators.append(decorator) 78 opinfo.decorators = tuple(decorators) 79 80 # This decorator doesn't modify fn in any way 81 def wrapped(fn): 82 return fn 83 84 return wrapped 85 86 87# Re-generate this failed list, turn on dry_run of the below func 88# check_dtensor_func(self, test, op, dry_run=True), then run sth 89# like python test/distributed/_tensor/test_dtensor_ops.py > failed.expect 90dtensor_fails = { 91 # these sometimes pass and sometimes fail 92 # we need to remove many of them from list once op 93 # get full support with varying sharding specs 94 xfail("__getitem__"), 95 xfail("__rsub__"), 96 xfail("_chunk_cat"), 97 xfail("_native_batch_norm_legit"), 98 xfail("_upsample_bilinear2d_aa"), 99 xfail("addbmm"), 100 xfail("addmv"), 101 xfail("addr"), 102 xfail("all"), 103 xfail("allclose"), 104 xfail("alias_copy"), 105 xfail("amax"), 106 xfail("amin"), 107 xfail("aminmax"), 108 xfail("any"), 109 xfail("arange"), 110 xfail("argmax"), 111 xfail("argmin"), 112 xfail("argsort"), 113 xfail("as_strided"), 114 xfail("as_strided", "partial_views"), 115 xfail("as_strided_copy"), 116 xfail("as_strided_scatter"), 117 xfail("bernoulli"), 118 xfail("_batch_norm_with_update"), 119 xfail("block_diag"), 120 xfail("broadcast_shapes"), 121 xfail("cauchy"), 122 xfail("cdist"), 123 xfail("cholesky"), 124 xfail("cholesky_inverse"), 125 xfail("cholesky_solve"), 126 xfail("chunk"), 127 xfail("clamp"), 128 xfail("clamp_max"), 129 xfail("clamp_min"), 130 xfail("combinations"), 131 xfail("complex"), 132 xfail("constant_pad_nd"), 133 xfail("count_nonzero"), 134 xfail("cross"), 135 xfail("cummax"), 136 xfail("cummin"), 137 xfail("cumsum"), 138 xfail("cumulative_trapezoid"), 139 xfail("diagonal_scatter"), 140 xfail("dist"), 141 xfail("dot"), 142 xfail("empty"), 143 xfail("empty_strided"), 144 xfail("empty_like"), 145 xfail("empty_permuted"), 146 xfail("expand_copy"), 147 xfail("exponential"), 148 xfail("equal"), 149 xfail("eye"), 150 xfail("fft.fft2"), 151 xfail("fft.fft"), 152 xfail("fft.fftn"), 153 xfail("fft.fftshift"), 154 xfail("fft.ifft2"), 155 xfail("fft.ifft"), 156 xfail("fft.ifftshift"), 157 xfail("fft.ihfft2"), 158 xfail("fft.ihfft"), 159 xfail("fft.ihfftn"), 160 xfail("fft.irfft2"), 161 xfail("fft.irfftn"), 162 xfail("fft.rfft2"), 163 xfail("fft.rfft"), 164 xfail("fft.rfftn"), 165 xfail("fill"), 166 xfail("flip"), 167 xfail("fliplr"), 168 xfail("flipud"), 169 xfail("floor_divide"), 170 xfail("fmax"), 171 xfail("fmin"), 172 xfail("frexp"), 173 xfail("full"), 174 xfail("full_like"), 175 xfail("gather"), 176 xfail("geometric"), 177 xfail("geqrf"), 178 xfail("grid_sampler_2d"), 179 xfail("gradient"), 180 xfail("heaviside"), 181 xfail("histc"), 182 xfail("histogram"), 183 xfail("histogramdd"), 184 xfail("index_add"), 185 xfail("index_copy"), 186 xfail("index_fill"), 187 xfail("index_put"), 188 xfail("index_reduce", "prod"), 189 xfail("index_reduce", "mean"), 190 xfail("index_reduce", "amax"), 191 xfail("index_reduce", "amin"), 192 xfail("index_select"), 193 xfail("isin"), 194 xfail("kthvalue"), 195 xfail("linalg.cholesky"), 196 xfail("linalg.cholesky_ex"), 197 xfail("linalg.cross"), 198 xfail("linalg.det"), 199 xfail("linalg.det", "singular"), 200 xfail("linalg.eig"), 201 xfail("linalg.eigvals"), 202 xfail("linalg.householder_product"), 203 xfail("linalg.inv"), 204 xfail("linalg.inv_ex"), 205 xfail("linalg.ldl_factor"), 206 xfail("linalg.ldl_factor_ex"), 207 xfail("linalg.ldl_solve"), 208 xfail("linalg.lstsq"), 209 xfail("linalg.lstsq", "grad_oriented"), 210 xfail("linalg.lu"), 211 xfail("linalg.lu_factor"), 212 xfail("linalg.lu_factor_ex"), 213 xfail("linalg.lu_solve"), 214 xfail("linalg.matrix_norm"), 215 xfail("linalg.matrix_power"), 216 xfail("linalg.matrix_rank"), 217 xfail("linalg.matrix_rank", "hermitian"), 218 xfail("linalg.multi_dot"), 219 xfail("linalg.norm"), 220 xfail("linalg.norm", "subgradients_at_zero"), 221 xfail("linalg.pinv"), 222 xfail("linalg.pinv", "hermitian"), 223 xfail("linalg.slogdet"), 224 xfail("linalg.solve"), 225 xfail("linalg.solve_ex"), 226 xfail("linalg.solve_triangular"), 227 xfail("linalg.tensorinv"), 228 xfail("linalg.tensorsolve"), 229 xfail("linalg.vander"), 230 xfail("linalg.vecdot"), 231 xfail("linspace"), 232 xfail("linspace", "tensor_overload"), 233 xfail("log_normal"), 234 xfail("logcumsumexp"), 235 xfail("logdet"), 236 xfail("logspace"), 237 xfail("logspace", "tensor_overload"), 238 xfail("logsumexp"), 239 xfail("lu"), 240 xfail("lu_solve"), 241 xfail("lu_unpack"), 242 xfail("masked_fill"), 243 xfail("masked_scatter"), 244 xfail("masked_select"), 245 xfail("masked.amax"), 246 xfail("masked.amin"), 247 xfail("masked.argmax"), 248 xfail("masked.argmin"), 249 xfail("masked.cumprod"), 250 xfail("masked.cumsum"), 251 xfail("masked.logsumexp"), 252 xfail("masked.median"), 253 xfail("matrix_exp"), 254 xfail("max", "binary"), 255 xfail("max", "reduction_with_dim"), 256 xfail("maximum"), 257 xfail("median"), 258 xfail("min", "binary"), 259 xfail("min", "reduction_with_dim"), 260 xfail("minimum"), 261 xfail("mode"), 262 xfail("msort"), 263 xfail("multinomial"), 264 xfail("mv"), 265 xfail("max_pool2d_with_indices_backward", ""), 266 xfail("nanmean"), 267 xfail("nanmedian"), 268 xfail("nanquantile"), 269 xfail("nansum"), 270 xfail("native_batch_norm"), 271 xfail("native_dropout_backward"), 272 xfail("narrow_copy"), 273 xfail("ne"), 274 xfail("new_empty"), 275 xfail("new_empty_strided"), 276 xfail("transpose"), 277 xfail("nn.functional.adaptive_avg_pool1d"), 278 xfail("nn.functional.adaptive_avg_pool2d"), 279 xfail("nn.functional.adaptive_avg_pool3d"), 280 xfail("nn.functional.adaptive_max_pool1d"), 281 xfail("nn.functional.adaptive_max_pool2d"), 282 xfail("nn.functional.adaptive_max_pool3d"), 283 xfail("nn.functional.alpha_dropout"), 284 xfail("nn.functional.avg_pool1d"), 285 xfail("nn.functional.avg_pool2d"), 286 xfail("nn.functional.avg_pool3d"), 287 xfail("nn.functional.batch_norm"), 288 xfail("nn.functional.batch_norm", "without_cudnn"), 289 xfail("nn.functional.bilinear"), 290 xfail("nn.functional.binary_cross_entropy"), 291 xfail("nn.functional.binary_cross_entropy_with_logits"), 292 xfail("nn.functional.celu"), 293 xfail("nn.functional.conv1d"), 294 xfail("nn.functional.conv2d"), 295 xfail("nn.functional.conv3d"), 296 xfail("nn.functional.conv_transpose1d"), 297 xfail("nn.functional.conv_transpose2d"), 298 xfail("nn.functional.conv_transpose3d"), 299 xfail("nn.functional.cosine_similarity"), 300 xfail("nn.functional.ctc_loss"), 301 xfail("nn.functional.dropout"), 302 xfail("nn.functional.dropout2d"), 303 xfail("nn.functional.dropout3d"), 304 xfail("nn.functional.elu"), 305 xfail("nn.functional.fractional_max_pool2d"), 306 xfail("nn.functional.fractional_max_pool3d"), 307 xfail("nn.functional.glu"), 308 xfail("nn.functional.grid_sample"), 309 xfail("nn.functional.group_norm"), 310 xfail("nn.functional.hardshrink"), 311 xfail("nn.functional.hardsigmoid"), 312 xfail("nn.functional.hardswish"), 313 xfail("nn.functional.hardtanh"), 314 xfail("nn.functional.huber_loss"), 315 xfail("nn.functional.instance_norm"), 316 xfail("nn.functional.interpolate", "area"), 317 xfail("nn.functional.interpolate", "bicubic"), 318 xfail("nn.functional.interpolate", "bilinear"), 319 xfail("nn.functional.interpolate", "linear"), 320 xfail("nn.functional.interpolate", "nearest"), 321 xfail("nn.functional.interpolate", "nearest-exact"), 322 xfail("nn.functional.interpolate", "trilinear"), 323 xfail("nn.functional.leaky_relu"), 324 xfail("nn.functional.linear"), 325 xfail("nn.functional.local_response_norm"), 326 xfail("nn.functional.logsigmoid"), 327 xfail("nn.functional.margin_ranking_loss"), 328 xfail("nn.functional.max_pool1d"), 329 xfail("nn.functional.max_pool2d"), 330 xfail("nn.functional.max_pool3d"), 331 xfail("nn.functional.max_unpool1d"), 332 xfail("nn.functional.max_unpool1d", "grad"), 333 xfail("nn.functional.max_unpool2d"), 334 xfail("nn.functional.max_unpool2d", "grad"), 335 xfail("nn.functional.max_unpool3d"), 336 xfail("nn.functional.max_unpool3d", "grad"), 337 xfail("nn.functional.mish"), 338 xfail("nn.functional.mse_loss"), 339 xfail("nn.functional.multi_margin_loss"), 340 xfail("nn.functional.multi_head_attention_forward"), 341 xfail("nn.functional.multilabel_margin_loss"), 342 xfail("nn.functional.multilabel_soft_margin_loss"), 343 xfail("nn.functional.normalize"), 344 xfail("nn.functional.pad", "constant"), 345 xfail("nn.functional.pad", "reflect"), 346 xfail("nn.functional.pad", "replicate"), 347 xfail("nn.functional.pad", "replicate_negative"), 348 xfail("nn.functional.pairwise_distance"), 349 xfail("nn.functional.pdist"), 350 xfail("nn.functional.pixel_shuffle"), 351 xfail("nn.functional.pixel_unshuffle"), 352 xfail("nn.functional.prelu"), 353 xfail("nn.functional.relu6"), 354 xfail("nn.functional.rrelu"), 355 xfail("nn.functional.selu"), 356 xfail("nn.functional.smooth_l1_loss"), 357 xfail("nn.functional.soft_margin_loss"), 358 xfail("nn.functional.softplus"), 359 xfail("nn.functional.softshrink"), 360 xfail("nn.functional.threshold"), 361 xfail("nn.functional.triplet_margin_loss"), 362 xfail("nn.functional.triplet_margin_with_distance_loss"), 363 xfail("nn.functional.unfold"), 364 xfail("nn.functional.upsample_bilinear"), 365 xfail("nn.functional.upsample_nearest"), 366 xfail("nonzero"), 367 xfail("normal"), 368 xfail("normal", "number_mean"), 369 xfail("normal", "in_place"), 370 xfail("ormqr"), 371 xfail("ones"), 372 xfail("pca_lowrank"), 373 xfail("pinverse"), 374 xfail("polar"), 375 xfail("put"), 376 xfail("quantile"), 377 xfail("rand_like"), 378 xfail("randint_like"), 379 xfail("randint"), 380 xfail("randn"), 381 xfail("randn_like"), 382 xfail("renorm"), 383 xfail("repeat_interleave"), 384 xfail("resize_"), 385 xfail("resize_as_"), 386 xfail("roll"), 387 xfail("rot90"), 388 xfail("rsub"), 389 xfail("scalar_tensor"), 390 xfail("scatter_add"), 391 xfail("scatter_reduce", "amax"), 392 xfail("scatter_reduce", "amin"), 393 xfail("scatter_reduce", "mean"), 394 xfail("scatter_reduce", "prod"), 395 xfail("scatter_reduce", "sum"), 396 xfail("searchsorted"), 397 xfail("select"), 398 xfail("select_scatter"), 399 xfail("sort"), 400 xfail("sparse.sampled_addmm"), 401 xfail("sparse.mm", "reduce"), 402 xfail("special.airy_ai"), 403 xfail("special.bessel_j0"), 404 xfail("special.bessel_j1"), 405 xfail("special.bessel_y0"), 406 xfail("special.bessel_y1"), 407 xfail("special.chebyshev_polynomial_t"), 408 xfail("special.chebyshev_polynomial_u"), 409 xfail("special.entr"), 410 xfail("special.erfcx"), 411 xfail("special.hermite_polynomial_h"), 412 xfail("special.hermite_polynomial_he"), 413 xfail("special.i0e"), 414 xfail("special.i1"), 415 xfail("special.i1e"), 416 xfail("special.laguerre_polynomial_l"), 417 xfail("special.log_ndtr"), 418 xfail("special.modified_bessel_i0"), 419 xfail("special.modified_bessel_i1"), 420 xfail("special.modified_bessel_k0"), 421 xfail("special.modified_bessel_k1"), 422 xfail("special.ndtri"), 423 xfail("special.scaled_modified_bessel_k0"), 424 xfail("special.scaled_modified_bessel_k1"), 425 xfail("special.spherical_bessel_j0"), 426 xfail("special.xlog1py"), 427 xfail("special.zeta"), 428 xfail("squeeze", "multiple"), 429 xfail("signal.windows.bartlett"), 430 xfail("signal.windows.blackman"), 431 xfail("signal.windows.cosine"), 432 xfail("signal.windows.exponential"), 433 xfail("signal.windows.gaussian"), 434 xfail("signal.windows.general_cosine"), 435 xfail("signal.windows.general_hamming"), 436 xfail("signal.windows.hamming"), 437 xfail("signal.windows.hann"), 438 xfail("signal.windows.nuttall"), 439 xfail("signal.windows.kaiser"), 440 xfail("stack"), 441 xfail("std"), 442 xfail("std", "unbiased"), 443 xfail("std_mean"), 444 xfail("std_mean", "unbiased"), 445 xfail("stft"), 446 xfail("svd_lowrank"), 447 xfail("t_copy"), 448 xfail("take"), 449 xfail("tensor_split"), 450 xfail("to_sparse"), 451 xfail("trace"), 452 xfail("trapezoid"), 453 xfail("trapz"), 454 xfail("triangular_solve"), 455 xfail("unbind"), 456 xfail("unfold"), 457 xfail("unfold_copy"), 458 xfail("uniform"), 459 xfail("unflatten"), 460 xfail("unique_consecutive"), 461 xfail("unique"), 462 xfail("unsafe_split"), 463 xfail("unsafe_chunk"), 464 xfail("_unsafe_masked_index"), 465 xfail("_unsafe_masked_index_put_accumulate"), 466 xfail("var_mean"), 467 xfail("var_mean", "unbiased"), 468 xfail("vdot"), 469 xfail("view_copy"), 470 xfail("zeros"), 471 # ops inside this might even fail without dtensor 472 # tests, as we rescale op db common test size factor (i.e. L, M, S) 473 # which triggered the original function run failures with input 474 # generation becomes wrong, we skip them for now but should enable later. 475 # TODO: need to clean this list and remove all cases 476 skip("argwhere"), 477 skip("cumprod"), 478 skip("__rmatmul__"), 479 skip("meshgrid", "list_of_tensors"), 480 skip("meshgrid", "variadic_tensors"), 481 skip("nn.functional.scaled_dot_product_attention"), 482 skip("nn.functional.softmin"), 483 skip("nn.functional.embedding"), 484 skip("nn.functional.embedding_bag"), 485 skip("nn.functional.feature_alpha_dropout", "with_train"), 486 skip("nn.functional.feature_alpha_dropout", "without_train"), 487 skip("nn.functional.hinge_embedding_loss"), 488 skip("nn.functional.cosine_embedding_loss"), 489 skip("fft.hfft"), 490 skip("fft.hfft2"), 491 skip("fft.hfft2"), 492 skip("fft.hfftn"), 493 skip("fft.ifftn"), 494 skip("fft.irfft"), 495 skip("istft"), 496 skip("isclose"), 497 skip("isreal"), 498 skip("matmul"), 499 skip("masked.mean"), 500 skip("masked.var"), 501 skip("masked.std"), 502 skip("masked.normalize"), 503 skip("prod"), 504 skip("_segment_reduce", "lengths"), 505 skip("_segment_reduce", "offsets"), 506 # TODO: fix the following ops 507 skip("squeeze"), 508} 509 510 511# Add a list of ops that are currently failing BW pass 512skip_bw = [ 513 None, # corresponds to the transpose ops 'H' and 'T' 514 "torch.bucketize", 515 "torch.conj_physical", 516 "torch.eq", 517 "torch.isfinite", 518 "torch.isnan", 519] 520 521 522OP_DB_WORLD_SIZE = 4 523# DEVICE_TYPE = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() >= OP_DB_WORLD_SIZE else "cpu" 524# TODO: debug cuda illegal memory access issue and re-enable cuda tests 525DEVICE_TYPE = "cpu" 526 527 528class TestDTensorOps(DTensorOpTestBase): 529 @property 530 def world_size(self) -> int: 531 return OP_DB_WORLD_SIZE 532 533 # only allow float dytpe for now, we can relax this constraint 534 # when feel necessary later (i.e when adding quantization support). 535 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 536 @suppress_warnings 537 @ops(op_db, allowed_dtypes=(torch.float,)) 538 @skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails) 539 def test_dtensor_op_db(self, dtype, op): 540 self.mesh = DeviceMesh(DEVICE_TYPE, torch.arange(self.world_size)) 541 542 # test each op with dist tensor inputs and normal inputs 543 def test(): 544 samples = op.sample_inputs(DEVICE_TYPE, dtype, requires_grad=True) 545 for sample_input in samples: 546 args = [sample_input.input] + list(sample_input.args) 547 kwargs = sample_input.kwargs 548 549 self.run_dtensor_crossref(op.op, args, kwargs) 550 # we need to figure out a way to test the out variant, out variant testing 551 # is tricky, as we need to pre allocate the dtensor out, some of them rely 552 # on sharding placements to be pre-known (i.e. mm.out) 553 # if isinstance(expected, torch.Tensor) and op.supports_out: 554 # func(*args, **kwargs, out=expected) 555 556 self.check_dtensor_func(test, op) 557 558 def assert_ref_dtensor_equal(self, dtensor_rs, rs): 559 flat_dtensor_rs = pytree.tree_leaves(dtensor_rs) 560 flat_rs = pytree.tree_leaves(rs) 561 self.assertEqual(len(flat_dtensor_rs), len(flat_rs)) 562 for dtensor_r, r in zip(flat_dtensor_rs, flat_rs): 563 if not isinstance(r, torch.Tensor): 564 continue 565 566 self.assertIsInstance(dtensor_r, torch.Tensor) 567 self.assertEqualOnRank( 568 dtensor_r.shape, 569 r.shape, 570 f"Shape mismatch! original shape:{r.shape}, dtensor shape: {dtensor_r.shape}", 571 ) 572 self.assertEqualOnRank( 573 dtensor_r.requires_grad, 574 r.requires_grad, 575 "op result requires_grad mismatch!" 576 f"original requires_grad: {r.requires_grad}, " 577 f"dtensor requires_grad: {dtensor_r.requires_grad}", 578 ) 579 580 self.assertEqualOnRank(dtensor_r, r) 581 582 def run_dtensor_crossref(self, func, args, kwargs): 583 to_dtensor = DTensorConverter(self.mesh, args, kwargs) 584 585 def concat_res_if_necessary(func, res: object) -> object: 586 # concat the result on corresponding dim for ops like 587 # split, so that we can call backward on a single tensor 588 if (resolve_name(func) is not None) and ("split" in resolve_name(func)): 589 dim = args[2] if len(args) == 3 else 0 590 return torch.cat(res, dim=dim) 591 else: 592 return res 593 594 # TODO: also handle cases where func raise an exception 595 rs = func(*args, **kwargs) 596 rs = concat_res_if_necessary(func, rs) 597 598 def to_replicate(e: object) -> object: 599 return e.full_tensor() if isinstance(e, DTensor) else e 600 601 try: 602 # Suppress warnings, this doesn't matter for test_meta.py 603 # but it does matter if you want to use this decorator 604 # for cross-ref testing, as some tests may be looking at 605 # errors 606 with warnings.catch_warnings(): 607 warnings.simplefilter("ignore") 608 # for every comb of sharding choices, we test if it works 609 for dtensor_args, dtensor_kwargs in to_dtensor: 610 # Only attempt if we managed to convert all tensors to DTensor 611 # (if any of them failed, we're in a mixed tensor situation and 612 # this is not allowed in DTensor) 613 if to_dtensor.successful(): 614 # Handle special cases first if there's any 615 # Suppress warnings, this doesn't matter for test_meta.py 616 # but it does matter if you want to use this decorator 617 # for cross-ref testing, as some tests may be looking at 618 # errors 619 dtensor_rs = func(*dtensor_args, **dtensor_kwargs) 620 621 # we need to skip tests containing tensors of zero elements for now. 622 # see issue: https://github.com/pytorch/tau/issues/470 623 # TODO remove this once issue above fixed. 624 flat_args = pytree.tree_leaves(dtensor_rs) 625 if any( 626 isinstance(e, torch.Tensor) and e.numel() == 0 627 for e in flat_args 628 ): 629 continue 630 631 # redistribute/all_gather the results to compare with normal output 632 dtensor_rs = tree_map(to_replicate, dtensor_rs) 633 dtensor_rs = concat_res_if_necessary(func, dtensor_rs) 634 try: 635 if resolve_name(func) not in skip_bw: 636 if isinstance(dtensor_rs, DTensor): 637 dtensor_rs.to_local().sum().backward() 638 elif isinstance(dtensor_rs, tuple): 639 dtensor_rs[0].to_local().sum().backward() 640 641 except Exception as e: 642 # TODO(anj): Remove this guard exception after gaining more confidence. 643 if torch.distributed.get_rank() == 0: 644 print( 645 f"failed to run BW: {resolve_name(func)}, {func}, {str(e)})" 646 ) 647 self.assert_ref_dtensor_equal(dtensor_rs, rs) 648 else: 649 raise RuntimeError( 650 f"failed to convert args to DTensor; " 651 f"originally (*{args}, **{kwargs})" 652 ) 653 except Exception as e: 654 raise RuntimeError( 655 f"failed to run: {resolve_name(func)}, with (*{args}, **{kwargs})" 656 ) from e 657 658 return rs 659 660 def check_dtensor_func(self, test_func, opinfo, dry_run=False): 661 try: 662 test_func() 663 except Exception: 664 if not dry_run: 665 raise 666 if dist.get_rank() == 0: 667 if opinfo.variant_test_name: 668 print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),") 669 else: 670 print(f"xfail('{opinfo.name}'),") 671 672 673# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU) 674instantiate_device_type_tests(TestDTensorOps, globals(), only_for=(DEVICE_TYPE,)) 675 676 677if __name__ == "__main__": 678 run_tests() 679