xref: /aosp_15_r20/external/pytorch/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3import math
4from collections import defaultdict
5from typing import Dict
6
7import torch
8import torch.distributed as dist
9from torch.distributed import distributed_c10d
10
11from . import default_hooks as default
12
13
14__all__ = ["PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"]
15
16logger = logging.getLogger(__name__)
17
18
19def _orthogonalize(matrices, epsilon=0):
20    """
21    Decide between Gram-Schmidt or QR factorization to orthogonalize a batch of matrices.
22
23    QR factorization doesn't work with half-precision, but it is usually faster with a rank > 2.
24    """
25    assert len(matrices.shape) == 3 and matrices.shape[2] <= matrices.shape[1]
26
27    num_matrices = matrices.shape[0]
28    rank = matrices.shape[2]
29    dtype = matrices.dtype
30    if rank <= 2 or dtype in [torch.float16, torch.bfloat16]:
31        _orthogonalize_gram_schmidt(matrices, epsilon=epsilon)
32    else:
33        torch.linalg.qr(
34            matrices,
35            out=(
36                matrices,
37                torch.empty(
38                    num_matrices, rank, rank, device=matrices.device, dtype=dtype
39                ),
40            ),
41        )
42
43
44def _orthogonalize_gram_schmidt(matrices, epsilon=0):
45    """
46    Apply Gram-Schmidt procedure to orthogonalize a batch of matrices.
47
48    If epsilon is 0, this is equivalent to `torch.qr(matrices, out=(matrices, _))`,
49    """
50    num_cols = matrices.shape[2]
51    for i in range(num_cols):
52        # Normalize the i'th column.
53        col = matrices[:, :, i : i + 1]
54        # If no epsilon is added here, division by zero may be caused by vanishing gradients.
55        # This epsilon is not needed if the input batch of matrices covers the gradients of at least one entire layer
56        # in the neural network.
57        if epsilon == 0:
58            # Note that col ** 2 can underflow/overflow if we use FP16.
59            # May need to consider multiplying a scaling factor and dividing it later, or using bfloat16 instead.
60            try:
61                col /= torch.norm(col, dim=1, keepdim=True)
62            except ZeroDivisionError:
63                logger.error(
64                    "The matrices to be orthogonalized has at least a column of all 0s. Please set a small value such as 1e-8 "
65                    "as `orthogonalization_epsilon` in PowerSGD state."
66                )
67                # Recover the values from NaNs to 0s.
68                col.fill_(0.0)
69        else:
70            col /= torch.norm(col, dim=1, keepdim=True) + epsilon
71        # Project it on the rest and remove it.
72        if i + 1 < num_cols:
73            rest = matrices[:, :, i + 1 :]
74            rest -= torch.sum(col * rest, dim=1, keepdim=True) * col
75
76
77def _should_compress(
78    num_rows, num_cols, matrix_approximation_rank, min_compression_rate
79):
80    """
81    Recommend if tensor given is worth compressing.
82
83    Returns a recommendation as to whether the 2D tensor described by the arguments is worth compressing,
84    including statistics describing the expected savings from compression.  We consider a tensor worth
85    compressing when ``min_compression_rate`` < uncompressed size / compressed size, where
86    uncompressed size = ``num_rows`` * ``num_cols``,
87    and compressed size = (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``.
88
89    The result of this function is a tuple of the form (compression_recommendation, uncompressed_el_count, compressed_el_count), where:
90
91    compression_recommendation is true if the tensor is worth compressing, and false otherwise (see above);
92
93    uncompressed_el_count is the uncompressed element count, i.e. ``num_rows`` * ``num_cols``; and,
94
95    compress_el_count is the element count after compression, i.e. (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``.
96    """  # noqa: B950
97    uncompressed_size = num_rows * num_cols
98    compressed_size = (num_rows + num_cols) * matrix_approximation_rank
99    return (
100        compressed_size * min_compression_rate < uncompressed_size,
101        uncompressed_size,
102        compressed_size,
103    )
104
105
106def _report_compression_stats(bucket, state):
107    """Report compression stats at frequency of ``compression_stats_logging_frequency`` specified in PowerSGD state."""
108    if bucket.is_last() and state.iter >= state.next_stats_report:
109        stats = state.compression_stats()
110        logger.info(
111            "Compression stats: iter %s, total before compression %s, total after compression %s, "
112            "rate %s",
113            state.iter,
114            stats[1],
115            stats[2],
116            stats[0],
117        )
118        state.next_stats_report = state.iter + state.compression_stats_logging_frequency
119
120
121class PowerSGDState:
122    r"""
123    Store both the algorithm's hyperparameters and internal state for all gradients during training.
124
125    Particularly, ``matrix_approximation_rank`` and ``start_powerSGD_iter`` are the main hyperparameters that should be tuned by the user.
126    For performance, we suggest to keep binary hyperparameters ``use_error_feedback`` and ``warm_start`` on.
127
128    1. ``matrix_approximation_rank`` controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression.
129
130        1.1. If ``matrix_approximation_rank`` is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy.
131
132        1.2. The increase of ``matrix_approximation_rank`` can substantially increase the computation costs of the compression, and the accuracy may not be further improved beyond a certain ``matrix_approximation_rank`` threshold.
133
134    To tune ``matrix_approximation_rank``, we suggest to start from 1 and increase by factors of 2 (like an exponential grid search, 1, 2, 4, ...), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32.
135
136    2. ``start_powerSGD_iter`` defers PowerSGD compression until step ``start_powerSGD_iter``, and vanilla allreduce runs prior to step ``start_powerSGD_iter``. This hybrid scheme of **vanilla allreduce + PowerSGD** can effectively improve the accuracy, even a relatively small ``matrix_approximation_rank`` is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy.
137
138    To tune ``start_powerSGD_iter``, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached. If there is a warm-up stage in the training, ``start_powerSGD_iter`` typically should be no less than the number of warm-up steps.
139
140    3. ``min_compression_rate`` is the minimum compression rate required when a layer is compressed. Due to the computation overheads incurred by the compression, a tensor is worth compressing only if there can be sufficient saving in bandwidth, where ``(num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols``. If the specified compression rate threshold cannot be satisfied, the tensor will be directly allreduced without compression.
141
142    Compression statistics are logged every ``compression_stats_logging_frequency`` iterations once PowerSGD compression starts.
143
144    4. ``orthogonalization_epsilon`` can be a very small value (e.g., 1e-8) added to every normalized matrix column in orthogonalization step, to prevent div-by-zero error if any column has all 0s. If this can already be prevented (e.g., by batch normalization), an epsilon of 0 is recommended for accuracy.
145
146    5. ``batch_tensors_with_same_shape`` controls whether to compress and decompress tensors with same shape in a batched operation to achieve higher parallelism. Note that you should also increase the bucket size (i.e., ``bucket_cap_mb`` arg in DDP constructor) to make more same-shaped tensors appear in the same bucket, however this may reduce the overlap between computation and communication, and increase the memory footprint due to stacking the tensors of the same shape. Set to ``True`` if the compression / decompression computation is a bottleneck.
147
148    .. warning ::
149        If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2.
150        This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,
151        and this can conflict with any tensor memorized before the rebuild process.
152    """  # noqa: B950
153
154    __slots__ = [
155        "process_group",
156        # The fields below are the hyperparameters that often need to be tuned by the user.
157        "matrix_approximation_rank",
158        "start_powerSGD_iter",
159        # The fields below are the hyperparameters that seldom need be tuned by the user.
160        "min_compression_rate",
161        "orthogonalization_epsilon",
162        # The fields below are the binary hyperparameters recommended to be turned on for performance and accuracy.
163        "use_error_feedback",
164        "warm_start",
165        "batch_tensors_with_same_shape",
166        # The fields below are internal state.
167        "rng",
168        "error_dict",
169        "p_memory_dict",
170        "q_memory_dict",
171        "iter",
172        # The fields below are for recording compression stats.
173        "total_numel_before_compression",
174        "total_numel_after_compression",
175        "compression_stats_logging_frequency",
176        "next_stats_report",
177    ]
178
179    def __init__(
180        self,
181        process_group,
182        matrix_approximation_rank=1,
183        start_powerSGD_iter=1_000,
184        min_compression_rate=2,
185        use_error_feedback=True,
186        warm_start=True,
187        orthogonalization_epsilon=0,
188        random_seed=0,
189        compression_stats_logging_frequency=10_000,
190        batch_tensors_with_same_shape: bool = False,
191    ):
192        logger.info(
193            "PowerSGD config: matrix_approximation_rank = %s; start_powerSGD_iter = %s; "
194            "min_compression_rate = %s; orthogonalization_epsilon = %s; use_error_feedback = %s; warm_start = %s; "
195            "random_seed = %s; compression_stats_logging_frequency = %s; batch_tensors_with_same_shape = %s",
196            matrix_approximation_rank,
197            start_powerSGD_iter,
198            min_compression_rate,
199            orthogonalization_epsilon,
200            use_error_feedback,
201            warm_start,
202            random_seed,
203            compression_stats_logging_frequency,
204            batch_tensors_with_same_shape,
205        )
206
207        self.process_group = process_group
208        self.matrix_approximation_rank = matrix_approximation_rank
209        # Deferring PowerSGD compression util step 'start_powerSGD_iter' can have two advantages:
210        # 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss,
211        # even if the matrix approximation rank is increased to a large value.
212        # To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce
213        # (or a more conservative compression such as FP16 compression) with PowerSGD.
214        # 2) There is an internal optimization of rebuilding buckets process in DDP,
215        # in order to save the memory space.
216        # This step takes place after the first iteration.
217        # However, this means that the shape of input bucketized tensors is subject to change,
218        # which will complicate the implementations of error feedback and warm-up.
219        # Running vanilla allreduce in the first few iterations can avoid this complexity.
220        if (use_error_feedback or warm_start) and start_powerSGD_iter <= 1:
221            raise ValueError(
222                "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
223                "because PowerSGD can only be applied after the first two iterations in DDP."
224            )
225        self.start_powerSGD_iter = start_powerSGD_iter
226        self.min_compression_rate = min_compression_rate
227        # Error feedback is usually crucial for both for convergence and generalization,
228        # because PowerSGD is a biased compressor,
229        # i.e., compressing and decompressing a random gradient does not yield the original in expectation.
230        # This mechanism requires a temporary copy of the input gradients,
231        # so it increases the peak memory consumption by the size of the gradient tensor.
232        # However, if the target matrices are known to be exactly low-ranked (instead of just low stable rank),
233        # sometimes it is possible to converge to the optima without error feedback.
234        # See: http://proceedings.mlr.press/v54/yurtsever17a/yurtsever17a.pdf
235        self.use_error_feedback = use_error_feedback
236        # Warm-start reuses P(s) and Q(s) from the previous iteration.
237        # This can improve the approximation quality and hence improve the accuracy.
238        # Additionally, by avoiding the initialization of these low-rank tensors at every step,
239        # this can also accelerate training.
240        # However, this is at the cost of extra memory.
241        self.warm_start = warm_start
242        # Can use a very small value to prevent div-by-zero error caused by orthogonalization of vanishing gradients.
243        self.orthogonalization_epsilon = orthogonalization_epsilon
244        # The purpose of this RNG is to generate different random seeds for initializing Q across iterations,
245        # but in the same order for all the DDP replicas.
246        # Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps.
247        # If the same random projection is used,
248        # there will be differences between the gradients that are never synchronized.
249        import numpy as np
250
251        self.rng = np.random.RandomState(random_seed)
252        # Since there is only a single state instance for all the input buckets,
253        # need to maintain a dictionary that maps each bucket index to the local error.
254        self.error_dict: Dict[int, torch.Tensor] = {}
255        self.p_memory_dict: Dict[int, torch.Tensor] = {}
256        self.q_memory_dict: Dict[int, torch.Tensor] = {}
257        # Iteration/step in the training loop.
258        self.iter = 0
259        # Compression stats accumulators
260        self.total_numel_before_compression = 0
261        self.total_numel_after_compression = 0
262        # We'll report compression stats every 'compression_stats_logging_frequency' iterations
263        # Note that we always report compression stats at least once.
264        self.compression_stats_logging_frequency = max(
265            1, compression_stats_logging_frequency
266        )
267        self.next_stats_report = 0
268        # Batching tensors with same shape can increase parallelism in compression / decompression computation.
269        # This requires a larger bucket size to make more same-shaped tensor to appear in one bucket, however
270        # this may reduce the overlap between computation and communication, and increase the memory footprint
271        # due to stacking tensors.
272        # Turn on if compression / decompression computation is a bottleneck.
273        self.batch_tensors_with_same_shape = batch_tensors_with_same_shape
274
275    def __getstate__(self):
276        r"""
277        Return a ``Dict[str, Any]`` which will be pickled and saved.
278
279        ``process_group`` is not serializable and excluded from
280        a returned state.
281        """
282        logger.warning(
283            "NOTE: Process group is not serializable and excluded from a saved state."
284        )
285        return {
286            slot: getattr(self, slot)
287            for slot in self.__slots__
288            if slot != "process_group"
289        }
290
291    def __setstate__(self, state):
292        r"""
293        Take a provided ``state`` and set to this ``PowerSGDState`` instance.
294
295        ``process_group`` is set to default.
296        """
297        self.process_group = distributed_c10d._get_default_group()
298        logger.warning(
299            "NOTE: Process group will be set to a default group (i.e. the world size).\
300                If a different group is desired, please set `self.process_group` after PowerSGD state is loaded."
301        )
302        for slot, value in state.items():
303            setattr(self, slot, value)
304
305    def maybe_increase_iter(self, bucket):
306        """Track iterations and trigger log message at start of local SGD."""
307        # Since bucket 0 is the last bucket to allreduce in an iteration.
308        # Only increase `iter` when bucket 0 is processed.
309        if bucket.is_last():
310            self.iter += 1
311
312        if self.iter == self.start_powerSGD_iter:
313            logger.info("Start to apply PowerSGD after %s iterations.", self.iter)
314
315    def compression_stats(self):
316        r"""
317        Return latest compression statistics as tuple.
318
319        Returns tuple of form (compress_rate, numel_before_compression, numel_after_compression) where:
320
321        compress_rate is the effective compression rate i.e. (number of elements before compression) / (number of elements after compression);
322
323        numel_before_compression is the total number of elements before compression was applied; and,
324
325        numel_after_compression is the total number of elements after compression was applied.
326        """  # noqa: B950
327        compress_rate = (
328            self.total_numel_before_compression / self.total_numel_after_compression
329            if self.total_numel_after_compression > 0
330            else 0
331        )
332        return (
333            compress_rate,
334            self.total_numel_before_compression,
335            self.total_numel_after_compression,
336        )
337
338
339def powerSGD_hook(
340    state: PowerSGDState, bucket: dist.GradBucket
341) -> torch.futures.Future[torch.Tensor]:
342    r"""
343    Implement PowerSGD algorithm.
344
345    This DDP communication hook implements PowerSGD gradient compression
346    algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_.
347    Once gradient tensors are aggregated across all workers, this hook applies
348    compression as follows:
349
350    1. Views the input flattened 1D gradient tensor as a list of per-parameter tensors, and divides all the tensors into two groups:
351
352        1.1 The tensors that should be compressed before allreduce, because the compression can give enough saving in bandwidth.
353
354        1.2 Rest of the tensors will be directly allreduced without compression, including all the vector tensors (for biases).
355
356    2. Handles uncompressed tensors:
357
358        2.1. Allocate contiguous memory for those uncompressed tensors, and allreduces all the uncompressed tensors as a batch, without compression;
359
360        2.2. Copies the individual uncompressed tensors from the contiguous memory back to the input tensor.
361
362    3. Handles the tensors that should be compressed by PowerSGD compression:
363
364        3.1. For each tensor M, creates two low-rank tensors P and Q for decomposing M,
365        such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
366
367        3.2. Computes each P in Ps, which is equal to MQ;
368
369        3.3. Allreduces Ps as a batch;
370
371        3.4. Orthogonalizes each P in Ps;
372
373        3.5. Computes each Q in Qs, which is approximately equal to M^TP;
374
375        3.6. Allreduces Qs as a batch;
376
377        3.7. Computes each M among all the compressed tensors, which is approximately equal to PQ^T.
378
379    Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
380    This not only gives the user more control over the tradeoff between speedup and accuracy,
381    but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
382
383    Args:
384        state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
385            To tune the compression configs, mainly need to tune ``matrix_approximation_rank``, ``start_powerSGD_iter``
386            and ``min_compression_rate``.
387        bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
388            Note that since DDP comm hook only supports single process single device mode,
389            only exactly one tensor is stored in this bucket.
390
391    Returns:
392        Future handler of the communication, which updates the gradients in place.
393
394    Example::
395        >>> # xdoctest: +SKIP
396        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1,
397                                  start_powerSGD_iter=10, min_compression_rate=0.5)
398        >>> ddp_model.register_comm_hook(state, powerSGD_hook)
399    """  # noqa: B950
400    process_group = state.process_group
401    group_to_use = process_group if process_group is not None else dist.group.WORLD
402    world_size = group_to_use.size()
403
404    # The input tensor is a flattened 1D tensor.
405    input_tensor = bucket.buffer()
406
407    # Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
408    if state.iter < state.start_powerSGD_iter:
409        state.maybe_increase_iter(bucket)
410        return default._allreduce_fut(group_to_use, input_tensor)
411
412    # Apply PowerSGD after `start_powerSGD_iter` iterations.
413    device = input_tensor.device
414    dtype = input_tensor.dtype
415
416    # Incorporate the error from the previous state into the gradients.
417    bucket_index = bucket.index()
418    input_tensor_cp = None
419    total_length = input_tensor.shape[0]
420    if state.use_error_feedback:
421        if bucket_index in state.error_dict:
422            input_tensor.add_(state.error_dict[bucket_index])
423        else:
424            logger.info(
425                "A zero tensor of length %s that represents local error is created.",
426                total_length,
427            )
428            state.error_dict[bucket_index] = torch.zeros(
429                total_length, device=device, dtype=dtype
430            )
431
432        # Keep a copy of the input tensor,
433        # so that we can compute the local error caused by compression later,
434        # by comparing this copy and the input tensor updated after decompression.
435        input_tensor_cp = torch.clone(input_tensor).detach()
436
437    # Unflatten the input tensor into per-parameter tensors, for layer-wise compression.
438    tensors = bucket.gradients()
439
440    # Step I: Divide all the tensors into two groups,
441    # one will be compressed before allreduce and the other will be directly allreduced without compression.
442    tensors_to_compress, uncompressed_tensors = [], []
443    total_Ps_size = 0
444    total_Qs_size = 0
445    for tensor in tensors:
446        matrix = tensor.view(tensor.shape[0], -1)
447        n, m = matrix.shape
448        matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
449        compress_test = _should_compress(
450            n, m, matrix_approximation_rank, state.min_compression_rate
451        )
452        state.total_numel_before_compression += compress_test[1]
453        if compress_test[0]:
454            tensors_to_compress.append(matrix)
455            total_Ps_size += n * matrix_approximation_rank
456            total_Qs_size += m * matrix_approximation_rank
457            state.total_numel_after_compression += compress_test[2]
458        else:
459            uncompressed_tensors.append(tensor)
460            state.total_numel_after_compression += compress_test[1]
461
462    _report_compression_stats(bucket, state)
463
464    # Step II: Handle uncompressed tensors.
465    # Allocate contiguous memory for these tensors to allreduce efficiently.
466    uncompressed_tensors_memory = (
467        torch.cat([tensor.view(-1) for tensor in uncompressed_tensors])
468        if uncompressed_tensors
469        else torch.tensor([], device=device, dtype=dtype)
470    )
471
472    # Step III: Handle the tensors that should be compressed.
473    # Allocate contiguous memory for Ps and Qs to allreduce efficiently.
474    # If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible.
475    # The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied.
476    need_randomize_qs = False
477    if not state.warm_start or bucket_index not in state.p_memory_dict:
478        need_randomize_qs = True
479        # If warm-start is disabled, low-rank tensors will be initialized at every step.
480        # Only log this if warm-start to avoid spamming.
481        if state.warm_start:
482            logger.info(
483                "Allocating contiguous memory of length %s for Ps, and of length %s for Qs, respectively.",
484                total_Ps_size,
485                total_Qs_size,
486            )
487        state.p_memory_dict[bucket_index] = torch.empty(
488            total_Ps_size, device=device, dtype=dtype
489        )
490        state.q_memory_dict[bucket_index] = torch.empty(
491            total_Qs_size, device=device, dtype=dtype
492        )
493
494    # Batch tensors to compress by shape.
495    shape_to_tensors = defaultdict(list)
496    for tensor in tensors_to_compress:
497        shape_to_tensors[tensor.shape].append(tensor)
498
499    # This function decides whether to batch tensors with same shape or not according to the argument,
500    # so the following process could share the same code.
501    def maybe_batched_tensors_to_compress():
502        for tensors in shape_to_tensors.values():
503            if state.batch_tensors_with_same_shape:
504                batch_size = len(tensors)
505                if batch_size == 1:
506                    # Use the original tensor to avoid copy.
507                    yield tensors[0].unsqueeze(0)
508                else:
509                    yield torch.stack(tensors)
510            else:
511                for tensor in tensors:
512                    yield tensor.unsqueeze(0)
513
514    # Create Ps and Qs that point to the allocated memory.
515    tensors_to_compress = []
516    ps = []
517    qs = []
518    p_idx = 0
519    q_idx = 0
520    for tensor in maybe_batched_tensors_to_compress():
521        batch_size, n, m = tensor.shape
522        matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
523        tensors_to_compress.append(tensor)
524        ps.append(
525            state.p_memory_dict[bucket_index][
526                p_idx : p_idx + batch_size * n * matrix_approximation_rank
527            ].view(batch_size, n, matrix_approximation_rank)
528        )
529        qs.append(
530            state.q_memory_dict[bucket_index][
531                q_idx : q_idx + batch_size * m * matrix_approximation_rank
532            ].view(batch_size, m, matrix_approximation_rank)
533        )
534        p_idx += batch_size * n * matrix_approximation_rank
535        q_idx += batch_size * m * matrix_approximation_rank
536
537    # If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values.
538    # The exception is the first iteration when PowerSGD is applied.
539    if not need_randomize_qs:
540        for q in qs:
541            _orthogonalize(q, state.orthogonalization_epsilon)
542    else:
543        with torch.random.fork_rng(devices=[]):
544            # Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training.
545            # The seed makes sure that the initial random values are the same across all the DDP replicas.
546            # This seed should differ at every step.
547            # Since it is very slow to fork RNG state across all the CUDA devices,
548            # only fork on CPU and then move the generated tensor to the CUDA device (by overwriting q).
549            torch.manual_seed(state.rng.randint(1_000_000_000))
550            for q in qs:
551                q.copy_(
552                    torch.randn(
553                        *q.shape,
554                        device="cpu",
555                        dtype=dtype,
556                    )
557                )
558                _orthogonalize(q, state.orthogonalization_epsilon)
559
560    # Compute Ps.
561    for tensor, q, p in zip(tensors_to_compress, qs, ps):
562        torch.bmm(tensor, q, out=p)
563
564    # This allreduce is only applied to uncompressed tensors,
565    # so it should have been kicked off before the above computation on the compressed tensors to hide more communication costs.
566    # However, this somehow requires a separate future chain at this time.
567    allreduce_contiguous_uncompressed_tensors_fut = dist.all_reduce(
568        uncompressed_tensors_memory, group=group_to_use, async_op=True
569    ).get_future()
570
571    def unpack_uncompressed_tensors_and_allreduce_ps(fut):
572        uncompressed_tensors_memory = fut.value()[0].div_(world_size)
573        idx = 0
574        for tensor in uncompressed_tensors:
575            tensor.copy_(
576                uncompressed_tensors_memory[idx : idx + tensor.numel()].view_as(tensor)
577            )
578            idx += tensor.numel()
579
580        # Since these Ps will be orthogonalized later, no need to divide them by world size.
581        return (
582            dist.all_reduce(
583                state.p_memory_dict[bucket_index], group=group_to_use, async_op=True
584            )
585            .get_future()
586            .wait()[0]
587        )
588
589    def compute_qs(fut):
590        state.p_memory_dict[bucket_index] = fut.value()
591        for p in ps:
592            _orthogonalize(p, state.orthogonalization_epsilon)
593
594        # Compute Qs.
595        for tensor, p, q in zip(tensors_to_compress, ps, qs):
596            torch.bmm(tensor.transpose(1, 2), p, out=q)
597
598        # TODO: The above procedure does two matmul+allreduce steps per iteration --
599        # one left multiplication and one right multiplication.
600        # For warm-start, can take one such step at a time, and alternate between them.
601
602        # Allreduce Qs.
603        return (
604            dist.all_reduce(
605                state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
606            )
607            .get_future()
608            .wait()[0]
609        )
610
611    def decompress(fut):
612        state.q_memory_dict[bucket_index] = fut.value().div_(world_size)
613
614        for p, q, tensor in zip(ps, qs, tensors_to_compress):
615            torch.bmm(p, q.transpose(1, 2), out=tensor)
616
617        # Copy batched tensors back to original buffer.
618        if state.batch_tensors_with_same_shape:
619            for tensor in tensors_to_compress:
620                if tensor.shape[0] == 1:
621                    # Skip tensor with batch_size == 1 since itself is the original tensor.
622                    continue
623                original_tensors = shape_to_tensors[tensor.shape[1:]]
624                for i, original_tensor in enumerate(original_tensors):
625                    original_tensor.copy_(tensor[i])
626
627        if torch.cuda.is_available():
628            torch.cuda.synchronize(device)
629
630        if state.use_error_feedback:
631            # Memorize the local errors.
632            state.error_dict[bucket_index] = input_tensor_cp - input_tensor
633        if not state.warm_start:
634            state.p_memory_dict.clear()
635            state.q_memory_dict.clear()
636
637        state.maybe_increase_iter(bucket)
638
639        return input_tensor
640
641    return (
642        allreduce_contiguous_uncompressed_tensors_fut.then(
643            unpack_uncompressed_tensors_and_allreduce_ps
644        )
645        .then(compute_qs)
646        .then(decompress)
647    )
648
649
650def batched_powerSGD_hook(
651    state: PowerSGDState, bucket: dist.GradBucket
652) -> torch.futures.Future[torch.Tensor]:
653    r"""
654    Implement simplified PowerSGD algorithm.
655
656    This DDP communication hook implements a simplified PowerSGD gradient compression
657    algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_.
658    This variant does not compress the gradients layer by layer,
659    but instead compresses the flattened input tensor that batches all the gradients.
660    Therefore, it is **faster** than :meth:`powerSGD_hook`,
661    but usually results in a **much lower accuracy**, unless ``matrix_approximation_rank`` is 1.
662
663    .. warning ::
664        Increasing ``matrix_approximation_rank`` here may not necessarily increase the accuracy,
665        because batching per-parameter tensors without column/row alignment can destroy low-rank structure.
666        Therefore, the user should always consider :meth:`powerSGD_hook` first,
667        and only consider this variant when a satisfactory accuracy can be achieved when ``matrix_approximation_rank`` is 1.
668
669    Once gradient tensors are aggregated across all workers, this hook applies
670    compression as follows:
671
672    1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
673
674    2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
675
676    3. Computes P, which is equal to MQ;
677
678    4. Allreduces P;
679
680    5. Orthogonalizes P;
681
682    6. Computes Q, which is approximately equal to M^TP;
683
684    7. Allreduces Q;
685
686    8. Computes M, which is approximately equal to PQ^T.
687
688    9. Truncates the input tensor to the original length.
689
690    Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
691    This not only gives the user more control over the tradeoff between speedup and accuracy,
692    but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
693
694    Args:
695        state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
696            To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``.
697        bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
698            Note that since DDP comm hook only supports single process single device mode,
699            only exactly one tensor is stored in this bucket.
700
701    Returns:
702        Future handler of the communication, which updates the gradients in place.
703
704    Example::
705        >>> # xdoctest: +SKIP
706        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
707        >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
708    """  # noqa: B950
709    process_group = state.process_group
710    group_to_use = process_group if process_group is not None else dist.group.WORLD
711    world_size = group_to_use.size()
712
713    # The input tensor is a flattened 1D tensor.
714    input_tensor = bucket.buffer()
715
716    # Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
717    if state.iter < state.start_powerSGD_iter:
718        state.maybe_increase_iter(bucket)
719        return default._allreduce_fut(group_to_use, input_tensor)
720
721    # Apply PowerSGD after `start_powerSGD_iter` iterations.
722    device = input_tensor.device
723    total_length = input_tensor.shape[0]
724    state.total_numel_before_compression += total_length
725
726    # View the input tensor as a 2D square-shape tensor, and pad 0s if necessary.
727    square_side_length = math.ceil(math.sqrt(total_length))
728    state.total_numel_after_compression += (
729        square_side_length * state.matrix_approximation_rank * 2
730    )
731    padded_total_length = square_side_length**2
732    input_tensor.resize_(padded_total_length)
733    input_tensor[total_length:padded_total_length].fill_(0)
734
735    _report_compression_stats(bucket, state)
736
737    # Incorporate the error from the previous state into the gradients.
738    bucket_index = bucket.index()
739    input_tensor_cp = None
740    if state.use_error_feedback:
741        if bucket_index in state.error_dict:
742            input_tensor.add_(state.error_dict[bucket_index])
743        else:
744            logger.info(
745                "A zero tensor of length %s that represents local error is created.",
746                padded_total_length,
747            )
748            state.error_dict[bucket_index] = torch.zeros(
749                padded_total_length, device=device, dtype=input_tensor.dtype
750            )
751
752        # Keep a copy of the input tensor,
753        # so that we can compute the local error caused by compression later,
754        # by comparing this copy and the input tensor updated after decompression.
755        input_tensor_cp = torch.clone(input_tensor).detach()
756    matrix = input_tensor.view(square_side_length, square_side_length)
757
758    # Reuse P and Q from the previous iteration if possible.
759    # The memory spaces of P and Q need to be allocated in the first iteration when PowerSGD is applied.
760    if not state.warm_start or bucket_index not in state.p_memory_dict:
761        # If warm-start is disabled, low-rank tensors will be initialized at every step.
762        # Only log this if warm-start to avoid spamming.
763        if state.warm_start:
764            logger.info(
765                "Initializing low-rank tensors P and Q, each of which has a shape of %s x %s.",
766                square_side_length,
767                state.matrix_approximation_rank,
768            )
769
770        def create_low_rank_tensor(fill_random_values, rng):
771            """Return a low-rank 2D tensor of square_side_length * matrix_approximation_rank."""
772            if fill_random_values:
773                with torch.random.fork_rng(devices=[]):
774                    # Fork this RNG to avoid changing the seed globally and affecting the random sampling
775                    # anywhere else in the training.
776                    # The seed makes sure that the initial random values are the same across all the DDP replicas.
777                    # This seed should differ at every step.
778                    # Since it is very slow to fork RNG state across all the CUDA devices,
779                    # only fork on CPU and then move the generated tensor to the CUDA device.
780                    torch.manual_seed(rng.randint(1_000_000_000))
781                    return torch.randn(
782                        square_side_length,
783                        state.matrix_approximation_rank,
784                        device="cpu",
785                        dtype=input_tensor.dtype,
786                    ).to(device)
787            else:
788                return torch.empty(
789                    square_side_length,
790                    state.matrix_approximation_rank,
791                    device=device,
792                    dtype=input_tensor.dtype,
793                )
794
795        state.p_memory_dict[bucket_index] = create_low_rank_tensor(
796            fill_random_values=False, rng=state.rng
797        )
798        state.q_memory_dict[bucket_index] = create_low_rank_tensor(
799            fill_random_values=True, rng=state.rng
800        )
801    _orthogonalize(state.q_memory_dict[bucket_index])
802
803    torch.matmul(
804        matrix, state.q_memory_dict[bucket_index], out=state.p_memory_dict[bucket_index]
805    )
806    allreduce_p_fut = dist.all_reduce(
807        state.p_memory_dict[bucket_index], group=group_to_use, async_op=True
808    ).get_future()
809
810    def compute_q(fut):
811        state.p_memory_dict[bucket_index] = fut.value()[0]
812        _orthogonalize(state.p_memory_dict[bucket_index])
813
814        torch.matmul(
815            matrix.t(),
816            state.p_memory_dict[bucket_index],
817            out=state.q_memory_dict[bucket_index],
818        )
819
820        # TODO: The above procedure does two matmul+allreduce steps per iteration --
821        # one left multiplication and one right multiplication.
822        # For warm-start, can take one such step at a time, and alternate between them.
823
824        return (
825            dist.all_reduce(
826                state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
827            )
828            .get_future()
829            .wait()[0]
830        )
831
832    def decompress(fut):
833        state.q_memory_dict[bucket_index] = fut.value().div_(world_size)
834        torch.matmul(
835            state.p_memory_dict[bucket_index],
836            state.q_memory_dict[bucket_index].t(),
837            out=matrix,
838        )
839
840        if state.use_error_feedback:
841            # Memorize the local errors.
842            state.error_dict[bucket_index] = input_tensor_cp - input_tensor
843        # Removing this seemingly unnecessary sync somehow may cause failures.
844        # See: https://github.com/pytorch/pytorch/pull/54838
845        if torch.cuda.is_available():
846            torch.cuda.synchronize(device)
847        if not state.warm_start:
848            state.p_memory_dict.clear()
849            state.q_memory_dict.clear()
850        ret = input_tensor.resize_(total_length)
851
852        state.maybe_increase_iter(bucket)
853
854        return ret
855
856    return allreduce_p_fut.then(compute_q).then(decompress)
857