xref: /aosp_15_r20/external/pytorch/test/distributed/_tensor/test_dtensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3
4from numpy.testing import assert_array_equal
5
6import torch
7import torch.nn.functional as F
8from torch.distributed._functional_collectives import AsyncCollectiveTensor
9
10from torch.distributed._tensor import (
11    DeviceMesh,
12    distribute_tensor,
13    DTensor,
14    init_device_mesh,
15)
16from torch.distributed._tensor.debug import CommDebugMode
17from torch.distributed._tensor.placement_types import (
18    DTensorSpec,
19    Partial,
20    Replicate,
21    Shard,
22    TensorMeta,
23)
24from torch.distributed.tensor.parallel import (
25    ColwiseParallel,
26    parallelize_module,
27    RowwiseParallel,
28)
29
30from torch.testing._internal.common_utils import run_tests
31from torch.testing._internal.distributed._tensor.common_dtensor import (
32    DTensorTestBase,
33    with_comms,
34)
35
36
37c10d_functional = torch.ops.c10d_functional
38
39
40class DummyMLP(torch.nn.Module):
41    def __init__(self, device):
42        super().__init__()
43        self.net1 = torch.nn.Linear(5, 1024, device=device)
44        self.relu = torch.nn.ReLU()
45        self.net2 = torch.nn.Linear(1024, 4, device=device)
46
47    def forward(self, x):
48        return self.net2(F.relu(self.net1(x)))
49
50    def reset_parameters(self, *args, **kwargs):
51        with torch.no_grad():
52            self.net1.weight.fill_(0.5)
53            self.net2.weight.fill_(1)
54            self.net1.bias.fill_(1.5)
55            self.net2.bias.fill_(1.2)
56
57
58class DTensorTest(DTensorTestBase):
59    @with_comms
60    def test_dtensor_constructor(self):
61        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
62        placements = [Shard(0)]
63        local_tensor = torch.randn(3, 3, requires_grad=True)
64
65        spec = DTensorSpec(
66            device_mesh,
67            tuple(placements),
68            tensor_meta=TensorMeta(
69                torch.Size([self.world_size * 3, 3]),
70                local_tensor.stride(),
71                local_tensor.dtype,
72            ),
73        )
74
75        dist_tensor = DTensor(
76            local_tensor,
77            spec,
78            requires_grad=True,
79        )
80        self.assertEqual(dist_tensor.size(), torch.Size((self.world_size * 3, 3)))
81
82        with self.assertWarnsRegex(UserWarning, "To construct"):
83            DTensor(
84                local_tensor,
85                spec,
86                requires_grad=False,
87            )
88
89    @with_comms
90    def test_meta_dtensor(self):
91        device_mesh = self.build_device_mesh()
92        dist_specs = [[Shard(0)], [Replicate()]]
93        meta_tensor = torch.randn(1024, 2048, device="meta")
94        for dist_spec in dist_specs:
95            # Test distribute_tensor on meta tensor
96            meta_dtensor = distribute_tensor(meta_tensor, device_mesh, dist_spec)
97            self.assertTrue(meta_dtensor.is_meta)
98            meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
99            torch.nn.init.constant_(meta_dtensor, 1.2)
100            value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.2)
101            self.assertFalse(meta_dtensor.is_meta)
102            self.assertEqual(meta_dtensor.device.type, self.device_type)
103            self.assertEqual(meta_dtensor.to_local(), value_tensor)
104            # Test from_local on meta tensor
105            meta_dtensor = DTensor.from_local(meta_tensor, device_mesh, dist_spec)
106            meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
107            torch.nn.init.constant_(meta_dtensor, 1.5)
108            self.assertEqual(meta_dtensor.device.type, self.device_type)
109            value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.5)
110            self.assertEqual(meta_dtensor.to_local(), value_tensor)
111
112    @with_comms
113    def test_modules_w_meta_dtensor(self):
114        model = DummyMLP("meta")
115        device_mesh = self.build_device_mesh()
116        parallelize_plan = {
117            "net1": ColwiseParallel(),
118            "net2": RowwiseParallel(),
119        }
120        model_tp = parallelize_module(model, device_mesh, parallelize_plan)
121        model_tp.to_empty(device=self.device_type)
122        model_tp.reset_parameters()
123        optim = torch.optim.SGD(model_tp.parameters(), lr=0.1)
124        model_regular = DummyMLP(self.device_type)
125        model_regular_tp = parallelize_module(
126            model_regular, device_mesh, parallelize_plan
127        )
128        optim_regular = torch.optim.SGD(model_regular_tp.parameters(), lr=0.1)
129        model_regular_tp.reset_parameters()
130        torch.manual_seed(0)
131        inp = torch.randn(20, 5, device=self.device_type)
132
133        output = model_tp(inp)
134        output_regular = model_regular_tp(inp)
135        self.assertEqual(output, output_regular)
136
137        output.sum().backward()
138        output_regular.sum().backward()
139
140        optim.step()
141        optim_regular.step()
142
143        torch.manual_seed(1)
144        inp = torch.randn(20, 5, device=self.device_type)
145        self.assertEqual(model_tp(inp), model_regular_tp(inp))
146
147    @with_comms
148    def test_dtensor_stride(self):
149        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
150        shard0_spec = [Shard(0)]
151        local_tensor = torch.randn(4, 8)
152        global_shape = torch.Size([self.world_size * 4, 8])
153        dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec)
154        # won't affect stride
155        self.assertEqual(dist_tensor.stride(), (8, 1))
156
157        shard1_spec = [Shard(1)]
158        local_tensor = torch.randn(8, 4)
159        global_shape = torch.Size([8, self.world_size * 4])
160        dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec)
161        # will affect stride after DT initialized
162        self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1))
163
164        # if initialized from a transposed mat
165        local_tensor = torch.randn(8, 4, 8)
166        local_tensor_t = local_tensor.permute(1, 2, 0)
167        global_shape = torch.Size([4, self.world_size * 8, 8])
168        self.assertEqual(local_tensor_t.stride(), (8, 1, 32))
169        dist_tensor = DTensor.from_local(local_tensor_t, device_mesh, shard1_spec)
170        global_stride = (8 * self.world_size, 1, 32 * self.world_size)
171        self.assertEqual(dist_tensor.stride(), global_stride)
172
173    @with_comms
174    def test_from_local(self):
175        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
176        placements = [Shard(0)]
177        local_tensor = torch.randn(3, 3)
178        sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
179        self.assertEqual(sharded_tensor.size(), torch.Size([self.world_size * 3, 3]))
180
181        replica_spec = [Replicate()]
182        ddp_tensor = DTensor.from_local(local_tensor, device_mesh, replica_spec)
183        self.assertEqual(ddp_tensor.size(), local_tensor.size())
184
185        partial_spec = [Partial()]
186        partial_tensor = DTensor.from_local(local_tensor, device_mesh, partial_spec)
187        self.assertEqual(partial_tensor.size(), local_tensor.size())
188
189        # test dist tensor works with torch.Tensor during backwards
190        local_tensor_with_grad = torch.randn(3, 3, requires_grad=True)
191        # do some operations on local tensor
192        local_tensor_temp = local_tensor_with_grad * 3
193        # create the dist tensor with non leaf local tensor, dist tensor created
194        # should also be non leaf node
195        dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements)
196        self.assertFalse(dist_tensor.is_leaf)
197        # do some random operations on dist tensor
198        output = dist_tensor * 3
199        self.assertIsInstance(output, DTensor)
200        # trigger .backward() on dist tensor directly
201        local_grad = torch.ones(3, 3)
202        grad_output = DTensor.from_local(local_grad, device_mesh, placements)
203        # run backward directly on dist tensor
204        output.backward(grad_output)
205        # check it gradients flow back to original torch.Tensor
206        self.assertIsNotNone(local_tensor_with_grad.grad)
207        expected_grad = torch.ones(3, 3) * 9
208        self.assertEqual(local_tensor_with_grad.grad, expected_grad)
209
210    @with_comms
211    def test_from_local_uneven_sharding(self):
212        mesh_shape = (self.world_size,)
213        device_mesh = init_device_mesh(self.device_type, mesh_shape)
214
215        uneven_dim0_size = self.world_size + 1
216        global_tensor = torch.randn(uneven_dim0_size, 2)
217        shard_placement = Shard(0)
218        tensor_list, _ = shard_placement._split_tensor(
219            global_tensor,
220            device_mesh.size(mesh_dim=0),
221            with_padding=False,
222            contiguous=True,
223        )
224
225        dtensor = DTensor.from_local(
226            tensor_list[self.rank],
227            device_mesh,
228            (Shard(0),),
229            shape=global_tensor.size(),
230            stride=global_tensor.stride(),
231        )
232
233        self.assertEqual(dtensor.size(), global_tensor.size())
234        self.assertEqual(dtensor.stride(), global_tensor.stride())
235
236    @with_comms
237    def test_from_local_uneven_sharding_raise_error(self):
238        mesh_shape = (self.world_size,)
239        device_mesh = init_device_mesh(self.device_type, mesh_shape)
240
241        uneven_dim0_size = self.world_size + 1
242        global_tensor = torch.randn(uneven_dim0_size, 2)
243        shard_placement = Shard(0)
244        tensor_list, _ = shard_placement._split_tensor(
245            global_tensor,
246            device_mesh.size(mesh_dim=0),
247            with_padding=False,
248            contiguous=True,
249        )
250
251        with self.assertRaisesRegex(
252            RuntimeError, "Please pass both shape and stride at the same time."
253        ):
254            dtensor = DTensor.from_local(
255                tensor_list[self.rank],
256                device_mesh,
257                (Shard(0),),
258                shape=global_tensor.size(),
259            )
260
261        with self.assertRaisesRegex(
262            RuntimeError, "Please pass both shape and stride at the same time."
263        ):
264            dtensor = DTensor.from_local(
265                tensor_list[self.rank],
266                device_mesh,
267                (Shard(0),),
268                stride=global_tensor.stride(),
269            )
270
271    @with_comms
272    def test_from_local_negative_dim(self):
273        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
274        placements = [Shard(-1)]
275        local_tensor = torch.randn(3, 3)
276        sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
277        self.assertEqual(sharded_tensor.placements[0].dim, 1)
278
279    @with_comms
280    def test_to_local(self):
281        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
282        placements = (Shard(0),)
283        local_tensor_with_grad = torch.randn(
284            3, 3, device=self.device_type, requires_grad=True
285        )
286        dist_tensor_shape = torch.Size([self.world_size * 3, 3])
287        spec = DTensorSpec(
288            mesh=device_mesh,
289            placements=placements,
290            tensor_meta=TensorMeta(
291                dist_tensor_shape,
292                local_tensor_with_grad.stride(),
293                local_tensor_with_grad.dtype,
294            ),
295        )
296        sharded_tensor = DTensor(
297            local_tensor_with_grad,
298            spec,
299            requires_grad=True,
300        )
301        self.assertEqual(sharded_tensor.size(), dist_tensor_shape)
302        self.assertEqual(sharded_tensor.to_local(), local_tensor_with_grad)
303
304        # test dist tensor works with torch.Tensor during backwards
305        # dist tensor created is a leaf node, do some operation on dist tensor
306        temp_st = sharded_tensor * 3
307
308        # do some operation on local tensor of the dist tensor
309        new_tensor_with_grad = torch.randn(
310            3, 3, device=self.device_type, requires_grad=True
311        )
312        res = temp_st.to_local() + new_tensor_with_grad
313        # call backward directly on torch.Tensor, and see if it works by
314        # propagating through dist tensor
315        res.sum().backward()
316        self.assertIsNotNone(sharded_tensor.grad)
317
318        self.assertEqual(sharded_tensor.grad.to_local(), torch.ones(3, 3) * 3)
319
320        # test the case when grad stride is different from fwd input.
321        res = sharded_tensor.to_local()
322        model = torch.nn.ReLU()
323        res.register_hook(lambda grad: grad.t())
324        target = torch.randn(3, 3, device=self.device_type)
325        mae_loss = torch.nn.L1Loss()
326        output = mae_loss(model(res), target)
327        # The manual change to grad stride leads to the failure of the copy op afterwards.
328        # so that we need a try-catch here.
329        try:
330            output.backward()
331        except RuntimeError:
332            self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])
333
334        # test the case under no-grad we directly return the local tensor
335        with torch.no_grad():
336            local_no_grad = sharded_tensor.to_local()
337            assert local_no_grad is sharded_tensor._local_tensor
338
339    @with_comms
340    def test_to_local_grad_hint(self):
341        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
342        placements = (Shard(0),)
343        global_tensor = torch.ones(8, 3, requires_grad=True)
344
345        sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
346        comm_mode = CommDebugMode()
347
348        with comm_mode:
349            local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local(
350                grad_placements=[Partial()]
351            )
352            local_out.backward(torch.ones_like(local_out))
353
354        self.assertEqual(
355            comm_mode.comm_counts[c10d_functional.all_gather_into_tensor], 1
356        )
357        self.assertEqual(
358            comm_mode.comm_counts[c10d_functional.reduce_scatter_tensor], 1
359        )
360
361        replica_grad = sharded_dtensor.grad.full_tensor()
362        self.assertEqual(replica_grad, global_tensor * self.world_size)
363
364    @with_comms
365    def test_full_tensor_sync(self):
366        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
367        placements = (Shard(0),)
368        global_tensor = torch.ones(8, 3, requires_grad=True)
369
370        sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
371        full_out = sharded_dtensor.full_tensor()
372        self.assertFalse(isinstance(full_out, AsyncCollectiveTensor))
373        self.assertEqual(full_out, global_tensor)
374
375    @with_comms
376    def test_full_tensor_grad_hint(self):
377        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
378        placements = (Shard(0),)
379        global_tensor = torch.ones(8, 3, requires_grad=True)
380
381        sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
382        local_out = sharded_dtensor.full_tensor(grad_placements=[Partial()])
383        local_out.sum().backward()
384
385        replica_grad = sharded_dtensor.grad.full_tensor()
386        self.assertEqual(replica_grad, global_tensor * self.world_size)
387
388    @with_comms
389    def test_dtensor_new_empty_strided(self):
390        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
391        local_tensor = torch.randn(8, 8, requires_grad=True, device=self.device_type)
392        my_dtensor = distribute_tensor(local_tensor, device_mesh, [Shard(0)])
393        new_strided_dtensor = my_dtensor.new_empty_strided(
394            (8, 8), (8, 1), requires_grad=True
395        )
396        # test the op produces new dtensor and autograd works
397        self.assertEqual(new_strided_dtensor.shape, my_dtensor.shape)
398        new_strided_dtensor.sum().backward()
399        self.assertIsNotNone(new_strided_dtensor.grad)
400        self.assertIsInstance(new_strided_dtensor.grad, DTensor)
401
402        # test backward new_empty_strided with sharding works correctly
403        my_dtensor.to_local().sum().backward()
404        local_tensor.sum().backward()
405        self.assertEqual(my_dtensor.grad, new_strided_dtensor.grad)
406        self.assertEqual(
407            my_dtensor.grad.redistribute(placements=[Replicate()]).to_local(),
408            local_tensor.grad,
409        )
410
411    @with_comms
412    def test_dtensor_async_output(self):
413        # Tests that if the output of some dtensor operations  isn't used in any compute,
414        # the output should be an AsyncCollectiveTensor (representing the fact that
415        # we haven't synced the collective yet).
416        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
417
418        def fn(dt):
419            dt_out_redistribute = dt.redistribute(mesh, [Replicate()], async_op=True)
420            # Make sure we haven't synced yet
421            # TODO: figure out why this is returning None
422            # self.assertTrue(_tensor_needs_wait(dt_out_redistribute))
423            dt_out_redistribute_view = dt_out_redistribute.view(
424                dt_out_redistribute.shape
425            )
426            local_tensor = dt_out_redistribute_view.to_local()
427            return local_tensor
428
429        x = torch.ones((4, 2), device=self.device_type)
430        dt = distribute_tensor(x, mesh, [Shard(0)])
431        out = fn(dt)
432        # Make sure we haven't synced yet
433        self.assertEqual(type(out), AsyncCollectiveTensor)
434        self.assertFalse(out.completed)
435        out_view = out.view(-1)
436
437        # Assert that output is a `AsyncCollectiveTensor`
438        self.assertEqual(type(out_view), AsyncCollectiveTensor)
439        self.assertFalse(out.completed)
440
441        # Use the daa, requiring a sync
442        ref = torch.ones((4, 2), device=self.device_type) + 1
443        ref = ref.view(-1)
444        out_data = out_view + 1
445        self.assertEqual(type(out_data), torch.Tensor)
446        self.assertEqual(out_data, ref)
447
448        # test async_op = False default
449        sync_out = dt.redistribute(mesh, [Replicate()])
450        self.assertFalse(isinstance(sync_out, AsyncCollectiveTensor))
451        self.assertEqual(sync_out.to_local(), x)
452
453    @with_comms
454    def test_from_local_then_to_local(self):
455        # this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works
456        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
457        placements = [Shard(0)]
458
459        # step 1. construct from construct local tensor
460        local_tensor_with_grad = torch.randn(
461            3, 3, device=self.device_type, requires_grad=True
462        )
463        # do some operations on local tensor
464        local_tensor_temp = local_tensor_with_grad + 8
465        # step 2. create the dist tensor with non leaf local tensor, dist tensor
466        # created should also be non leaf node
467        dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements)
468        self.assertFalse(dist_tensor.is_leaf)
469        # do some random operations on dist tensor
470        output = dist_tensor * 6
471        self.assertIsInstance(output, DTensor)
472
473        # step 3. do some operation on local tensor of the dist tensor
474        new_tensor_with_grad = torch.randn(
475            3, 3, device=self.device_type, requires_grad=True
476        )
477        res = output.to_local() + new_tensor_with_grad
478        # call backward directly on torch.Tensor, and see if it works by
479        # propagating all the way back to the original torch.Tensor
480        res.sum().backward()
481        self.assertIsNotNone(local_tensor_with_grad.grad)
482
483        expected_grad = torch.ones(3, 3) * 6
484        self.assertEqual(local_tensor_with_grad.grad, expected_grad)
485
486    @with_comms
487    def test_dtensor_spec_read_only_after_set(self):
488        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
489        placements = [Shard(0)]
490        local_tensor = torch.randn(3, 3)
491        sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
492
493        # modify placements, and dist_tensor's spec should not be changed
494        placements[0] = Replicate()
495        self.assertTrue(sharded_tensor.placements is not placements)
496        self.assertNotEqual(sharded_tensor.placements, placements)
497
498    @with_comms
499    def test_dtensor_spec_hash(self):
500        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
501        placements = [Shard(0)]
502        local_tensor = torch.randn(3, 3)
503        local_tensor2 = torch.randn(3, 3)
504        sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
505        sharded_tensor2 = DTensor.from_local(local_tensor2, device_mesh, placements)
506        # note that DTensorSpec without real tensor data, so the hash would be the same
507        # as long as the mesh, placements and tensor properties are the same
508        self.assertEqual(hash(sharded_tensor._spec), hash(sharded_tensor2._spec))
509
510        # change the placements would change the hash
511        local_tensor3 = torch.ones(3, 3)
512        replica_spec = [Replicate()]
513        replica_tensor = DTensor.from_local(
514            local_tensor3, device_mesh, replica_spec, run_check=False
515        )
516        self.assertNotEqual(hash(sharded_tensor._spec), hash(replica_tensor._spec))
517
518    @with_comms
519    def test_dtensor_properties(self):
520        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
521        placements = [Shard(0)]
522        local_tensor = torch.randn(3, 3)
523        sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
524        self.assertEqual(sharded_tensor.device.type, self.device_type)
525
526    @with_comms
527    def test_dtensor_save_load(self):
528        import io
529
530        device_mesh = self.build_device_mesh()
531        placements = [Shard(0)]
532        local_tensor = torch.randn(3, 3)
533        sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
534        buffer = io.BytesIO()
535        torch.save(sharded_tensor, buffer)
536        buffer.seek(0)
537        reloaded_st = torch.load(buffer)
538        self.assertEqual(sharded_tensor, reloaded_st)
539        # Test weights_only load
540        try:
541            torch.serialization.add_safe_globals(
542                [DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
543            )
544            buffer.seek(0)
545            reloaded_st = torch.load(buffer, weights_only=True)
546            self.assertEqual(sharded_tensor, reloaded_st)
547        finally:
548            torch.serialization.clear_safe_globals()
549
550
551class DTensorMeshTest(DTensorTestBase):
552    @property
553    def world_size(self):
554        return 8
555
556    def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor):
557        if self.rank in mesh:
558            self.assertEqual(tensor, exp_in_mesh)
559        else:
560            self.assertEqual(tensor, exp_out_of_mesh)
561
562    @with_comms
563    def test_dtensor_device_mesh_device_conversion(self):
564        # construct a cuda device mesh
565        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
566
567        # construct from a cpu local tensor with cuda device mesh
568        # should automatically convert the dist tensor to cuda
569        placements = [Shard(0)]
570        local_tensor = torch.randn(3, 3)
571        dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
572        self.assertEqual(dist_tensor.device.type, self.device_type)
573        self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
574
575    @with_comms
576    def test_dtensor_api_device_mesh_context_manager(self):
577        with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh:
578            placements = [Shard(0)]
579            local_tensor = torch.randn(3, 3)
580            sharded_tensor = DTensor.from_local(
581                local_tensor, device_mesh=mesh, placements=placements
582            )
583
584        with DeviceMesh(self.device_type, list(range(self.world_size))):
585            placements = [Shard(0)]
586            local_tensor = torch.randn(3, 3)
587            sharded_tensor = DTensor.from_local(local_tensor, placements=placements)
588            replica_spec = [Replicate()]
589            replica_tensor = sharded_tensor.redistribute(placements=replica_spec)
590            self.assertEqual(
591                replica_tensor.size(), torch.Size([3 * self.world_size, 3])
592            )
593
594        with DeviceMesh(self.device_type, torch.arange(self.world_size)):
595            placements = [Shard(0)]
596            global_shape = torch.Size([3 * self.world_size, 3])
597            global_tensor = torch.randn(global_shape)
598            sharded_tensor = distribute_tensor(global_tensor, placements=placements)
599            self.assertEqual(sharded_tensor.to_local().shape, torch.Size([3, 3]))
600
601            mesh_2d = DeviceMesh(
602                self.device_type, torch.arange(self.world_size).reshape(2, 4)
603            )
604
605            with mesh_2d:
606                shard_2d_spec = [Shard(0), Replicate()]
607                tensor_2d = distribute_tensor(global_tensor, placements=shard_2d_spec)
608
609                self.assertEqual(tensor_2d.to_local().shape, torch.Size([3 * 4, 3]))
610
611            sharded_after_2d = distribute_tensor(global_tensor, placements=placements)
612            self.assertEqual(sharded_after_2d.to_local().shape, torch.Size([3, 3]))
613
614    @with_comms
615    def test_dtensor_2d_mesh(self):
616        mesh_tensor = torch.arange(self.world_size).reshape(2, 4)
617        # construct a cuda device mesh
618        mesh = DeviceMesh(self.device_type, mesh_tensor)
619
620        # construct a dist tensor on 2d device mesh and test if works
621        placements = [Shard(0), Shard(1)]
622        local_tensor = torch.randn(3, 3)
623        dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
624        self.assertEqual(
625            dist_tensor.size(), torch.Size([3 * mesh.size(0), 3 * mesh.size(1)])
626        )
627        self.assertEqual(dist_tensor.device.type, self.device_type)
628        self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
629
630        # if shard on the same tensor dimension
631        # we should correctly construct the global tensor size
632        shard_same_dim_spec = [Shard(0), Shard(0)]
633        local_tensor = torch.randn(3, 3)
634        dist_tensor = DTensor.from_local(local_tensor, mesh, shard_same_dim_spec)
635        self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3]))
636
637    @with_comms
638    def test_device_mesh_nd(self):
639        # construct a cuda device mesh
640        mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
641        mesh = DeviceMesh(self.device_type, mesh_tensor)
642        # construct a dist tensor on 3d device mesh and test if works
643        placements = [Shard(0), Shard(1), Shard(2)]
644        local_tensor = torch.randn(3, 3, 3)
645        dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
646        self.assertEqual(dist_tensor.size(), torch.Size([6, 6, 6]))
647        self.assertEqual(dist_tensor.device.type, self.device_type)
648        self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
649
650        # construct a dist tensor on 3d device mesh with some shards on same dim
651        placements = [Shard(0), Shard(0), Shard(2)]
652        local_tensor = torch.randn(3, 3, 3)
653        dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
654        self.assertEqual(dist_tensor.size(), torch.Size([12, 3, 6]))
655        self.assertEqual(dist_tensor.device.type, self.device_type)
656        self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
657
658    @with_comms
659    def test_dtensor_spec_local_shard_offset(self):
660        device_mesh = DeviceMesh(
661            self.device_type, torch.arange(self.world_size).reshape(2, 4)
662        )
663        tensor_shape = (3 * self.world_size, 3 * self.world_size)
664        # sharding specs and its corresponding local shard offsets
665        shard_spec_and_offsets = [
666            (
667                [Shard(0), Replicate()],
668                (3 * (self.world_size // 2) * (self.rank // 4), 0),
669            ),
670            (
671                [Shard(1), Replicate()],
672                (0, 3 * (self.world_size // 2) * (self.rank // 4)),
673            ),
674            (
675                [Replicate(), Shard(0)],
676                (3 * (self.world_size // 4) * (self.rank % 4), 0),
677            ),
678            (
679                [Replicate(), Shard(1)],
680                (0, 3 * (self.world_size // 4) * (self.rank % 4)),
681            ),
682        ]
683
684        from torch.distributed._tensor._utils import (
685            compute_local_shape_and_global_offset,
686        )
687
688        # loop through all sharding specs and check local shard offsets
689        logical_tensor = torch.randn(tensor_shape)
690        for placements, expected_shard_offsets in shard_spec_and_offsets:
691            dtensor = distribute_tensor(logical_tensor, device_mesh, placements)
692            _, offset = compute_local_shape_and_global_offset(
693                dtensor.shape, device_mesh, dtensor.placements
694            )
695            self.assertEqual(expected_shard_offsets, offset)
696
697    @with_comms
698    def test_from_local_sub_mesh(self):
699        mesh = DeviceMesh(self.device_type, [0, 2])
700        local_tensor = torch.ones(3, 4)
701
702        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
703        self.assertEqual(dtensor.size(), torch.Size([6, 4]))
704
705        self.sub_mesh_assert_equal(
706            mesh.mesh,
707            torch.ones(3, 4),
708            torch.tensor([]),
709            dtensor.to_local(),
710        )
711
712        # test dtensor created in submesh, the operation should only
713        # be applied to the local shard inside the mesh, not the whole
714        # world, so only 0/2 really run the computation
715        dtensor = dtensor + 2
716
717        self.sub_mesh_assert_equal(
718            mesh.mesh,
719            torch.ones(3, 4) + 2,
720            torch.tensor([]),
721            dtensor.to_local(),
722        )
723
724    @with_comms
725    def test_default_value_sub_mesh(self):
726        mesh = DeviceMesh(self.device_type, [0, 2])
727
728        # test scalar return value
729        local_tensor1 = torch.ones(4, 3)
730        local_tensor2 = torch.ones(4, 3)
731        dtensor1 = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
732        dtensor2 = DTensor.from_local(local_tensor2, mesh, [Shard(0)])
733        local_res = dtensor1.equal(dtensor2)  # equal returns local result
734        self.sub_mesh_assert_equal(
735            mesh.mesh,
736            True,
737            True,
738            local_res,
739        )
740
741        # test 0-d tensor return value
742        local_tensor = torch.ones(4, 3)
743        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]).sum()
744        self.sub_mesh_assert_equal(
745            mesh.mesh,
746            torch.tensor(12.0),
747            torch.tensor(0.0),
748            dtensor.to_local(),
749        )
750
751        # test List[torch.Tensor] return value
752        local_tensor = torch.ones(3, 4)
753        dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
754        dtensor_list = dtensor.split([2, 2], dim=1)
755        self.sub_mesh_assert_equal(
756            mesh.mesh,
757            [torch.ones(3, 2)] * 2,
758            [torch.tensor([])] * 2,
759            [dt.to_local() for dt in dtensor_list],
760        )
761
762    @with_comms
763    def test_redistribute_sub_mesh(self):
764        mesh = DeviceMesh(self.device_type, [0, 2])
765
766        # test redistribute on a submesh
767        local_tensor1 = torch.ones(4, 3)
768        sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
769        replicated_dtensor = sharded_dtensor.redistribute(placements=[Replicate()])
770        self.sub_mesh_assert_equal(
771            mesh.mesh, torch.ones(8, 3), torch.tensor([]), replicated_dtensor.to_local()
772        )
773        sharded_again = replicated_dtensor.redistribute(placements=[Shard(0)])
774        self.sub_mesh_assert_equal(
775            mesh.mesh, torch.ones(4, 3), torch.tensor([]), sharded_again.to_local()
776        )
777
778    @with_comms
779    def test_implicit_replication(self):
780        mesh = init_device_mesh(self.device_type, (self.world_size,))
781        local_tensor1 = torch.ones(4, 3)
782        sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
783
784        from torch.distributed._tensor.experimental import implicit_replication
785
786        with implicit_replication():
787            out_dt = sharded_dtensor + torch.ones(3, device=self.device_type)
788            self.assertEqual(out_dt.placements, [Shard(0)])
789            self.assertEqual(out_dt.shape, (4 * self.world_size, 3))
790            local_shard = out_dt.to_local()
791            self.assertEqual(local_shard.shape, (4, 3))
792            self.assertEqual(local_shard, torch.ones(4, 3) + torch.ones(3))
793
794    @with_comms
795    def test_auto_implicit_replication(self):
796        mesh = init_device_mesh(self.device_type, (self.world_size,))
797
798        local_tensor = torch.ones(self.world_size, 3, device=self.device_type)
799        sharded_dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
800
801        # automatically turn tensor to DTensor replicate when ndim = 0 and numel = 1
802        ndim_0_tensor = torch.tensor(1, device=self.device_type)
803
804        def add_scalar_tensor_with_dtensor():
805            return sharded_dtensor + ndim_0_tensor
806
807        result = add_scalar_tensor_with_dtensor().to_local()
808        self.assertEqual(result, local_tensor + ndim_0_tensor)
809        self.assertNotWarn(
810            add_scalar_tensor_with_dtensor,
811            "Found a non-scalar tensor with numel=1 and ndim!=0",
812        )
813
814        # automatically turn tensor to DTensor replicate when ndim = 1 and numel = 1
815        numel_1_tensor = torch.tensor([1], device=self.device_type)
816        self.assertEqual(
817            (sharded_dtensor + numel_1_tensor).to_local(), local_tensor + numel_1_tensor
818        )
819
820
821class TestDTensorPlacementTypes(DTensorTestBase):
822    @property
823    def world_size(self):
824        return 8
825
826    def _create_tensor(self, size):
827        # Keep everything deterministic.
828        torch.manual_seed(0)
829        tensor = torch.rand(size)
830        if self.device_type == "cuda":
831            return tensor.cuda()
832        else:
833            return tensor
834
835    @with_comms
836    def test_split_tensor_1D(self) -> None:
837        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
838        shard_placement = Shard(0)
839
840        for size in range(8):
841            tensor = self._create_tensor(size)
842            splitted_tensor_list, pad_sizes = shard_placement._split_tensor(
843                tensor,
844                mesh.size(),
845                with_padding=True,
846                contiguous=True,
847            )
848            if size == 0:
849                # when tensor size is 0, there is no padding needed for all the ranks.
850                expected_pad_sizes = []
851                assert_array_equal(expected_pad_sizes, pad_sizes)
852
853                is_tensor_empty = [
854                    False if splitted_tensor.numel() > 0 else True
855                    for splitted_tensor in splitted_tensor_list
856                ]
857                expected_is_tensor_empty = [True] * self.world_size
858                assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
859            else:
860                expected_pad_sizes = [
861                    0 if idx < size else 1
862                    for idx, _ in enumerate(range(self.world_size))
863                ]
864                assert_array_equal(expected_pad_sizes, pad_sizes)
865
866                from torch.distributed._tensor._collective_utils import unpad_tensor
867
868                unpadded_list = [
869                    unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
870                    if pad_sizes[i] > 0
871                    else tensor
872                    for i, tensor in enumerate(splitted_tensor_list)
873                ]
874                expected_is_tensor_empty = [
875                    False if idx < size else True
876                    for idx, _ in enumerate(range(self.world_size))
877                ]
878                is_tensor_empty = [
879                    False if unpadded_tensor.numel() > 0 else True
880                    for unpadded_tensor in unpadded_list
881                ]
882                assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
883
884
885if __name__ == "__main__":
886    run_tests()
887