xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/parallel/style.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3from abc import ABC, abstractmethod
4from typing import Optional, Union, Tuple, Dict, Any
5from functools import partial
6
7import torch
8import torch.nn as nn
9from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor, distribute_module
10
11
12__all__ = [
13    "ParallelStyle",
14    "RowwiseParallel",
15    "SequenceParallel",
16    "ColwiseParallel",
17    "PrepareModuleInput",
18    "PrepareModuleOutput",
19]
20
21
22class ParallelStyle(ABC):
23    """
24    The parallel style contract defines how the module or submodule should be parallelized.
25
26    It only defines the ``apply`` method for ``parallelize_module`` to use, this allows maximum
27    flexibility for different kind of style implementations.
28    """
29
30    @abstractmethod
31    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
32        ...
33
34
35class ColwiseParallel(ParallelStyle):
36    """
37    Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding.
38    Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules.
39    (i.e. MLP, Attention)
40
41    Keyword Args:
42        input_layouts (Placement, optional):
43            The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
44            become a DTensor. If not specified, we assume the input tensor to be replicated.
45        output_layouts (Placement, optional):
46            The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
47            with the user desired layout. If not specified, the output tensor is sharded on the last dimension.
48        use_local_output (bool, optional):
49            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
50    Returns:
51        A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module.
52
53    Example::
54        >>> # xdoctest: +SKIP(failing)
55        >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
56        >>> from torch.distributed.device_mesh import init_device_mesh
57        >>> ...
58        >>> m = Model(...)  # m is a nn.Module that contains a "w1" nn.Linear submodule
59        >>> tp_mesh = init_device_mesh("cuda", (8,))
60        >>>
61        >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor
62        >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.
63        >>>
64        >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()})
65        >>> ...
66
67    .. note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not
68        specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``),
69        keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size.
70    """
71
72    def __init__(
73        self,
74        *,
75        input_layouts: Optional[Placement] = None,
76        output_layouts: Optional[Placement] = None,
77        use_local_output: bool = True
78    ):
79        super().__init__()
80        self.input_layouts = (input_layouts or Replicate(), )
81        self.output_layouts = (output_layouts or Shard(-1), )
82        # colwise linear runtime sharding (desired sharding):
83        # 1. requires replicate input
84        # 2. shard output on last dim
85        self.desired_input_layouts = (Replicate(), )
86        self.use_local_output = use_local_output
87
88    @staticmethod
89    def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
90        # TODO: figure out dynamo support for instance method and switch this to instance method
91
92        # annotate module input placements/sharding with input_layouts
93        input_tensor = inputs[0]
94        if not isinstance(input_tensor, DTensor):
95            input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
96
97        # transform the input layouts to the desired layouts of ColwiseParallel
98        if input_layouts != desired_input_layouts:
99            input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
100        return input_tensor
101
102    def _partition_linear_fn(self, name, module, device_mesh):
103        # colwise shard weight/bias to Shard(0), weight be Shard(0)
104        # means Colwise as Linear is input * weight^T + bias, where
105        # weight would become Shard(1)
106        for name, param in module.named_parameters():
107            dist_param = nn.Parameter(
108                distribute_tensor(param, device_mesh, [Shard(0)])
109            )
110            module.register_parameter(name, dist_param)
111
112    def _partition_embedding_fn(self, name, module, device_mesh):
113        # colwise shard embedding.weight is straight forward as Shard(1)
114        for name, param in module.named_parameters():
115            dist_param = nn.Parameter(
116                distribute_tensor(param, device_mesh, [Shard(1)])
117            )
118            module.register_parameter(name, dist_param)
119
120    @staticmethod
121    def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
122        # outputs is a shard on last dimension DTensor, i.e. Shard(-1)
123        if outputs.placements != output_layouts:
124            outputs = outputs.redistribute(placements=output_layouts, async_op=True)
125        # back to local tensor
126        return outputs.to_local() if use_local_output else outputs
127
128    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
129        if isinstance(module, nn.Linear):
130            partition_fn = self._partition_linear_fn
131        elif isinstance(module, nn.Embedding):
132            partition_fn = self._partition_embedding_fn
133        else:
134            raise NotImplementedError("ColwiseParallel currently only support nn.Linear and nn.Embedding!")
135
136        return distribute_module(
137            module,
138            device_mesh,
139            partition_fn,
140            partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
141            partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
142        )
143
144
145class RowwiseParallel(ParallelStyle):
146    """
147    Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.
148    Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.
149    (i.e. MLP, Attention)
150
151    Keyword Args:
152        input_layouts (Placement, optional):
153            The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
154            become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.
155        output_layouts (Placement, optional):
156            The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
157            with the user desired layout. If not specified, the output tensor is replicated.
158        use_local_output (bool, optional):
159            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
160    Returns:
161        A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.
162
163    Example::
164        >>> # xdoctest: +SKIP(failing)
165        >>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
166        >>> from torch.distributed.device_mesh import init_device_mesh
167        >>> ...
168        >>> m = Model(...)  # m is a nn.Module that contains a "w2" nn.Linear submodule
169        >>> tp_mesh = init_device_mesh("cuda", (8,))
170        >>>
171        >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim
172        >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.
173        >>>
174        >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}),
175        >>> ...
176    """
177
178    def __init__(
179        self,
180        *,
181        input_layouts: Optional[Placement] = None,
182        output_layouts: Optional[Placement] = None,
183        use_local_output: bool = True
184    ):
185        super().__init__()
186        self.input_layouts = (input_layouts or Shard(-1), )
187        self.output_layouts = (output_layouts or Replicate(), )
188        self.use_local_output = use_local_output
189
190    @staticmethod
191    def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
192        input_tensor = inputs[0]
193        if not isinstance(input_tensor, DTensor):
194            input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
195
196        if input_layouts != desired_input_layouts:
197            input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
198        return input_tensor
199
200    def _partition_linear_fn(self, name, module, device_mesh):
201        # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
202        # means Rowwise as nn.Linear is input * weight^T + bias, where
203        # weight would become Shard(0)
204        module.register_parameter("weight", nn.Parameter(
205            distribute_tensor(module.weight, device_mesh, [Shard(1)])
206        ))
207        if module.bias is not None:
208            module.register_parameter("bias", nn.Parameter(
209                distribute_tensor(module.bias, device_mesh, [Replicate()])
210            ))
211
212    def _partition_embedding_fn(self, name, module, device_mesh):
213        # rowwise shard embedding.weight is Shard(0)
214        for name, param in module.named_parameters():
215            dist_param = nn.Parameter(
216                distribute_tensor(param, device_mesh, [Shard(0)])
217            )
218            module.register_parameter(name, dist_param)
219
220    @staticmethod
221    def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
222        # Rowwise sharding produces partial output, depending on output layouts:
223        # 1. to replicate -> allreduce
224        # 2. to shard -> reduce_scatter
225        if outputs.placements != output_layouts:
226            outputs = outputs.redistribute(placements=output_layouts, async_op=True)
227        # back to local tensor if use_local_output is True
228        return outputs.to_local() if use_local_output else outputs
229
230    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
231        if isinstance(module, nn.Linear):
232            partition_fn = self._partition_linear_fn
233            # rowwise linear runtime sharding requires input tensor shard on last dim
234            self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1), )
235        elif isinstance(module, nn.Embedding):
236            partition_fn = self._partition_embedding_fn
237            # rowwise embedding runtime sharding requires input tensor replicated
238            self.desired_input_layouts = (Replicate(), )
239        else:
240            raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!")
241
242        return distribute_module(
243            module,
244            device_mesh,
245            partition_fn,
246            partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
247            partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
248        )
249
250
251class SequenceParallel(ParallelStyle):
252    """
253    SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
254    input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the
255    `RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__
256
257    This style implements the operation that is described in the paper
258    `Reducing Activation Recomputation in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__
259
260    Both the input and output of the ``nn.Module`` will be sharded on the sequence dimension.
261
262    Keyword Args:
263        sequence_dim (int, optional):
264            The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
265            become a DTensor that is sharded on the sequence dimension, default: 1.
266        use_local_output (bool, optional):
267            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
268    Returns:
269        A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
270
271    Example::
272        >>> # xdoctest: +SKIP(failing)
273        >>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
274        >>> from torch.distributed.device_mesh import init_device_mesh
275        >>> ...
276        >>> m = Model(...)  # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
277        >>> tp_mesh = init_device_mesh("cuda", (8,))
278        >>>
279        >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
280        >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
281        >>>
282        >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
283        >>> ...
284
285    .. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
286        ``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
287        inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
288        to ensure that they are replicated.
289    """
290    def __init__(
291        self,
292        *,
293        sequence_dim: int = 1,
294        use_local_output: bool = False
295    ):
296        super().__init__()
297        self.sequence_dim = sequence_dim
298        self.use_local_output = use_local_output
299
300    def _replicate_module_fn(self, name: str, module: nn.Module, device_mesh: DeviceMesh):
301        for p_name, param in module.named_parameters():
302            # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow
303            # us to simply just use from_local
304            replicated_param = torch.nn.Parameter(
305                DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
306            )
307            module.register_parameter(p_name, replicated_param)
308
309    @staticmethod
310    def _prepare_input_fn(sequence_dim, mod, inputs, device_mesh):
311        input_tensor = inputs[0]
312        if isinstance(input_tensor, DTensor):
313            return inputs
314        elif isinstance(input_tensor, torch.Tensor):
315            return DTensor.from_local(input_tensor, device_mesh, [Shard(sequence_dim)], run_check=False)
316        else:
317            raise ValueError(f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}")
318
319    @staticmethod
320    def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
321        return outputs.to_local() if use_local_output else outputs
322
323    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
324        return distribute_module(
325            module,
326            device_mesh,
327            self._replicate_module_fn,
328            partial(self._prepare_input_fn, self.sequence_dim),
329            partial(self._prepare_output_fn, self.use_local_output),
330        )
331
332
333class PrepareModuleInput(ParallelStyle):
334    """
335    Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to
336    ``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``.
337
338    Keyword Args:
339        input_layouts (Union[Placement, Tuple[Optional[Placement]]]):
340            The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to
341            DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified
342            as a placeholder. default: None.
343        desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]):
344            The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module
345            have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None.
346        input_kwarg_layouts (Dict[str, Placement]):
347            The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors.
348            default: None
349        desired_input_kwarg_layouts: (Dict[str, Placement]):
350            The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module
351            have the desired DTensor layouts. default: None.
352        use_local_output (bool, optional):
353            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False.
354    Returns:
355        A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs.
356
357    Example::
358        >>> # xdoctest: +SKIP(failing)
359        >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
360        >>> from torch.distributed.device_mesh import init_device_mesh
361        >>> ...
362        >>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
363        >>> tp_mesh = init_device_mesh("cuda", (8,))
364        >>>
365        >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
366        >>> # and then redistributed to Replicated DTensor.
367        >>> parallelize_module(
368        >>>     block, # this can be a submodule or module
369        >>>     tp_mesh,
370        >>>     parallelize_plan={
371        >>>         "attn": PrepareModuleInput(
372        >>>             input_layouts=(Shard(0), None, None, ...),
373        >>>             desired_input_layouts=(Replicate(), None, None, ...)
374        >>>         ),
375        >>>     }
376        >>> )
377    """
378
379    def __init__(
380        self,
381        *,
382        input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None,
383        desired_input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None,
384        input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
385        desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
386        use_local_output: bool = False
387    ):
388        self.input_layouts = (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts
389        self.desired_input_layouts = \
390            (desired_input_layouts,) if isinstance(desired_input_layouts, Placement) else desired_input_layouts
391        self.use_local_output = use_local_output
392        if self.input_layouts is not None:
393            assert self.desired_input_layouts is not None, "desired module inputs should not be None!"
394            assert len(self.input_layouts) == len(self.desired_input_layouts), \
395                "input_layouts and desired_input_layouts should have same length!"
396        self.with_kwargs = input_kwarg_layouts is not None
397        self.input_kwarg_layouts = input_kwarg_layouts or {}
398        self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {}
399        if self.with_kwargs:
400            assert len(self.input_kwarg_layouts) == len(self.desired_input_kwarg_layouts), \
401                "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
402
403    def _prepare_input_arg(
404        self,
405        input: Any,
406        mesh: DeviceMesh,
407        input_layout: Optional[Placement],
408        desired_layout: Optional[Placement]
409    ):
410        if input_layout is not None:
411            if isinstance(input, DTensor):
412                # TODO: re-enable the check once we fix the compile path
413                # assert inp.placements[0] == input_layout
414                dt_inp = input
415            else:
416                assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!"
417                dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False)
418
419            if desired_layout is not None and input_layout != desired_layout:
420                dt_inp = dt_inp.redistribute(placements=(desired_layout,))
421
422            return dt_inp.to_local() if self.use_local_output else dt_inp
423        else:
424            return input
425
426    def _prepare_input_fn(self, inputs, device_mesh):
427        if self.input_layouts is None:
428            return inputs
429        prepared_inputs = []
430        if not isinstance(inputs, tuple):
431            inputs = (inputs,)
432        if len(inputs) != len(self.input_layouts):
433            raise ValueError("module inputs and input_layouts should have same length!")
434
435        assert self.desired_input_layouts is not None, "desired module inputs should not be None!"
436        for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts):
437            prepared_inputs.append(self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout))
438        return tuple(prepared_inputs)
439
440    def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh):
441        prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh)
442        prepared_kwarg_inputs = {}
443        for kwarg_key in kwarg_inputs.keys():
444            kwarg_val = kwarg_inputs[kwarg_key]
445            input_layout = self.input_kwarg_layouts.get(kwarg_key)
446            desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key)
447
448            prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(kwarg_val, device_mesh, input_layout, desired_input_layout)
449
450        return (prepared_arg_inputs, prepared_kwarg_inputs)
451
452    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
453        if self.with_kwargs:
454            module.register_forward_pre_hook(
455                lambda _, inputs, kwargs: self._prepare_input_kwarg_fn(inputs, kwargs, device_mesh),
456                with_kwargs=True
457            )  # type: ignore[misc]
458        else:
459            module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh))  # type: ignore[misc, call-arg]
460        return module
461
462
463class PrepareModuleOutput(ParallelStyle):
464    """
465    Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to
466    ``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``.
467
468    Keyword Args:
469        output_layouts (Union[Placement, Tuple[Placement]]):
470            The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to
471            DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors,
472            ``None`` need to be specified as a placeholder.
473        desired_output_layouts (Union[Placement, Tuple[Placement]]):
474            The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module
475            have the desired DTensor layouts.
476        use_local_output (bool, optional):
477            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True.
478    Returns:
479        A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs.
480
481    Example::
482        >>> # xdoctest: +SKIP(failing)
483        >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
484        >>> from torch.distributed.device_mesh import init_device_mesh
485        >>> ...
486        >>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
487        >>> tp_mesh = init_device_mesh("cuda", (8,))
488        >>>
489        >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor
490        >>> # and then redistributed to Sharded DTensor.
491        >>> parallelize_module(
492        >>>     block, # this can be a submodule or module
493        >>>     tp_mesh,
494        >>>     parallelize_plan = PrepareModuleOutput(
495        >>>         output_layouts=Replicate(),
496        >>>         desired_output_layouts=Shard(0)
497        >>>     )
498        >>> )
499    """
500    def __init__(
501        self,
502        *,
503        output_layouts: Union[Placement, Tuple[Placement]],
504        desired_output_layouts: Union[Placement, Tuple[Placement]],
505        use_local_output: bool = True
506    ):
507        self.output_layouts = (output_layouts,) if isinstance(output_layouts, Placement) else output_layouts
508        self.desired_output_layouts = \
509            (desired_output_layouts,) if isinstance(desired_output_layouts, Placement) else desired_output_layouts
510        self.use_local_output = use_local_output
511        assert len(self.output_layouts) == len(self.desired_output_layouts), \
512            "output_layouts and desired_output_layouts should have same length!"
513
514    def _prepare_out_fn(self, outputs, device_mesh):
515        prepared_outputs = []
516        if not isinstance(outputs, tuple):
517            outputs = (outputs,)
518        if len(outputs) != len(self.output_layouts):
519            raise ValueError("module outputs and output_layouts should have same length!")
520        for out, out_layout, desired_out_layout in zip(outputs, self.output_layouts, self.desired_output_layouts):
521            if out_layout is not None:
522                if isinstance(out, DTensor):
523                    # TODO: re-enable the check once we fix the compile path
524                    # assert out.placements[0] == out_layout
525                    dt_out = out
526                else:
527                    dt_out = DTensor.from_local(out, device_mesh, (out_layout,), run_check=False)
528
529                if out_layout != desired_out_layout:
530                    dt_out = dt_out.redistribute(placements=(desired_out_layout,))
531                prepared_outputs.append(dt_out.to_local() if self.use_local_output else dt_out)
532            else:
533                prepared_outputs.append(out)
534        if len(prepared_outputs) == 1:
535            return prepared_outputs[0]
536        else:
537            return tuple(prepared_outputs)
538
539    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
540        module.register_forward_hook(lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh))  # type: ignore[misc, call-arg]
541        return module
542