1# mypy: allow-untyped-defs 2 3# Copyright (c) Meta Platforms, Inc. and affiliates 4 5import itertools 6import sys 7from dataclasses import dataclass 8from functools import wraps 9from typing import Any, Callable, cast, Dict, Iterator, List, Sequence, Tuple, TypeVar 10 11import torch 12import torch.distributed as dist 13import torch.nn as nn 14import torch.nn.functional as F 15 16from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard 17from torch.distributed._tensor.placement_types import Placement 18from torch.distributed.tensor.parallel import ( 19 ColwiseParallel, 20 parallelize_module, 21 PrepareModuleInput, 22 RowwiseParallel, 23 SequenceParallel, 24) 25from torch.testing._internal.common_distributed import ( 26 MultiProcessTestCase, 27 MultiThreadedTestCase, 28 skip_if_lt_x_gpu, 29 run_subtests, 30 TEST_SKIPS, 31) 32 33from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec 34 35DEVICE_TYPE = ( 36 "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else "cpu" 37) 38 39NUM_DEVICES = 4 40 41# We use this as a proxy for "multiple GPUs exist" 42if torch.cuda.is_available() and torch.cuda.device_count() > 1: 43 # when we actually have multiple GPUs, relax the requirement to smaller counts. 44 NUM_DEVICES = min(NUM_DEVICES, torch.cuda.device_count()) 45 46T = TypeVar("T") 47 48 49# simple RMSNorm layer for testing 50class RMSNormPython(torch.nn.Module): 51 def __init__(self, dim: int, eps: float = 1e-6): 52 super().__init__() 53 self.eps = eps 54 self.weight = torch.nn.Parameter(torch.ones(dim)) 55 56 def _norm(self, x): 57 return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 58 59 def forward(self, x): 60 output = self._norm(x) 61 return output * self.weight 62 63 64class MLPModule(nn.Module): 65 def __init__(self, device, bias: bool = True): 66 super().__init__() 67 torch.manual_seed(5) 68 self.net1 = nn.Linear(10, 16, bias=bias, device=device) 69 self.relu = nn.ReLU() 70 self.net2 = nn.Linear(16, 10, bias=bias, device=device) 71 72 def forward(self, x): 73 return self.net2(self.relu(self.net1(x))) 74 75 def reset_parameters(self): 76 self.net1.reset_parameters() 77 self.net2.reset_parameters() 78 79 80class MLPStacked(nn.Module): 81 def __init__(self, device, n_layers: int = 2): 82 super().__init__() 83 self.layers = nn.ModuleList([MLPModule(device) for i in range(n_layers)]) 84 85 def forward(self, x): 86 for layer in self.layers: 87 x = layer(x) 88 return x 89 90 91@dataclass 92class ModelArgs: 93 n_layers: int = 2 94 vocab_size: int = 8 95 max_seq_len: int = 16 96 dim: int = 16 97 n_heads: int = 4 98 dropout_p: float = 0.1 99 use_attn_mask: bool = True 100 weight_tying: bool = True 101 checkpoint_activations: bool = False 102 103 104class Attention(nn.Module): 105 def __init__(self, args: ModelArgs): 106 super().__init__() 107 assert args.dim % args.n_heads == 0 108 self.head_dim = args.dim // args.n_heads 109 self.n_heads = args.n_heads 110 self.dropout_p = args.dropout_p 111 self.resid_dropout = nn.Dropout(args.dropout_p) 112 self.use_attn_mask = args.use_attn_mask 113 114 self.wq = nn.Linear(args.dim, args.dim, bias=False) 115 self.wk = nn.Linear(args.dim, args.dim, bias=False) 116 self.wv = nn.Linear(args.dim, args.dim, bias=False) 117 self.wo = nn.Linear(args.dim, args.dim, bias=False) 118 119 def forward(self, x): 120 bsz, seq_len, _ = x.size() 121 queries, keys, values = self.wq(x), self.wk(x), self.wv(x) 122 queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim) 123 keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim) 124 values = values.view(bsz, seq_len, self.n_heads, self.head_dim) 125 126 queries = queries.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim) 127 keys = keys.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim) 128 values = values.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim) 129 130 output = F.scaled_dot_product_attention( 131 queries, 132 keys, 133 values, 134 None, 135 self.dropout_p if self.training else 0, 136 self.use_attn_mask, 137 ) 138 output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) 139 return self.resid_dropout(self.wo(output)) 140 141 142class FeedForward(nn.Module): 143 def __init__(self, dim, hidden_dim, dropout_p): 144 super().__init__() 145 self.w1 = nn.Linear(dim, hidden_dim) 146 self.gelu = nn.GELU() 147 self.w2 = nn.Linear(hidden_dim, dim) 148 self.resid_dropout = nn.Dropout(dropout_p) 149 150 def forward(self, x): 151 return self.resid_dropout(self.w2(self.gelu(self.w1(x)))) 152 153 154class TransformerBlock(nn.Module): 155 def __init__(self, args: ModelArgs): 156 super().__init__() 157 self.attention_norm = nn.LayerNorm(args.dim) 158 self.attention = Attention(args) 159 self.ffn_norm = nn.LayerNorm(args.dim) 160 self.feed_forward = FeedForward( 161 args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p 162 ) 163 164 def forward(self, x): 165 h = x + self.attention(self.attention_norm(x)) 166 out = h + self.feed_forward(self.ffn_norm(h)) 167 return out 168 169 170# A toy transformer model, partly inspired by the nanoGPT model: 171# https://github.com/karpathy/nanoGPT. 172class Transformer(nn.Module): 173 def __init__(self, args: ModelArgs): 174 super().__init__() 175 assert args.vocab_size is not None 176 assert args.max_seq_len is not None 177 self.model_args = args 178 self.max_seq_len = args.max_seq_len 179 self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) 180 self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim) 181 self.dropout = nn.Dropout(args.dropout_p) 182 self.layers = nn.ModuleList() 183 for _ in range(args.n_layers): 184 self.layers.append(TransformerBlock(args)) 185 self.norm = nn.LayerNorm(args.dim) 186 self.output = nn.Linear(args.dim, args.vocab_size, bias=False) 187 if args.weight_tying: 188 self.output.weight = self.tok_embeddings.weight 189 self.checkpoint_activations = args.checkpoint_activations 190 191 def forward(self, tokens): 192 _bsz, seq_len = tokens.size() 193 assert seq_len <= self.max_seq_len 194 h = self.tok_embeddings(tokens) 195 pos = torch.arange(0, seq_len, device=tokens.device) 196 p = self.pos_embeddings(pos) # positional embeddings of shape (seq_len, dim) 197 h = h + p 198 h = self.dropout(h) 199 for layer in self.layers: 200 if self.checkpoint_activations: 201 h = torch.utils.checkpoint.checkpoint(layer, h, use_reentrant=False) 202 else: 203 h = layer(h) 204 h = self.norm(h) 205 output = self.output(h).float() 206 return output 207 208 @staticmethod 209 def parallelize( 210 module: "Transformer", device_mesh: DeviceMesh, use_seq_parallel: bool, local_output_for_attn: bool = False 211 ) -> nn.Module: 212 assert isinstance(module, Transformer), f"Requires Transformer but got {module}" 213 # Parallelize the root submodules. 214 if use_seq_parallel: 215 root_plan = { 216 "tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)), 217 "pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(0)), 218 "norm": SequenceParallel(), 219 } 220 else: 221 root_plan = { 222 "tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()), 223 "pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()), 224 } 225 226 module_tp = parallelize_module(module, device_mesh, root_plan) 227 # Parallelize the attention and feed forward submodules. 228 for layer in module_tp.layers: 229 layer_parallelize_plan = {} 230 if use_seq_parallel: 231 layer_parallelize_plan["attention"] = PrepareModuleInput( 232 input_layouts=Shard(1), 233 desired_input_layouts=Replicate(), 234 ) 235 # shard the RMSNorms 236 layer_parallelize_plan["attention_norm"] = SequenceParallel() 237 layer_parallelize_plan["ffn_norm"] = SequenceParallel() 238 layer_parallelize_plan["attention.wq"] = ColwiseParallel(use_local_output=local_output_for_attn) 239 layer_parallelize_plan["attention.wk"] = ColwiseParallel(use_local_output=local_output_for_attn) 240 layer_parallelize_plan["attention.wv"] = ColwiseParallel(use_local_output=local_output_for_attn) 241 layer_parallelize_plan["attention.wo"] = ( 242 RowwiseParallel(output_layouts=Shard(1)) 243 if use_seq_parallel 244 else RowwiseParallel() 245 ) 246 247 layer_parallelize_plan["feed_forward.w1"] = ( 248 ColwiseParallel(input_layouts=Shard(1)) 249 if use_seq_parallel 250 else ColwiseParallel() 251 ) 252 layer_parallelize_plan["feed_forward.w2"] = ( 253 RowwiseParallel(output_layouts=Shard(1)) 254 if use_seq_parallel 255 else RowwiseParallel() 256 ) 257 258 parallelize_module(layer, device_mesh, layer_parallelize_plan) 259 260 # Parallelize the output submodule. If weight tying is enabled, we need to 261 # make sure output.weight is sharded consistently as tok_embeddings.weight, 262 # at the cost of the all_reduce operation using RowwiseParallel. 263 output_parallelize_plan = ( 264 ColwiseParallel( 265 input_layouts=Shard(1), 266 output_layouts=Replicate(), 267 ) 268 if use_seq_parallel 269 else ColwiseParallel(output_layouts=Replicate()) 270 ) 271 parallelize_module(module_tp.output, device_mesh, output_parallelize_plan) 272 273 if local_output_for_attn: 274 for layer in module_tp.layers: 275 layer.attention.n_heads = module_tp.model_args.n_heads // device_mesh.size() 276 277 # Manually set output.weight so that parameters and gradients are shared. 278 if module_tp.model_args.weight_tying: 279 module_tp.output.weight = module_tp.tok_embeddings.weight 280 281 return module_tp 282 283 284def skip_unless_torch_gpu(method: T) -> T: 285 """ 286 Test decorator which skips the test unless there's a GPU available to torch. 287 288 >>> # xdoctest: +SKIP 289 >>> @skip_unless_torch_gpu 290 >>> def test_some_method(self) -> None: 291 >>> ... 292 """ 293 # The builtin @skip_if_no_gpu relies on os.environ['WORLD_SIZE'] being set. 294 return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method)) 295 296 297class DTensorTestBase(MultiProcessTestCase): 298 @property 299 def world_size(self) -> int: 300 return NUM_DEVICES 301 302 @property 303 def backend(self) -> str: 304 backend = "nccl" if self.device_type == "cuda" else "gloo" 305 return backend 306 307 def build_device_mesh(self) -> DeviceMesh: 308 return DeviceMesh(self.device_type, list(range(self.world_size))) 309 310 def init_pg(self) -> None: 311 if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: 312 sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) 313 314 if self.backend not in ["nccl", "gloo", "mpi", "cpu:gloo,cuda:nccl"]: 315 raise RuntimeError(f"Backend {self.backend} not supported!") 316 317 dist.init_process_group( 318 backend=self.backend, 319 world_size=self.world_size, 320 rank=self.rank, # pyre-ignore[16] 321 init_method=f"file://{self.file_name}", # pyre-ignore[16] 322 ) 323 324 # set device for nccl pg for collectives 325 if "nccl" in self.backend: 326 torch.cuda.set_device(self.rank) 327 328 def destroy_pg(self) -> None: 329 # Wait for all ranks to reach here before starting shutdown. 330 # FIXME dist.barrier deadlocks with multiple threads and NCCL: https://github.com/pytorch/pytorch/issues/95895 331 # dist.all_reduce(torch.zeros((1,), device="cuda" if torch.cuda.is_available() else "cpu")) 332 # FIXME can't use the above all_reduce as it causes hangs on bionic and focal. It hangs: 333 # test_dtensor.py -- DTensorMeshTest.test_dtensor_device_mesh_device_conversion 334 dist.barrier() 335 dist.destroy_process_group() 336 337 def setUp(self) -> None: 338 super().setUp() 339 self._spawn_processes() 340 341 # pyre-ignore[2]: 342 def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None: 343 out = op_call(*args, **kwargs) 344 dtc = DTensorConverter(mesh, args, kwargs) 345 for d_args, d_kwargs in dtc: 346 # pyre can't find assertTrue anymore? 347 self.assertEqual(dtc.successful(), True) 348 d_out = op_call(*d_args, **d_kwargs) 349 self.assertEqual(d_out.full_tensor(), out) 350 351 def run_subtests(self, *args, **kwargs): 352 return run_subtests(self, *args, **kwargs) 353 354 355TestFunc = Callable[[object], object] 356 357 358# wrapper to initialize comms (processgroup) 359def with_comms(func: TestFunc) -> TestFunc: 360 assert func is not None 361 362 @wraps(func) # pyre-ignore[6] 363 def wrapper( 364 self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc] 365 ) -> None: 366 # if enough GPU we can use GPU, otherwise we fallback to CPU 367 if not torch.cuda.is_available() or torch.cuda.device_count() < self.world_size: 368 self.device_type = "cpu" 369 else: 370 self.device_type = DEVICE_TYPE 371 372 self.init_pg() 373 374 try: 375 func(self, *args, **kwargs) # type: ignore[misc] 376 except Exception as e: 377 dist.destroy_process_group() 378 raise e 379 380 self.destroy_pg() 381 382 return wrapper 383 384 385class DTensorOpTestBase(MultiThreadedTestCase): 386 @property 387 def world_size(self) -> int: 388 return NUM_DEVICES 389 390 @property 391 def device_type(self) -> str: 392 return DEVICE_TYPE 393 394 def build_device_mesh(self): 395 return DeviceMesh(self.device_type, list(range(self.world_size))) 396 397 def setUp(self) -> None: 398 super().setUp() 399 self._spawn_threads() 400 401 402# This is a class for converting args/kwargs of an op into distributed args/kwargs 403class DTensorConverter: 404 def __init__( 405 self, 406 mesh: DeviceMesh, 407 args: Tuple[object, ...], 408 kwargs: Dict[str, object], 409 ) -> None: 410 self.hit = 0 411 self.miss = 0 412 self.mesh = mesh 413 self.args = args 414 self.kwargs = kwargs 415 flatten_args, flatten_args_spec = tree_flatten(args) 416 flatten_kwargs, flatten_kwargs_spec = tree_flatten(kwargs) 417 418 self.flatten_args: List[object] = flatten_args 419 self.flatten_args_spec: TreeSpec = flatten_args_spec 420 self.flatten_kwargs: List[object] = flatten_kwargs 421 self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec 422 423 choices_for_args = [] 424 for arg in self.flatten_args: 425 if isinstance(arg, torch.Tensor): 426 choices_for_args.append(self.gen_sharding_choices_for_arg(arg)) 427 428 for arg in self.flatten_kwargs: 429 if isinstance(arg, torch.Tensor): 430 choices_for_args.append(self.gen_sharding_choices_for_arg(arg)) 431 432 self.sharding_combs: Iterator[Sequence[Placement]] = iter( 433 itertools.product(*choices_for_args) 434 ) 435 436 def successful(self) -> bool: 437 return self.hit > 0 and self.miss == 0 438 439 def is_supported_tensor(self, t: torch.Tensor) -> bool: 440 # TODO: dist tensor need to support quantized and sparse 441 # tensors, quantized tensor might be relatively easy, but 442 # sparse tensor have special layouts that we need to possibly 443 # deal with, until we are clear about them, we don't officially 444 # support them. 445 return not any( 446 [ 447 t.is_sparse_csr, 448 t.is_sparse, 449 t.is_mkldnn, 450 t.is_quantized, 451 t.is_nested, 452 torch._is_functional_tensor(t), 453 t.is_neg(), 454 t.is_conj(), 455 t.device.type in ("lazy", "meta"), 456 # We need a way to test if a tensor is batched but there 457 # is no official APi to do it 458 # torch._C._is_batched(t), 459 ] 460 ) 461 462 def gen_sharding_choices_for_arg(self, arg: torch.Tensor) -> Sequence[Placement]: 463 mesh_size = self.mesh.size() 464 sharding_choices: List[Placement] = [Replicate()] 465 # c10d collective does not support bool tensor 466 # for bool tensor we treat it as replicated 467 if arg.dtype != torch.bool: 468 # only generating choices with: replicate, or sharding 469 # evenly on a dimension that could be sharded 470 sharding_choices = sharding_choices + [ 471 Shard(i) 472 for i, s in enumerate(arg.shape) 473 if s > 1 and s % mesh_size == 0 474 ] 475 # TODO: add multi mesh choices 476 # all_choices = itertools.product( 477 # *(self.mesh.ndim * [sharding_choices]) 478 # ) 479 return sharding_choices 480 481 def __iter__(self) -> "DTensorConverter": 482 return self 483 484 def __next__(self) -> Tuple[Tuple[object, ...], Dict[str, object]]: 485 try: 486 next_sharding_choices = next(self.sharding_combs) 487 idx = 0 488 489 new_args: List[object] = [] 490 for arg in self.flatten_args: 491 if isinstance(arg, torch.Tensor): 492 new_args.append( 493 self.to_dist_tensor( 494 arg, self.mesh, [next_sharding_choices[idx]] 495 ) 496 ) 497 idx += 1 498 else: 499 new_args.append(arg) 500 501 new_kwargs: List[object] = [] 502 for arg in self.flatten_kwargs: 503 if isinstance(arg, torch.Tensor): 504 new_kwargs.append( 505 self.to_dist_tensor( 506 arg, self.mesh, [next_sharding_choices[idx]] 507 ) 508 ) 509 idx += 1 510 else: 511 new_kwargs.append(arg) 512 513 return ( 514 tree_unflatten(new_args, self.flatten_args_spec), 515 tree_unflatten(new_kwargs, self.flatten_kwargs_spec), 516 ) 517 except StopIteration as e: 518 raise StopIteration from e 519 520 def to_dist_tensor( 521 self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement] 522 ) -> torch.Tensor: 523 if type(t) is torch.Tensor or type(t) is nn.Parameter: 524 if self.is_supported_tensor(t): 525 self.hit += 1 526 if t.ndim == 0: 527 # scalar tensor by default will be replicated 528 r = distribute_tensor(t, mesh, [Replicate()] * mesh.ndim) 529 else: 530 # distribute non-scalar tensors 531 r = distribute_tensor(t, mesh, placements) 532 if type(t) is nn.Parameter: 533 r = nn.Parameter( # type: ignore[assignment] 534 r, requires_grad=r.requires_grad 535 ) 536 return r 537 else: 538 self.miss += 1 539 return t 540 elif torch.overrides.is_tensor_like(t): 541 # Blindly converting tensor subclasses to dist tensor can cause 542 # unpredictable problems, we explicitly disable this conversion 543 # for now (i.e. we don't support DTensor holding tensor subclass 544 # until there's a strong reason later). 545 self.miss += 1 546 return t 547 else: 548 raise RuntimeError(f"Trying to convert to DTensor, but got {type(t)}") 549