xref: /aosp_15_r20/external/pytorch/test/distributed/_tensor/test_dtensor_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3
4import unittest
5import warnings
6
7import torch
8import torch.distributed as dist
9import torch.testing._internal.common_methods_invocations as common_ops
10from torch.distributed._tensor import DeviceMesh, DTensor
11from torch.overrides import resolve_name
12from torch.testing._internal.common_device_type import (
13    instantiate_device_type_tests,
14    ops,
15)
16from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db
17from torch.testing._internal.common_utils import (
18    run_tests,
19    suppress_warnings,
20    TEST_WITH_ASAN,
21)
22from torch.testing._internal.distributed._tensor.common_dtensor import (
23    DTensorConverter,
24    DTensorOpTestBase,
25)
26from torch.utils import _pytree as pytree
27from torch.utils._pytree import tree_map
28
29
30# rewrite common size variables to sth can be sharded evenly
31# we can enable uneven shards later, but need to adjust more on
32# sample inputs (i.e. view/reshape need to adjust shape size as well)
33common_ops.L = 24
34common_ops.M = 12
35common_ops.S = 4
36common_ops.XS = 2
37
38
39# Copied from functorch
40def xfail(op_name, variant_name="", *, device_type=None, dtypes=None):
41    return (op_name, variant_name, device_type, dtypes, True)
42
43
44def skip(op_name, variant_name="", *, device_type=None, dtypes=None):
45    return (op_name, variant_name, device_type, dtypes, False)
46
47
48def skipOps(test_case_name, base_test_name, to_skip):
49    all_opinfos = op_db
50    for xfail in to_skip:
51        op_name, variant_name, device_type, dtypes, expected_failure = xfail
52        matching_opinfos = [
53            o
54            for o in all_opinfos
55            if o.name == op_name and o.variant_test_name == variant_name
56        ]
57        assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
58        for opinfo in matching_opinfos:
59            decorators = list(opinfo.decorators)
60            if expected_failure:
61                decorator = DecorateInfo(
62                    unittest.expectedFailure,
63                    test_case_name,
64                    base_test_name,
65                    device_type=device_type,
66                    dtypes=dtypes,
67                )
68                decorators.append(decorator)
69            else:
70                decorator = DecorateInfo(
71                    unittest.skip("Skipped!"),
72                    test_case_name,
73                    base_test_name,
74                    device_type=device_type,
75                    dtypes=dtypes,
76                )
77                decorators.append(decorator)
78            opinfo.decorators = tuple(decorators)
79
80    # This decorator doesn't modify fn in any way
81    def wrapped(fn):
82        return fn
83
84    return wrapped
85
86
87# Re-generate this failed list, turn on dry_run of the below func
88# check_dtensor_func(self, test, op, dry_run=True), then run sth
89# like python test/distributed/_tensor/test_dtensor_ops.py > failed.expect
90dtensor_fails = {
91    # these sometimes pass and sometimes fail
92    # we need to remove many of them from list once op
93    # get full support with varying sharding specs
94    xfail("__getitem__"),
95    xfail("__rsub__"),
96    xfail("_chunk_cat"),
97    xfail("_native_batch_norm_legit"),
98    xfail("_upsample_bilinear2d_aa"),
99    xfail("addbmm"),
100    xfail("addmv"),
101    xfail("addr"),
102    xfail("all"),
103    xfail("allclose"),
104    xfail("alias_copy"),
105    xfail("amax"),
106    xfail("amin"),
107    xfail("aminmax"),
108    xfail("any"),
109    xfail("arange"),
110    xfail("argmax"),
111    xfail("argmin"),
112    xfail("argsort"),
113    xfail("as_strided"),
114    xfail("as_strided", "partial_views"),
115    xfail("as_strided_copy"),
116    xfail("as_strided_scatter"),
117    xfail("bernoulli"),
118    xfail("_batch_norm_with_update"),
119    xfail("block_diag"),
120    xfail("broadcast_shapes"),
121    xfail("cauchy"),
122    xfail("cdist"),
123    xfail("cholesky"),
124    xfail("cholesky_inverse"),
125    xfail("cholesky_solve"),
126    xfail("chunk"),
127    xfail("clamp"),
128    xfail("clamp_max"),
129    xfail("clamp_min"),
130    xfail("combinations"),
131    xfail("complex"),
132    xfail("constant_pad_nd"),
133    xfail("count_nonzero"),
134    xfail("cross"),
135    xfail("cummax"),
136    xfail("cummin"),
137    xfail("cumsum"),
138    xfail("cumulative_trapezoid"),
139    xfail("diagonal_scatter"),
140    xfail("dist"),
141    xfail("dot"),
142    xfail("empty"),
143    xfail("empty_strided"),
144    xfail("empty_like"),
145    xfail("empty_permuted"),
146    xfail("expand_copy"),
147    xfail("exponential"),
148    xfail("equal"),
149    xfail("eye"),
150    xfail("fft.fft2"),
151    xfail("fft.fft"),
152    xfail("fft.fftn"),
153    xfail("fft.fftshift"),
154    xfail("fft.ifft2"),
155    xfail("fft.ifft"),
156    xfail("fft.ifftshift"),
157    xfail("fft.ihfft2"),
158    xfail("fft.ihfft"),
159    xfail("fft.ihfftn"),
160    xfail("fft.irfft2"),
161    xfail("fft.irfftn"),
162    xfail("fft.rfft2"),
163    xfail("fft.rfft"),
164    xfail("fft.rfftn"),
165    xfail("fill"),
166    xfail("flip"),
167    xfail("fliplr"),
168    xfail("flipud"),
169    xfail("floor_divide"),
170    xfail("fmax"),
171    xfail("fmin"),
172    xfail("frexp"),
173    xfail("full"),
174    xfail("full_like"),
175    xfail("gather"),
176    xfail("geometric"),
177    xfail("geqrf"),
178    xfail("grid_sampler_2d"),
179    xfail("gradient"),
180    xfail("heaviside"),
181    xfail("histc"),
182    xfail("histogram"),
183    xfail("histogramdd"),
184    xfail("index_add"),
185    xfail("index_copy"),
186    xfail("index_fill"),
187    xfail("index_put"),
188    xfail("index_reduce", "prod"),
189    xfail("index_reduce", "mean"),
190    xfail("index_reduce", "amax"),
191    xfail("index_reduce", "amin"),
192    xfail("index_select"),
193    xfail("isin"),
194    xfail("kthvalue"),
195    xfail("linalg.cholesky"),
196    xfail("linalg.cholesky_ex"),
197    xfail("linalg.cross"),
198    xfail("linalg.det"),
199    xfail("linalg.det", "singular"),
200    xfail("linalg.eig"),
201    xfail("linalg.eigvals"),
202    xfail("linalg.householder_product"),
203    xfail("linalg.inv"),
204    xfail("linalg.inv_ex"),
205    xfail("linalg.ldl_factor"),
206    xfail("linalg.ldl_factor_ex"),
207    xfail("linalg.ldl_solve"),
208    xfail("linalg.lstsq"),
209    xfail("linalg.lstsq", "grad_oriented"),
210    xfail("linalg.lu"),
211    xfail("linalg.lu_factor"),
212    xfail("linalg.lu_factor_ex"),
213    xfail("linalg.lu_solve"),
214    xfail("linalg.matrix_norm"),
215    xfail("linalg.matrix_power"),
216    xfail("linalg.matrix_rank"),
217    xfail("linalg.matrix_rank", "hermitian"),
218    xfail("linalg.multi_dot"),
219    xfail("linalg.norm"),
220    xfail("linalg.norm", "subgradients_at_zero"),
221    xfail("linalg.pinv"),
222    xfail("linalg.pinv", "hermitian"),
223    xfail("linalg.slogdet"),
224    xfail("linalg.solve"),
225    xfail("linalg.solve_ex"),
226    xfail("linalg.solve_triangular"),
227    xfail("linalg.tensorinv"),
228    xfail("linalg.tensorsolve"),
229    xfail("linalg.vander"),
230    xfail("linalg.vecdot"),
231    xfail("linspace"),
232    xfail("linspace", "tensor_overload"),
233    xfail("log_normal"),
234    xfail("logcumsumexp"),
235    xfail("logdet"),
236    xfail("logspace"),
237    xfail("logspace", "tensor_overload"),
238    xfail("logsumexp"),
239    xfail("lu"),
240    xfail("lu_solve"),
241    xfail("lu_unpack"),
242    xfail("masked_fill"),
243    xfail("masked_scatter"),
244    xfail("masked_select"),
245    xfail("masked.amax"),
246    xfail("masked.amin"),
247    xfail("masked.argmax"),
248    xfail("masked.argmin"),
249    xfail("masked.cumprod"),
250    xfail("masked.cumsum"),
251    xfail("masked.logsumexp"),
252    xfail("masked.median"),
253    xfail("matrix_exp"),
254    xfail("max", "binary"),
255    xfail("max", "reduction_with_dim"),
256    xfail("maximum"),
257    xfail("median"),
258    xfail("min", "binary"),
259    xfail("min", "reduction_with_dim"),
260    xfail("minimum"),
261    xfail("mode"),
262    xfail("msort"),
263    xfail("multinomial"),
264    xfail("mv"),
265    xfail("max_pool2d_with_indices_backward", ""),
266    xfail("nanmean"),
267    xfail("nanmedian"),
268    xfail("nanquantile"),
269    xfail("nansum"),
270    xfail("native_batch_norm"),
271    xfail("native_dropout_backward"),
272    xfail("narrow_copy"),
273    xfail("ne"),
274    xfail("new_empty"),
275    xfail("new_empty_strided"),
276    xfail("transpose"),
277    xfail("nn.functional.adaptive_avg_pool1d"),
278    xfail("nn.functional.adaptive_avg_pool2d"),
279    xfail("nn.functional.adaptive_avg_pool3d"),
280    xfail("nn.functional.adaptive_max_pool1d"),
281    xfail("nn.functional.adaptive_max_pool2d"),
282    xfail("nn.functional.adaptive_max_pool3d"),
283    xfail("nn.functional.alpha_dropout"),
284    xfail("nn.functional.avg_pool1d"),
285    xfail("nn.functional.avg_pool2d"),
286    xfail("nn.functional.avg_pool3d"),
287    xfail("nn.functional.batch_norm"),
288    xfail("nn.functional.batch_norm", "without_cudnn"),
289    xfail("nn.functional.bilinear"),
290    xfail("nn.functional.binary_cross_entropy"),
291    xfail("nn.functional.binary_cross_entropy_with_logits"),
292    xfail("nn.functional.celu"),
293    xfail("nn.functional.conv1d"),
294    xfail("nn.functional.conv2d"),
295    xfail("nn.functional.conv3d"),
296    xfail("nn.functional.conv_transpose1d"),
297    xfail("nn.functional.conv_transpose2d"),
298    xfail("nn.functional.conv_transpose3d"),
299    xfail("nn.functional.cosine_similarity"),
300    xfail("nn.functional.ctc_loss"),
301    xfail("nn.functional.dropout"),
302    xfail("nn.functional.dropout2d"),
303    xfail("nn.functional.dropout3d"),
304    xfail("nn.functional.elu"),
305    xfail("nn.functional.fractional_max_pool2d"),
306    xfail("nn.functional.fractional_max_pool3d"),
307    xfail("nn.functional.glu"),
308    xfail("nn.functional.grid_sample"),
309    xfail("nn.functional.group_norm"),
310    xfail("nn.functional.hardshrink"),
311    xfail("nn.functional.hardsigmoid"),
312    xfail("nn.functional.hardswish"),
313    xfail("nn.functional.hardtanh"),
314    xfail("nn.functional.huber_loss"),
315    xfail("nn.functional.instance_norm"),
316    xfail("nn.functional.interpolate", "area"),
317    xfail("nn.functional.interpolate", "bicubic"),
318    xfail("nn.functional.interpolate", "bilinear"),
319    xfail("nn.functional.interpolate", "linear"),
320    xfail("nn.functional.interpolate", "nearest"),
321    xfail("nn.functional.interpolate", "nearest-exact"),
322    xfail("nn.functional.interpolate", "trilinear"),
323    xfail("nn.functional.leaky_relu"),
324    xfail("nn.functional.linear"),
325    xfail("nn.functional.local_response_norm"),
326    xfail("nn.functional.logsigmoid"),
327    xfail("nn.functional.margin_ranking_loss"),
328    xfail("nn.functional.max_pool1d"),
329    xfail("nn.functional.max_pool2d"),
330    xfail("nn.functional.max_pool3d"),
331    xfail("nn.functional.max_unpool1d"),
332    xfail("nn.functional.max_unpool1d", "grad"),
333    xfail("nn.functional.max_unpool2d"),
334    xfail("nn.functional.max_unpool2d", "grad"),
335    xfail("nn.functional.max_unpool3d"),
336    xfail("nn.functional.max_unpool3d", "grad"),
337    xfail("nn.functional.mish"),
338    xfail("nn.functional.mse_loss"),
339    xfail("nn.functional.multi_margin_loss"),
340    xfail("nn.functional.multi_head_attention_forward"),
341    xfail("nn.functional.multilabel_margin_loss"),
342    xfail("nn.functional.multilabel_soft_margin_loss"),
343    xfail("nn.functional.normalize"),
344    xfail("nn.functional.pad", "constant"),
345    xfail("nn.functional.pad", "reflect"),
346    xfail("nn.functional.pad", "replicate"),
347    xfail("nn.functional.pad", "replicate_negative"),
348    xfail("nn.functional.pairwise_distance"),
349    xfail("nn.functional.pdist"),
350    xfail("nn.functional.pixel_shuffle"),
351    xfail("nn.functional.pixel_unshuffle"),
352    xfail("nn.functional.prelu"),
353    xfail("nn.functional.relu6"),
354    xfail("nn.functional.rrelu"),
355    xfail("nn.functional.selu"),
356    xfail("nn.functional.smooth_l1_loss"),
357    xfail("nn.functional.soft_margin_loss"),
358    xfail("nn.functional.softplus"),
359    xfail("nn.functional.softshrink"),
360    xfail("nn.functional.threshold"),
361    xfail("nn.functional.triplet_margin_loss"),
362    xfail("nn.functional.triplet_margin_with_distance_loss"),
363    xfail("nn.functional.unfold"),
364    xfail("nn.functional.upsample_bilinear"),
365    xfail("nn.functional.upsample_nearest"),
366    xfail("nonzero"),
367    xfail("normal"),
368    xfail("normal", "number_mean"),
369    xfail("normal", "in_place"),
370    xfail("ormqr"),
371    xfail("ones"),
372    xfail("pca_lowrank"),
373    xfail("pinverse"),
374    xfail("polar"),
375    xfail("put"),
376    xfail("quantile"),
377    xfail("rand_like"),
378    xfail("randint_like"),
379    xfail("randint"),
380    xfail("randn"),
381    xfail("randn_like"),
382    xfail("renorm"),
383    xfail("repeat_interleave"),
384    xfail("resize_"),
385    xfail("resize_as_"),
386    xfail("roll"),
387    xfail("rot90"),
388    xfail("rsub"),
389    xfail("scalar_tensor"),
390    xfail("scatter_add"),
391    xfail("scatter_reduce", "amax"),
392    xfail("scatter_reduce", "amin"),
393    xfail("scatter_reduce", "mean"),
394    xfail("scatter_reduce", "prod"),
395    xfail("scatter_reduce", "sum"),
396    xfail("searchsorted"),
397    xfail("select"),
398    xfail("select_scatter"),
399    xfail("sort"),
400    xfail("sparse.sampled_addmm"),
401    xfail("sparse.mm", "reduce"),
402    xfail("special.airy_ai"),
403    xfail("special.bessel_j0"),
404    xfail("special.bessel_j1"),
405    xfail("special.bessel_y0"),
406    xfail("special.bessel_y1"),
407    xfail("special.chebyshev_polynomial_t"),
408    xfail("special.chebyshev_polynomial_u"),
409    xfail("special.entr"),
410    xfail("special.erfcx"),
411    xfail("special.hermite_polynomial_h"),
412    xfail("special.hermite_polynomial_he"),
413    xfail("special.i0e"),
414    xfail("special.i1"),
415    xfail("special.i1e"),
416    xfail("special.laguerre_polynomial_l"),
417    xfail("special.log_ndtr"),
418    xfail("special.modified_bessel_i0"),
419    xfail("special.modified_bessel_i1"),
420    xfail("special.modified_bessel_k0"),
421    xfail("special.modified_bessel_k1"),
422    xfail("special.ndtri"),
423    xfail("special.scaled_modified_bessel_k0"),
424    xfail("special.scaled_modified_bessel_k1"),
425    xfail("special.spherical_bessel_j0"),
426    xfail("special.xlog1py"),
427    xfail("special.zeta"),
428    xfail("squeeze", "multiple"),
429    xfail("signal.windows.bartlett"),
430    xfail("signal.windows.blackman"),
431    xfail("signal.windows.cosine"),
432    xfail("signal.windows.exponential"),
433    xfail("signal.windows.gaussian"),
434    xfail("signal.windows.general_cosine"),
435    xfail("signal.windows.general_hamming"),
436    xfail("signal.windows.hamming"),
437    xfail("signal.windows.hann"),
438    xfail("signal.windows.nuttall"),
439    xfail("signal.windows.kaiser"),
440    xfail("stack"),
441    xfail("std"),
442    xfail("std", "unbiased"),
443    xfail("std_mean"),
444    xfail("std_mean", "unbiased"),
445    xfail("stft"),
446    xfail("svd_lowrank"),
447    xfail("t_copy"),
448    xfail("take"),
449    xfail("tensor_split"),
450    xfail("to_sparse"),
451    xfail("trace"),
452    xfail("trapezoid"),
453    xfail("trapz"),
454    xfail("triangular_solve"),
455    xfail("unbind"),
456    xfail("unfold"),
457    xfail("unfold_copy"),
458    xfail("uniform"),
459    xfail("unflatten"),
460    xfail("unique_consecutive"),
461    xfail("unique"),
462    xfail("unsafe_split"),
463    xfail("unsafe_chunk"),
464    xfail("_unsafe_masked_index"),
465    xfail("_unsafe_masked_index_put_accumulate"),
466    xfail("var_mean"),
467    xfail("var_mean", "unbiased"),
468    xfail("vdot"),
469    xfail("view_copy"),
470    xfail("zeros"),
471    # ops inside this might even fail without dtensor
472    # tests, as we rescale op db common test size factor (i.e. L, M, S)
473    # which triggered the original function run failures with input
474    # generation becomes wrong, we skip them for now but should enable later.
475    # TODO: need to clean this list and remove all cases
476    skip("argwhere"),
477    skip("cumprod"),
478    skip("__rmatmul__"),
479    skip("meshgrid", "list_of_tensors"),
480    skip("meshgrid", "variadic_tensors"),
481    skip("nn.functional.scaled_dot_product_attention"),
482    skip("nn.functional.softmin"),
483    skip("nn.functional.embedding"),
484    skip("nn.functional.embedding_bag"),
485    skip("nn.functional.feature_alpha_dropout", "with_train"),
486    skip("nn.functional.feature_alpha_dropout", "without_train"),
487    skip("nn.functional.hinge_embedding_loss"),
488    skip("nn.functional.cosine_embedding_loss"),
489    skip("fft.hfft"),
490    skip("fft.hfft2"),
491    skip("fft.hfft2"),
492    skip("fft.hfftn"),
493    skip("fft.ifftn"),
494    skip("fft.irfft"),
495    skip("istft"),
496    skip("isclose"),
497    skip("isreal"),
498    skip("matmul"),
499    skip("masked.mean"),
500    skip("masked.var"),
501    skip("masked.std"),
502    skip("masked.normalize"),
503    skip("prod"),
504    skip("_segment_reduce", "lengths"),
505    skip("_segment_reduce", "offsets"),
506    # TODO: fix the following ops
507    skip("squeeze"),
508}
509
510
511# Add a list of ops that are currently failing BW pass
512skip_bw = [
513    None,  # corresponds to the transpose ops 'H' and 'T'
514    "torch.bucketize",
515    "torch.conj_physical",
516    "torch.eq",
517    "torch.isfinite",
518    "torch.isnan",
519]
520
521
522OP_DB_WORLD_SIZE = 4
523# DEVICE_TYPE = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() >= OP_DB_WORLD_SIZE else "cpu"
524# TODO: debug cuda illegal memory access issue and re-enable cuda tests
525DEVICE_TYPE = "cpu"
526
527
528class TestDTensorOps(DTensorOpTestBase):
529    @property
530    def world_size(self) -> int:
531        return OP_DB_WORLD_SIZE
532
533    # only allow float dytpe for now, we can relax this constraint
534    # when feel necessary later (i.e when adding quantization support).
535    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
536    @suppress_warnings
537    @ops(op_db, allowed_dtypes=(torch.float,))
538    @skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails)
539    def test_dtensor_op_db(self, dtype, op):
540        self.mesh = DeviceMesh(DEVICE_TYPE, torch.arange(self.world_size))
541
542        # test each op with dist tensor inputs and normal inputs
543        def test():
544            samples = op.sample_inputs(DEVICE_TYPE, dtype, requires_grad=True)
545            for sample_input in samples:
546                args = [sample_input.input] + list(sample_input.args)
547                kwargs = sample_input.kwargs
548
549                self.run_dtensor_crossref(op.op, args, kwargs)
550                # we need to figure out a way to test the out variant, out variant testing
551                # is tricky, as we need to pre allocate the dtensor out, some of them rely
552                # on sharding placements to be pre-known (i.e. mm.out)
553                # if isinstance(expected, torch.Tensor) and op.supports_out:
554                #     func(*args, **kwargs, out=expected)
555
556        self.check_dtensor_func(test, op)
557
558    def assert_ref_dtensor_equal(self, dtensor_rs, rs):
559        flat_dtensor_rs = pytree.tree_leaves(dtensor_rs)
560        flat_rs = pytree.tree_leaves(rs)
561        self.assertEqual(len(flat_dtensor_rs), len(flat_rs))
562        for dtensor_r, r in zip(flat_dtensor_rs, flat_rs):
563            if not isinstance(r, torch.Tensor):
564                continue
565
566            self.assertIsInstance(dtensor_r, torch.Tensor)
567            self.assertEqualOnRank(
568                dtensor_r.shape,
569                r.shape,
570                f"Shape mismatch! original shape:{r.shape}, dtensor shape: {dtensor_r.shape}",
571            )
572            self.assertEqualOnRank(
573                dtensor_r.requires_grad,
574                r.requires_grad,
575                "op result requires_grad mismatch!"
576                f"original requires_grad: {r.requires_grad}, "
577                f"dtensor requires_grad: {dtensor_r.requires_grad}",
578            )
579
580            self.assertEqualOnRank(dtensor_r, r)
581
582    def run_dtensor_crossref(self, func, args, kwargs):
583        to_dtensor = DTensorConverter(self.mesh, args, kwargs)
584
585        def concat_res_if_necessary(func, res: object) -> object:
586            # concat the result on corresponding dim for ops like
587            # split, so that we can call backward on a single tensor
588            if (resolve_name(func) is not None) and ("split" in resolve_name(func)):
589                dim = args[2] if len(args) == 3 else 0
590                return torch.cat(res, dim=dim)
591            else:
592                return res
593
594        # TODO: also handle cases where func raise an exception
595        rs = func(*args, **kwargs)
596        rs = concat_res_if_necessary(func, rs)
597
598        def to_replicate(e: object) -> object:
599            return e.full_tensor() if isinstance(e, DTensor) else e
600
601        try:
602            # Suppress warnings, this doesn't matter for test_meta.py
603            # but it does matter if you want to use this decorator
604            # for cross-ref testing, as some tests may be looking at
605            # errors
606            with warnings.catch_warnings():
607                warnings.simplefilter("ignore")
608                # for every comb of sharding choices, we test if it works
609                for dtensor_args, dtensor_kwargs in to_dtensor:
610                    # Only attempt if we managed to convert all tensors to DTensor
611                    # (if any of them failed, we're in a mixed tensor situation and
612                    # this is not allowed in DTensor)
613                    if to_dtensor.successful():
614                        # Handle special cases first if there's any
615                        # Suppress warnings, this doesn't matter for test_meta.py
616                        # but it does matter if you want to use this decorator
617                        # for cross-ref testing, as some tests may be looking at
618                        # errors
619                        dtensor_rs = func(*dtensor_args, **dtensor_kwargs)
620
621                        # we need to skip tests containing tensors of zero elements for now.
622                        # see issue: https://github.com/pytorch/tau/issues/470
623                        # TODO remove this once issue above fixed.
624                        flat_args = pytree.tree_leaves(dtensor_rs)
625                        if any(
626                            isinstance(e, torch.Tensor) and e.numel() == 0
627                            for e in flat_args
628                        ):
629                            continue
630
631                        # redistribute/all_gather the results to compare with normal output
632                        dtensor_rs = tree_map(to_replicate, dtensor_rs)
633                        dtensor_rs = concat_res_if_necessary(func, dtensor_rs)
634                        try:
635                            if resolve_name(func) not in skip_bw:
636                                if isinstance(dtensor_rs, DTensor):
637                                    dtensor_rs.to_local().sum().backward()
638                                elif isinstance(dtensor_rs, tuple):
639                                    dtensor_rs[0].to_local().sum().backward()
640
641                        except Exception as e:
642                            # TODO(anj): Remove this guard exception after gaining more confidence.
643                            if torch.distributed.get_rank() == 0:
644                                print(
645                                    f"failed to run BW: {resolve_name(func)}, {func}, {str(e)})"
646                                )
647                        self.assert_ref_dtensor_equal(dtensor_rs, rs)
648                    else:
649                        raise RuntimeError(
650                            f"failed to convert args to DTensor; "
651                            f"originally (*{args}, **{kwargs})"
652                        )
653        except Exception as e:
654            raise RuntimeError(
655                f"failed to run: {resolve_name(func)}, with (*{args}, **{kwargs})"
656            ) from e
657
658        return rs
659
660    def check_dtensor_func(self, test_func, opinfo, dry_run=False):
661        try:
662            test_func()
663        except Exception:
664            if not dry_run:
665                raise
666            if dist.get_rank() == 0:
667                if opinfo.variant_test_name:
668                    print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),")
669                else:
670                    print(f"xfail('{opinfo.name}'),")
671
672
673# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
674instantiate_device_type_tests(TestDTensorOps, globals(), only_for=(DEVICE_TYPE,))
675
676
677if __name__ == "__main__":
678    run_tests()
679