xref: /aosp_15_r20/external/pytorch/docs/source/distributed.tensor.parallel.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1.. role:: hidden
2    :class: hidden-section
3
4Tensor Parallelism - torch.distributed.tensor.parallel
5======================================================
6
7Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor
8(`DTensor <https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/README.md>`__)
9and provides different parallelism styles: Colwise, Rowwise, and Sequence Parallelism.
10
11.. warning ::
12    Tensor Parallelism APIs are experimental and subject to change.
13
14The entrypoint to parallelize your ``nn.Module`` using Tensor Parallelism is:
15
16.. automodule:: torch.distributed.tensor.parallel
17
18.. currentmodule:: torch.distributed.tensor.parallel
19
20.. autofunction::  parallelize_module
21
22Tensor Parallelism supports the following parallel styles:
23
24.. autoclass:: torch.distributed.tensor.parallel.ColwiseParallel
25  :members:
26  :undoc-members:
27
28.. autoclass:: torch.distributed.tensor.parallel.RowwiseParallel
29  :members:
30  :undoc-members:
31
32.. autoclass:: torch.distributed.tensor.parallel.SequenceParallel
33  :members:
34  :undoc-members:
35
36To simply configure the nn.Module's inputs and outputs with DTensor layouts
37and perform necessary layout redistributions, without distribute the module
38parameters to DTensors, the following ``ParallelStyle`` s can be used in
39the ``parallelize_plan`` when calling ``parallelize_module``:
40
41.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInput
42  :members:
43  :undoc-members:
44
45.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleOutput
46  :members:
47  :undoc-members:
48
49.. note:: when using the ``Shard(dim)`` as the input/output layouts for the above
50  ``ParallelStyle`` s, we assume the input/output activation tensors are evenly sharded on
51  the tensor dimension ``dim`` on the ``DeviceMesh`` that TP operates on. For instance,
52  since ``RowwiseParallel`` accepts input that is sharded on the last dimension, it assumes
53  the input tensor has already been evenly sharded on the last dimension. For the case of uneven
54  sharded activation tensors, one could pass in DTensor directly to the partitioned modules,
55  and use ``use_local_output=False`` to return DTensor after each ``ParallelStyle``, where
56  DTensor could track the uneven sharding information.
57
58For models like Transformer, we recommend users to use ``ColwiseParallel``
59and ``RowwiseParallel`` together in the parallelize_plan for achieve the desired
60sharding for the entire model (i.e. Attention and MLP).
61
62Parallelized cross-entropy loss computation (loss parallelism), is supported via the following context manager:
63
64.. autofunction:: torch.distributed.tensor.parallel.loss_parallel
65
66.. warning ::
67    The loss_parallel API is experimental and subject to change.
68