1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3from abc import ABC, abstractmethod 4from typing import Optional, Union, Tuple, Dict, Any 5from functools import partial 6 7import torch 8import torch.nn as nn 9from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor, distribute_module 10 11 12__all__ = [ 13 "ParallelStyle", 14 "RowwiseParallel", 15 "SequenceParallel", 16 "ColwiseParallel", 17 "PrepareModuleInput", 18 "PrepareModuleOutput", 19] 20 21 22class ParallelStyle(ABC): 23 """ 24 The parallel style contract defines how the module or submodule should be parallelized. 25 26 It only defines the ``apply`` method for ``parallelize_module`` to use, this allows maximum 27 flexibility for different kind of style implementations. 28 """ 29 30 @abstractmethod 31 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 32 ... 33 34 35class ColwiseParallel(ParallelStyle): 36 """ 37 Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding. 38 Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules. 39 (i.e. MLP, Attention) 40 41 Keyword Args: 42 input_layouts (Placement, optional): 43 The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to 44 become a DTensor. If not specified, we assume the input tensor to be replicated. 45 output_layouts (Placement, optional): 46 The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module 47 with the user desired layout. If not specified, the output tensor is sharded on the last dimension. 48 use_local_output (bool, optional): 49 Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. 50 Returns: 51 A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module. 52 53 Example:: 54 >>> # xdoctest: +SKIP(failing) 55 >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel 56 >>> from torch.distributed.device_mesh import init_device_mesh 57 >>> ... 58 >>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule 59 >>> tp_mesh = init_device_mesh("cuda", (8,)) 60 >>> 61 >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor 62 >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim. 63 >>> 64 >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()}) 65 >>> ... 66 67 .. note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not 68 specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``), 69 keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size. 70 """ 71 72 def __init__( 73 self, 74 *, 75 input_layouts: Optional[Placement] = None, 76 output_layouts: Optional[Placement] = None, 77 use_local_output: bool = True 78 ): 79 super().__init__() 80 self.input_layouts = (input_layouts or Replicate(), ) 81 self.output_layouts = (output_layouts or Shard(-1), ) 82 # colwise linear runtime sharding (desired sharding): 83 # 1. requires replicate input 84 # 2. shard output on last dim 85 self.desired_input_layouts = (Replicate(), ) 86 self.use_local_output = use_local_output 87 88 @staticmethod 89 def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): 90 # TODO: figure out dynamo support for instance method and switch this to instance method 91 92 # annotate module input placements/sharding with input_layouts 93 input_tensor = inputs[0] 94 if not isinstance(input_tensor, DTensor): 95 input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) 96 97 # transform the input layouts to the desired layouts of ColwiseParallel 98 if input_layouts != desired_input_layouts: 99 input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) 100 return input_tensor 101 102 def _partition_linear_fn(self, name, module, device_mesh): 103 # colwise shard weight/bias to Shard(0), weight be Shard(0) 104 # means Colwise as Linear is input * weight^T + bias, where 105 # weight would become Shard(1) 106 for name, param in module.named_parameters(): 107 dist_param = nn.Parameter( 108 distribute_tensor(param, device_mesh, [Shard(0)]) 109 ) 110 module.register_parameter(name, dist_param) 111 112 def _partition_embedding_fn(self, name, module, device_mesh): 113 # colwise shard embedding.weight is straight forward as Shard(1) 114 for name, param in module.named_parameters(): 115 dist_param = nn.Parameter( 116 distribute_tensor(param, device_mesh, [Shard(1)]) 117 ) 118 module.register_parameter(name, dist_param) 119 120 @staticmethod 121 def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): 122 # outputs is a shard on last dimension DTensor, i.e. Shard(-1) 123 if outputs.placements != output_layouts: 124 outputs = outputs.redistribute(placements=output_layouts, async_op=True) 125 # back to local tensor 126 return outputs.to_local() if use_local_output else outputs 127 128 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 129 if isinstance(module, nn.Linear): 130 partition_fn = self._partition_linear_fn 131 elif isinstance(module, nn.Embedding): 132 partition_fn = self._partition_embedding_fn 133 else: 134 raise NotImplementedError("ColwiseParallel currently only support nn.Linear and nn.Embedding!") 135 136 return distribute_module( 137 module, 138 device_mesh, 139 partition_fn, 140 partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), 141 partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), 142 ) 143 144 145class RowwiseParallel(ParallelStyle): 146 """ 147 Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding. 148 Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules. 149 (i.e. MLP, Attention) 150 151 Keyword Args: 152 input_layouts (Placement, optional): 153 The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to 154 become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension. 155 output_layouts (Placement, optional): 156 The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module 157 with the user desired layout. If not specified, the output tensor is replicated. 158 use_local_output (bool, optional): 159 Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. 160 Returns: 161 A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module. 162 163 Example:: 164 >>> # xdoctest: +SKIP(failing) 165 >>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel 166 >>> from torch.distributed.device_mesh import init_device_mesh 167 >>> ... 168 >>> m = Model(...) # m is a nn.Module that contains a "w2" nn.Linear submodule 169 >>> tp_mesh = init_device_mesh("cuda", (8,)) 170 >>> 171 >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim 172 >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`. 173 >>> 174 >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}), 175 >>> ... 176 """ 177 178 def __init__( 179 self, 180 *, 181 input_layouts: Optional[Placement] = None, 182 output_layouts: Optional[Placement] = None, 183 use_local_output: bool = True 184 ): 185 super().__init__() 186 self.input_layouts = (input_layouts or Shard(-1), ) 187 self.output_layouts = (output_layouts or Replicate(), ) 188 self.use_local_output = use_local_output 189 190 @staticmethod 191 def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): 192 input_tensor = inputs[0] 193 if not isinstance(input_tensor, DTensor): 194 input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) 195 196 if input_layouts != desired_input_layouts: 197 input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) 198 return input_tensor 199 200 def _partition_linear_fn(self, name, module, device_mesh): 201 # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) 202 # means Rowwise as nn.Linear is input * weight^T + bias, where 203 # weight would become Shard(0) 204 module.register_parameter("weight", nn.Parameter( 205 distribute_tensor(module.weight, device_mesh, [Shard(1)]) 206 )) 207 if module.bias is not None: 208 module.register_parameter("bias", nn.Parameter( 209 distribute_tensor(module.bias, device_mesh, [Replicate()]) 210 )) 211 212 def _partition_embedding_fn(self, name, module, device_mesh): 213 # rowwise shard embedding.weight is Shard(0) 214 for name, param in module.named_parameters(): 215 dist_param = nn.Parameter( 216 distribute_tensor(param, device_mesh, [Shard(0)]) 217 ) 218 module.register_parameter(name, dist_param) 219 220 @staticmethod 221 def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): 222 # Rowwise sharding produces partial output, depending on output layouts: 223 # 1. to replicate -> allreduce 224 # 2. to shard -> reduce_scatter 225 if outputs.placements != output_layouts: 226 outputs = outputs.redistribute(placements=output_layouts, async_op=True) 227 # back to local tensor if use_local_output is True 228 return outputs.to_local() if use_local_output else outputs 229 230 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 231 if isinstance(module, nn.Linear): 232 partition_fn = self._partition_linear_fn 233 # rowwise linear runtime sharding requires input tensor shard on last dim 234 self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1), ) 235 elif isinstance(module, nn.Embedding): 236 partition_fn = self._partition_embedding_fn 237 # rowwise embedding runtime sharding requires input tensor replicated 238 self.desired_input_layouts = (Replicate(), ) 239 else: 240 raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!") 241 242 return distribute_module( 243 module, 244 device_mesh, 245 partition_fn, 246 partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), 247 partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), 248 ) 249 250 251class SequenceParallel(ParallelStyle): 252 """ 253 SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with 254 input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the 255 `RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__ 256 257 This style implements the operation that is described in the paper 258 `Reducing Activation Recomputation in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__ 259 260 Both the input and output of the ``nn.Module`` will be sharded on the sequence dimension. 261 262 Keyword Args: 263 sequence_dim (int, optional): 264 The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to 265 become a DTensor that is sharded on the sequence dimension, default: 1. 266 use_local_output (bool, optional): 267 Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False. 268 Returns: 269 A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``. 270 271 Example:: 272 >>> # xdoctest: +SKIP(failing) 273 >>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel 274 >>> from torch.distributed.device_mesh import init_device_mesh 275 >>> ... 276 >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule 277 >>> tp_mesh = init_device_mesh("cuda", (8,)) 278 >>> 279 >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim 280 >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`. 281 >>> 282 >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}), 283 >>> ... 284 285 .. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e. 286 ``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom 287 inits for the weights on those modules, you need to broadcast the weights before/after parallelizing 288 to ensure that they are replicated. 289 """ 290 def __init__( 291 self, 292 *, 293 sequence_dim: int = 1, 294 use_local_output: bool = False 295 ): 296 super().__init__() 297 self.sequence_dim = sequence_dim 298 self.use_local_output = use_local_output 299 300 def _replicate_module_fn(self, name: str, module: nn.Module, device_mesh: DeviceMesh): 301 for p_name, param in module.named_parameters(): 302 # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow 303 # us to simply just use from_local 304 replicated_param = torch.nn.Parameter( 305 DTensor.from_local(param, device_mesh, [Replicate()], run_check=False) 306 ) 307 module.register_parameter(p_name, replicated_param) 308 309 @staticmethod 310 def _prepare_input_fn(sequence_dim, mod, inputs, device_mesh): 311 input_tensor = inputs[0] 312 if isinstance(input_tensor, DTensor): 313 return inputs 314 elif isinstance(input_tensor, torch.Tensor): 315 return DTensor.from_local(input_tensor, device_mesh, [Shard(sequence_dim)], run_check=False) 316 else: 317 raise ValueError(f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}") 318 319 @staticmethod 320 def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): 321 return outputs.to_local() if use_local_output else outputs 322 323 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 324 return distribute_module( 325 module, 326 device_mesh, 327 self._replicate_module_fn, 328 partial(self._prepare_input_fn, self.sequence_dim), 329 partial(self._prepare_output_fn, self.use_local_output), 330 ) 331 332 333class PrepareModuleInput(ParallelStyle): 334 """ 335 Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to 336 ``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``. 337 338 Keyword Args: 339 input_layouts (Union[Placement, Tuple[Optional[Placement]]]): 340 The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to 341 DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified 342 as a placeholder. default: None. 343 desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]): 344 The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module 345 have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None. 346 input_kwarg_layouts (Dict[str, Placement]): 347 The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors. 348 default: None 349 desired_input_kwarg_layouts: (Dict[str, Placement]): 350 The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module 351 have the desired DTensor layouts. default: None. 352 use_local_output (bool, optional): 353 Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False. 354 Returns: 355 A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs. 356 357 Example:: 358 >>> # xdoctest: +SKIP(failing) 359 >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput 360 >>> from torch.distributed.device_mesh import init_device_mesh 361 >>> ... 362 >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule 363 >>> tp_mesh = init_device_mesh("cuda", (8,)) 364 >>> 365 >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor 366 >>> # and then redistributed to Replicated DTensor. 367 >>> parallelize_module( 368 >>> block, # this can be a submodule or module 369 >>> tp_mesh, 370 >>> parallelize_plan={ 371 >>> "attn": PrepareModuleInput( 372 >>> input_layouts=(Shard(0), None, None, ...), 373 >>> desired_input_layouts=(Replicate(), None, None, ...) 374 >>> ), 375 >>> } 376 >>> ) 377 """ 378 379 def __init__( 380 self, 381 *, 382 input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None, 383 desired_input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None, 384 input_kwarg_layouts: Optional[Dict[str, Placement]] = None, 385 desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None, 386 use_local_output: bool = False 387 ): 388 self.input_layouts = (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts 389 self.desired_input_layouts = \ 390 (desired_input_layouts,) if isinstance(desired_input_layouts, Placement) else desired_input_layouts 391 self.use_local_output = use_local_output 392 if self.input_layouts is not None: 393 assert self.desired_input_layouts is not None, "desired module inputs should not be None!" 394 assert len(self.input_layouts) == len(self.desired_input_layouts), \ 395 "input_layouts and desired_input_layouts should have same length!" 396 self.with_kwargs = input_kwarg_layouts is not None 397 self.input_kwarg_layouts = input_kwarg_layouts or {} 398 self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {} 399 if self.with_kwargs: 400 assert len(self.input_kwarg_layouts) == len(self.desired_input_kwarg_layouts), \ 401 "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" 402 403 def _prepare_input_arg( 404 self, 405 input: Any, 406 mesh: DeviceMesh, 407 input_layout: Optional[Placement], 408 desired_layout: Optional[Placement] 409 ): 410 if input_layout is not None: 411 if isinstance(input, DTensor): 412 # TODO: re-enable the check once we fix the compile path 413 # assert inp.placements[0] == input_layout 414 dt_inp = input 415 else: 416 assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!" 417 dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False) 418 419 if desired_layout is not None and input_layout != desired_layout: 420 dt_inp = dt_inp.redistribute(placements=(desired_layout,)) 421 422 return dt_inp.to_local() if self.use_local_output else dt_inp 423 else: 424 return input 425 426 def _prepare_input_fn(self, inputs, device_mesh): 427 if self.input_layouts is None: 428 return inputs 429 prepared_inputs = [] 430 if not isinstance(inputs, tuple): 431 inputs = (inputs,) 432 if len(inputs) != len(self.input_layouts): 433 raise ValueError("module inputs and input_layouts should have same length!") 434 435 assert self.desired_input_layouts is not None, "desired module inputs should not be None!" 436 for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts): 437 prepared_inputs.append(self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)) 438 return tuple(prepared_inputs) 439 440 def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): 441 prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh) 442 prepared_kwarg_inputs = {} 443 for kwarg_key in kwarg_inputs.keys(): 444 kwarg_val = kwarg_inputs[kwarg_key] 445 input_layout = self.input_kwarg_layouts.get(kwarg_key) 446 desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key) 447 448 prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(kwarg_val, device_mesh, input_layout, desired_input_layout) 449 450 return (prepared_arg_inputs, prepared_kwarg_inputs) 451 452 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 453 if self.with_kwargs: 454 module.register_forward_pre_hook( 455 lambda _, inputs, kwargs: self._prepare_input_kwarg_fn(inputs, kwargs, device_mesh), 456 with_kwargs=True 457 ) # type: ignore[misc] 458 else: 459 module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg] 460 return module 461 462 463class PrepareModuleOutput(ParallelStyle): 464 """ 465 Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to 466 ``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``. 467 468 Keyword Args: 469 output_layouts (Union[Placement, Tuple[Placement]]): 470 The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to 471 DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors, 472 ``None`` need to be specified as a placeholder. 473 desired_output_layouts (Union[Placement, Tuple[Placement]]): 474 The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module 475 have the desired DTensor layouts. 476 use_local_output (bool, optional): 477 Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True. 478 Returns: 479 A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs. 480 481 Example:: 482 >>> # xdoctest: +SKIP(failing) 483 >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput 484 >>> from torch.distributed.device_mesh import init_device_mesh 485 >>> ... 486 >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule 487 >>> tp_mesh = init_device_mesh("cuda", (8,)) 488 >>> 489 >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor 490 >>> # and then redistributed to Sharded DTensor. 491 >>> parallelize_module( 492 >>> block, # this can be a submodule or module 493 >>> tp_mesh, 494 >>> parallelize_plan = PrepareModuleOutput( 495 >>> output_layouts=Replicate(), 496 >>> desired_output_layouts=Shard(0) 497 >>> ) 498 >>> ) 499 """ 500 def __init__( 501 self, 502 *, 503 output_layouts: Union[Placement, Tuple[Placement]], 504 desired_output_layouts: Union[Placement, Tuple[Placement]], 505 use_local_output: bool = True 506 ): 507 self.output_layouts = (output_layouts,) if isinstance(output_layouts, Placement) else output_layouts 508 self.desired_output_layouts = \ 509 (desired_output_layouts,) if isinstance(desired_output_layouts, Placement) else desired_output_layouts 510 self.use_local_output = use_local_output 511 assert len(self.output_layouts) == len(self.desired_output_layouts), \ 512 "output_layouts and desired_output_layouts should have same length!" 513 514 def _prepare_out_fn(self, outputs, device_mesh): 515 prepared_outputs = [] 516 if not isinstance(outputs, tuple): 517 outputs = (outputs,) 518 if len(outputs) != len(self.output_layouts): 519 raise ValueError("module outputs and output_layouts should have same length!") 520 for out, out_layout, desired_out_layout in zip(outputs, self.output_layouts, self.desired_output_layouts): 521 if out_layout is not None: 522 if isinstance(out, DTensor): 523 # TODO: re-enable the check once we fix the compile path 524 # assert out.placements[0] == out_layout 525 dt_out = out 526 else: 527 dt_out = DTensor.from_local(out, device_mesh, (out_layout,), run_check=False) 528 529 if out_layout != desired_out_layout: 530 dt_out = dt_out.redistribute(placements=(desired_out_layout,)) 531 prepared_outputs.append(dt_out.to_local() if self.use_local_output else dt_out) 532 else: 533 prepared_outputs.append(out) 534 if len(prepared_outputs) == 1: 535 return prepared_outputs[0] 536 else: 537 return tuple(prepared_outputs) 538 539 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 540 module.register_forward_hook(lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)) # type: ignore[misc, call-arg] 541 return module 542