xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3# Copyright (c) Meta Platforms, Inc. and affiliates
4
5import itertools
6import sys
7from dataclasses import dataclass
8from functools import wraps
9from typing import Any, Callable, cast, Dict, Iterator, List, Sequence, Tuple, TypeVar
10
11import torch
12import torch.distributed as dist
13import torch.nn as nn
14import torch.nn.functional as F
15
16from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
17from torch.distributed._tensor.placement_types import Placement
18from torch.distributed.tensor.parallel import (
19    ColwiseParallel,
20    parallelize_module,
21    PrepareModuleInput,
22    RowwiseParallel,
23    SequenceParallel,
24)
25from torch.testing._internal.common_distributed import (
26    MultiProcessTestCase,
27    MultiThreadedTestCase,
28    skip_if_lt_x_gpu,
29    run_subtests,
30    TEST_SKIPS,
31)
32
33from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
34
35DEVICE_TYPE = (
36    "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else "cpu"
37)
38
39NUM_DEVICES = 4
40
41# We use this as a proxy for "multiple GPUs exist"
42if torch.cuda.is_available() and torch.cuda.device_count() > 1:
43    # when we actually have multiple GPUs, relax the requirement to smaller counts.
44    NUM_DEVICES = min(NUM_DEVICES, torch.cuda.device_count())
45
46T = TypeVar("T")
47
48
49# simple RMSNorm layer for testing
50class RMSNormPython(torch.nn.Module):
51    def __init__(self, dim: int, eps: float = 1e-6):
52        super().__init__()
53        self.eps = eps
54        self.weight = torch.nn.Parameter(torch.ones(dim))
55
56    def _norm(self, x):
57        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
58
59    def forward(self, x):
60        output = self._norm(x)
61        return output * self.weight
62
63
64class MLPModule(nn.Module):
65    def __init__(self, device, bias: bool = True):
66        super().__init__()
67        torch.manual_seed(5)
68        self.net1 = nn.Linear(10, 16, bias=bias, device=device)
69        self.relu = nn.ReLU()
70        self.net2 = nn.Linear(16, 10, bias=bias, device=device)
71
72    def forward(self, x):
73        return self.net2(self.relu(self.net1(x)))
74
75    def reset_parameters(self):
76        self.net1.reset_parameters()
77        self.net2.reset_parameters()
78
79
80class MLPStacked(nn.Module):
81    def __init__(self, device, n_layers: int = 2):
82        super().__init__()
83        self.layers = nn.ModuleList([MLPModule(device) for i in range(n_layers)])
84
85    def forward(self, x):
86        for layer in self.layers:
87            x = layer(x)
88        return x
89
90
91@dataclass
92class ModelArgs:
93    n_layers: int = 2
94    vocab_size: int = 8
95    max_seq_len: int = 16
96    dim: int = 16
97    n_heads: int = 4
98    dropout_p: float = 0.1
99    use_attn_mask: bool = True
100    weight_tying: bool = True
101    checkpoint_activations: bool = False
102
103
104class Attention(nn.Module):
105    def __init__(self, args: ModelArgs):
106        super().__init__()
107        assert args.dim % args.n_heads == 0
108        self.head_dim = args.dim // args.n_heads
109        self.n_heads = args.n_heads
110        self.dropout_p = args.dropout_p
111        self.resid_dropout = nn.Dropout(args.dropout_p)
112        self.use_attn_mask = args.use_attn_mask
113
114        self.wq = nn.Linear(args.dim, args.dim, bias=False)
115        self.wk = nn.Linear(args.dim, args.dim, bias=False)
116        self.wv = nn.Linear(args.dim, args.dim, bias=False)
117        self.wo = nn.Linear(args.dim, args.dim, bias=False)
118
119    def forward(self, x):
120        bsz, seq_len, _ = x.size()
121        queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
122        queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
123        keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
124        values = values.view(bsz, seq_len, self.n_heads, self.head_dim)
125
126        queries = queries.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
127        keys = keys.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
128        values = values.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
129
130        output = F.scaled_dot_product_attention(
131            queries,
132            keys,
133            values,
134            None,
135            self.dropout_p if self.training else 0,
136            self.use_attn_mask,
137        )
138        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
139        return self.resid_dropout(self.wo(output))
140
141
142class FeedForward(nn.Module):
143    def __init__(self, dim, hidden_dim, dropout_p):
144        super().__init__()
145        self.w1 = nn.Linear(dim, hidden_dim)
146        self.gelu = nn.GELU()
147        self.w2 = nn.Linear(hidden_dim, dim)
148        self.resid_dropout = nn.Dropout(dropout_p)
149
150    def forward(self, x):
151        return self.resid_dropout(self.w2(self.gelu(self.w1(x))))
152
153
154class TransformerBlock(nn.Module):
155    def __init__(self, args: ModelArgs):
156        super().__init__()
157        self.attention_norm = nn.LayerNorm(args.dim)
158        self.attention = Attention(args)
159        self.ffn_norm = nn.LayerNorm(args.dim)
160        self.feed_forward = FeedForward(
161            args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
162        )
163
164    def forward(self, x):
165        h = x + self.attention(self.attention_norm(x))
166        out = h + self.feed_forward(self.ffn_norm(h))
167        return out
168
169
170# A toy transformer model, partly inspired by the nanoGPT model:
171# https://github.com/karpathy/nanoGPT.
172class Transformer(nn.Module):
173    def __init__(self, args: ModelArgs):
174        super().__init__()
175        assert args.vocab_size is not None
176        assert args.max_seq_len is not None
177        self.model_args = args
178        self.max_seq_len = args.max_seq_len
179        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
180        self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
181        self.dropout = nn.Dropout(args.dropout_p)
182        self.layers = nn.ModuleList()
183        for _ in range(args.n_layers):
184            self.layers.append(TransformerBlock(args))
185        self.norm = nn.LayerNorm(args.dim)
186        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
187        if args.weight_tying:
188            self.output.weight = self.tok_embeddings.weight
189        self.checkpoint_activations = args.checkpoint_activations
190
191    def forward(self, tokens):
192        _bsz, seq_len = tokens.size()
193        assert seq_len <= self.max_seq_len
194        h = self.tok_embeddings(tokens)
195        pos = torch.arange(0, seq_len, device=tokens.device)
196        p = self.pos_embeddings(pos)  # positional embeddings of shape (seq_len, dim)
197        h = h + p
198        h = self.dropout(h)
199        for layer in self.layers:
200            if self.checkpoint_activations:
201                h = torch.utils.checkpoint.checkpoint(layer, h, use_reentrant=False)
202            else:
203                h = layer(h)
204        h = self.norm(h)
205        output = self.output(h).float()
206        return output
207
208    @staticmethod
209    def parallelize(
210        module: "Transformer", device_mesh: DeviceMesh, use_seq_parallel: bool, local_output_for_attn: bool = False
211    ) -> nn.Module:
212        assert isinstance(module, Transformer), f"Requires Transformer but got {module}"
213        # Parallelize the root submodules.
214        if use_seq_parallel:
215            root_plan = {
216                "tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)),
217                "pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(0)),
218                "norm": SequenceParallel(),
219            }
220        else:
221            root_plan = {
222                "tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
223                "pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
224            }
225
226        module_tp = parallelize_module(module, device_mesh, root_plan)
227        # Parallelize the attention and feed forward submodules.
228        for layer in module_tp.layers:
229            layer_parallelize_plan = {}
230            if use_seq_parallel:
231                layer_parallelize_plan["attention"] = PrepareModuleInput(
232                    input_layouts=Shard(1),
233                    desired_input_layouts=Replicate(),
234                )
235                # shard the RMSNorms
236                layer_parallelize_plan["attention_norm"] = SequenceParallel()
237                layer_parallelize_plan["ffn_norm"] = SequenceParallel()
238            layer_parallelize_plan["attention.wq"] = ColwiseParallel(use_local_output=local_output_for_attn)
239            layer_parallelize_plan["attention.wk"] = ColwiseParallel(use_local_output=local_output_for_attn)
240            layer_parallelize_plan["attention.wv"] = ColwiseParallel(use_local_output=local_output_for_attn)
241            layer_parallelize_plan["attention.wo"] = (
242                RowwiseParallel(output_layouts=Shard(1))
243                if use_seq_parallel
244                else RowwiseParallel()
245            )
246
247            layer_parallelize_plan["feed_forward.w1"] = (
248                ColwiseParallel(input_layouts=Shard(1))
249                if use_seq_parallel
250                else ColwiseParallel()
251            )
252            layer_parallelize_plan["feed_forward.w2"] = (
253                RowwiseParallel(output_layouts=Shard(1))
254                if use_seq_parallel
255                else RowwiseParallel()
256            )
257
258            parallelize_module(layer, device_mesh, layer_parallelize_plan)
259
260        # Parallelize the output submodule. If weight tying is enabled, we need to
261        # make sure output.weight is sharded consistently as tok_embeddings.weight,
262        # at the cost of the all_reduce operation using RowwiseParallel.
263        output_parallelize_plan = (
264            ColwiseParallel(
265                input_layouts=Shard(1),
266                output_layouts=Replicate(),
267            )
268            if use_seq_parallel
269            else ColwiseParallel(output_layouts=Replicate())
270        )
271        parallelize_module(module_tp.output, device_mesh, output_parallelize_plan)
272
273        if local_output_for_attn:
274            for layer in module_tp.layers:
275                layer.attention.n_heads = module_tp.model_args.n_heads // device_mesh.size()
276
277        # Manually set output.weight so that parameters and gradients are shared.
278        if module_tp.model_args.weight_tying:
279            module_tp.output.weight = module_tp.tok_embeddings.weight
280
281        return module_tp
282
283
284def skip_unless_torch_gpu(method: T) -> T:
285    """
286    Test decorator which skips the test unless there's a GPU available to torch.
287
288    >>> # xdoctest: +SKIP
289    >>> @skip_unless_torch_gpu
290    >>> def test_some_method(self) -> None:
291    >>>   ...
292    """
293    # The builtin @skip_if_no_gpu relies on os.environ['WORLD_SIZE'] being set.
294    return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method))
295
296
297class DTensorTestBase(MultiProcessTestCase):
298    @property
299    def world_size(self) -> int:
300        return NUM_DEVICES
301
302    @property
303    def backend(self) -> str:
304        backend = "nccl" if self.device_type == "cuda" else "gloo"
305        return backend
306
307    def build_device_mesh(self) -> DeviceMesh:
308        return DeviceMesh(self.device_type, list(range(self.world_size)))
309
310    def init_pg(self) -> None:
311        if "nccl" in self.backend and torch.cuda.device_count() < self.world_size:
312            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
313
314        if self.backend not in ["nccl", "gloo", "mpi", "cpu:gloo,cuda:nccl"]:
315            raise RuntimeError(f"Backend {self.backend} not supported!")
316
317        dist.init_process_group(
318            backend=self.backend,
319            world_size=self.world_size,
320            rank=self.rank,  # pyre-ignore[16]
321            init_method=f"file://{self.file_name}",  # pyre-ignore[16]
322        )
323
324        # set device for nccl pg for collectives
325        if "nccl" in self.backend:
326            torch.cuda.set_device(self.rank)
327
328    def destroy_pg(self) -> None:
329        # Wait for all ranks to reach here before starting shutdown.
330        # FIXME dist.barrier deadlocks with multiple threads and NCCL: https://github.com/pytorch/pytorch/issues/95895
331        # dist.all_reduce(torch.zeros((1,), device="cuda" if torch.cuda.is_available() else "cpu"))
332        # FIXME can't use the above all_reduce as it causes hangs on bionic and focal. It hangs:
333        #  test_dtensor.py  -- DTensorMeshTest.test_dtensor_device_mesh_device_conversion
334        dist.barrier()
335        dist.destroy_process_group()
336
337    def setUp(self) -> None:
338        super().setUp()
339        self._spawn_processes()
340
341    # pyre-ignore[2]:
342    def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None:
343        out = op_call(*args, **kwargs)
344        dtc = DTensorConverter(mesh, args, kwargs)
345        for d_args, d_kwargs in dtc:
346            # pyre can't find assertTrue anymore?
347            self.assertEqual(dtc.successful(), True)
348            d_out = op_call(*d_args, **d_kwargs)
349            self.assertEqual(d_out.full_tensor(), out)
350
351    def run_subtests(self, *args, **kwargs):
352        return run_subtests(self, *args, **kwargs)
353
354
355TestFunc = Callable[[object], object]
356
357
358# wrapper to initialize comms (processgroup)
359def with_comms(func: TestFunc) -> TestFunc:
360    assert func is not None
361
362    @wraps(func)  # pyre-ignore[6]
363    def wrapper(
364        self, *args: Tuple[object], **kwargs: Dict[str, Any]  # type: ignore[misc]
365    ) -> None:
366        # if enough GPU we can use GPU, otherwise we fallback to CPU
367        if not torch.cuda.is_available() or torch.cuda.device_count() < self.world_size:
368            self.device_type = "cpu"
369        else:
370            self.device_type = DEVICE_TYPE
371
372        self.init_pg()
373
374        try:
375            func(self, *args, **kwargs)  # type: ignore[misc]
376        except Exception as e:
377            dist.destroy_process_group()
378            raise e
379
380        self.destroy_pg()
381
382    return wrapper
383
384
385class DTensorOpTestBase(MultiThreadedTestCase):
386    @property
387    def world_size(self) -> int:
388        return NUM_DEVICES
389
390    @property
391    def device_type(self) -> str:
392        return DEVICE_TYPE
393
394    def build_device_mesh(self):
395        return DeviceMesh(self.device_type, list(range(self.world_size)))
396
397    def setUp(self) -> None:
398        super().setUp()
399        self._spawn_threads()
400
401
402# This is a class for converting args/kwargs of an op into distributed args/kwargs
403class DTensorConverter:
404    def __init__(
405        self,
406        mesh: DeviceMesh,
407        args: Tuple[object, ...],
408        kwargs: Dict[str, object],
409    ) -> None:
410        self.hit = 0
411        self.miss = 0
412        self.mesh = mesh
413        self.args = args
414        self.kwargs = kwargs
415        flatten_args, flatten_args_spec = tree_flatten(args)
416        flatten_kwargs, flatten_kwargs_spec = tree_flatten(kwargs)
417
418        self.flatten_args: List[object] = flatten_args
419        self.flatten_args_spec: TreeSpec = flatten_args_spec
420        self.flatten_kwargs: List[object] = flatten_kwargs
421        self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec
422
423        choices_for_args = []
424        for arg in self.flatten_args:
425            if isinstance(arg, torch.Tensor):
426                choices_for_args.append(self.gen_sharding_choices_for_arg(arg))
427
428        for arg in self.flatten_kwargs:
429            if isinstance(arg, torch.Tensor):
430                choices_for_args.append(self.gen_sharding_choices_for_arg(arg))
431
432        self.sharding_combs: Iterator[Sequence[Placement]] = iter(
433            itertools.product(*choices_for_args)
434        )
435
436    def successful(self) -> bool:
437        return self.hit > 0 and self.miss == 0
438
439    def is_supported_tensor(self, t: torch.Tensor) -> bool:
440        # TODO: dist tensor need to support quantized and sparse
441        # tensors, quantized tensor might be relatively easy, but
442        # sparse tensor have special layouts that we need to possibly
443        # deal with, until we are clear about them, we don't officially
444        # support them.
445        return not any(
446            [
447                t.is_sparse_csr,
448                t.is_sparse,
449                t.is_mkldnn,
450                t.is_quantized,
451                t.is_nested,
452                torch._is_functional_tensor(t),
453                t.is_neg(),
454                t.is_conj(),
455                t.device.type in ("lazy", "meta"),
456                # We need a way to test if a tensor is batched but there
457                # is no official APi to do it
458                # torch._C._is_batched(t),
459            ]
460        )
461
462    def gen_sharding_choices_for_arg(self, arg: torch.Tensor) -> Sequence[Placement]:
463        mesh_size = self.mesh.size()
464        sharding_choices: List[Placement] = [Replicate()]
465        # c10d collective does not support bool tensor
466        # for bool tensor we treat it as replicated
467        if arg.dtype != torch.bool:
468            # only generating choices with: replicate, or sharding
469            # evenly on a dimension that could be sharded
470            sharding_choices = sharding_choices + [
471                Shard(i)
472                for i, s in enumerate(arg.shape)
473                if s > 1 and s % mesh_size == 0
474            ]
475        # TODO: add multi mesh choices
476        # all_choices = itertools.product(
477        #     *(self.mesh.ndim * [sharding_choices])
478        # )
479        return sharding_choices
480
481    def __iter__(self) -> "DTensorConverter":
482        return self
483
484    def __next__(self) -> Tuple[Tuple[object, ...], Dict[str, object]]:
485        try:
486            next_sharding_choices = next(self.sharding_combs)
487            idx = 0
488
489            new_args: List[object] = []
490            for arg in self.flatten_args:
491                if isinstance(arg, torch.Tensor):
492                    new_args.append(
493                        self.to_dist_tensor(
494                            arg, self.mesh, [next_sharding_choices[idx]]
495                        )
496                    )
497                    idx += 1
498                else:
499                    new_args.append(arg)
500
501            new_kwargs: List[object] = []
502            for arg in self.flatten_kwargs:
503                if isinstance(arg, torch.Tensor):
504                    new_kwargs.append(
505                        self.to_dist_tensor(
506                            arg, self.mesh, [next_sharding_choices[idx]]
507                        )
508                    )
509                    idx += 1
510                else:
511                    new_kwargs.append(arg)
512
513            return (
514                tree_unflatten(new_args, self.flatten_args_spec),
515                tree_unflatten(new_kwargs, self.flatten_kwargs_spec),
516            )
517        except StopIteration as e:
518            raise StopIteration from e
519
520    def to_dist_tensor(
521        self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement]
522    ) -> torch.Tensor:
523        if type(t) is torch.Tensor or type(t) is nn.Parameter:
524            if self.is_supported_tensor(t):
525                self.hit += 1
526                if t.ndim == 0:
527                    # scalar tensor by default will be replicated
528                    r = distribute_tensor(t, mesh, [Replicate()] * mesh.ndim)
529                else:
530                    # distribute non-scalar tensors
531                    r = distribute_tensor(t, mesh, placements)
532                if type(t) is nn.Parameter:
533                    r = nn.Parameter(  # type: ignore[assignment]
534                        r, requires_grad=r.requires_grad
535                    )
536                return r
537            else:
538                self.miss += 1
539                return t
540        elif torch.overrides.is_tensor_like(t):
541            # Blindly converting tensor subclasses to dist tensor can cause
542            # unpredictable problems, we explicitly disable this conversion
543            # for now (i.e. we don't support DTensor holding tensor subclass
544            # until there's a strong reason later).
545            self.miss += 1
546            return t
547        else:
548            raise RuntimeError(f"Trying to convert to DTensor, but got {type(t)}")
549