1# mypy: allow-untyped-defs 2import logging 3import math 4from collections import defaultdict 5from typing import Dict 6 7import torch 8import torch.distributed as dist 9from torch.distributed import distributed_c10d 10 11from . import default_hooks as default 12 13 14__all__ = ["PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"] 15 16logger = logging.getLogger(__name__) 17 18 19def _orthogonalize(matrices, epsilon=0): 20 """ 21 Decide between Gram-Schmidt or QR factorization to orthogonalize a batch of matrices. 22 23 QR factorization doesn't work with half-precision, but it is usually faster with a rank > 2. 24 """ 25 assert len(matrices.shape) == 3 and matrices.shape[2] <= matrices.shape[1] 26 27 num_matrices = matrices.shape[0] 28 rank = matrices.shape[2] 29 dtype = matrices.dtype 30 if rank <= 2 or dtype in [torch.float16, torch.bfloat16]: 31 _orthogonalize_gram_schmidt(matrices, epsilon=epsilon) 32 else: 33 torch.linalg.qr( 34 matrices, 35 out=( 36 matrices, 37 torch.empty( 38 num_matrices, rank, rank, device=matrices.device, dtype=dtype 39 ), 40 ), 41 ) 42 43 44def _orthogonalize_gram_schmidt(matrices, epsilon=0): 45 """ 46 Apply Gram-Schmidt procedure to orthogonalize a batch of matrices. 47 48 If epsilon is 0, this is equivalent to `torch.qr(matrices, out=(matrices, _))`, 49 """ 50 num_cols = matrices.shape[2] 51 for i in range(num_cols): 52 # Normalize the i'th column. 53 col = matrices[:, :, i : i + 1] 54 # If no epsilon is added here, division by zero may be caused by vanishing gradients. 55 # This epsilon is not needed if the input batch of matrices covers the gradients of at least one entire layer 56 # in the neural network. 57 if epsilon == 0: 58 # Note that col ** 2 can underflow/overflow if we use FP16. 59 # May need to consider multiplying a scaling factor and dividing it later, or using bfloat16 instead. 60 try: 61 col /= torch.norm(col, dim=1, keepdim=True) 62 except ZeroDivisionError: 63 logger.error( 64 "The matrices to be orthogonalized has at least a column of all 0s. Please set a small value such as 1e-8 " 65 "as `orthogonalization_epsilon` in PowerSGD state." 66 ) 67 # Recover the values from NaNs to 0s. 68 col.fill_(0.0) 69 else: 70 col /= torch.norm(col, dim=1, keepdim=True) + epsilon 71 # Project it on the rest and remove it. 72 if i + 1 < num_cols: 73 rest = matrices[:, :, i + 1 :] 74 rest -= torch.sum(col * rest, dim=1, keepdim=True) * col 75 76 77def _should_compress( 78 num_rows, num_cols, matrix_approximation_rank, min_compression_rate 79): 80 """ 81 Recommend if tensor given is worth compressing. 82 83 Returns a recommendation as to whether the 2D tensor described by the arguments is worth compressing, 84 including statistics describing the expected savings from compression. We consider a tensor worth 85 compressing when ``min_compression_rate`` < uncompressed size / compressed size, where 86 uncompressed size = ``num_rows`` * ``num_cols``, 87 and compressed size = (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``. 88 89 The result of this function is a tuple of the form (compression_recommendation, uncompressed_el_count, compressed_el_count), where: 90 91 compression_recommendation is true if the tensor is worth compressing, and false otherwise (see above); 92 93 uncompressed_el_count is the uncompressed element count, i.e. ``num_rows`` * ``num_cols``; and, 94 95 compress_el_count is the element count after compression, i.e. (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``. 96 """ # noqa: B950 97 uncompressed_size = num_rows * num_cols 98 compressed_size = (num_rows + num_cols) * matrix_approximation_rank 99 return ( 100 compressed_size * min_compression_rate < uncompressed_size, 101 uncompressed_size, 102 compressed_size, 103 ) 104 105 106def _report_compression_stats(bucket, state): 107 """Report compression stats at frequency of ``compression_stats_logging_frequency`` specified in PowerSGD state.""" 108 if bucket.is_last() and state.iter >= state.next_stats_report: 109 stats = state.compression_stats() 110 logger.info( 111 "Compression stats: iter %s, total before compression %s, total after compression %s, " 112 "rate %s", 113 state.iter, 114 stats[1], 115 stats[2], 116 stats[0], 117 ) 118 state.next_stats_report = state.iter + state.compression_stats_logging_frequency 119 120 121class PowerSGDState: 122 r""" 123 Store both the algorithm's hyperparameters and internal state for all gradients during training. 124 125 Particularly, ``matrix_approximation_rank`` and ``start_powerSGD_iter`` are the main hyperparameters that should be tuned by the user. 126 For performance, we suggest to keep binary hyperparameters ``use_error_feedback`` and ``warm_start`` on. 127 128 1. ``matrix_approximation_rank`` controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression. 129 130 1.1. If ``matrix_approximation_rank`` is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy. 131 132 1.2. The increase of ``matrix_approximation_rank`` can substantially increase the computation costs of the compression, and the accuracy may not be further improved beyond a certain ``matrix_approximation_rank`` threshold. 133 134 To tune ``matrix_approximation_rank``, we suggest to start from 1 and increase by factors of 2 (like an exponential grid search, 1, 2, 4, ...), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32. 135 136 2. ``start_powerSGD_iter`` defers PowerSGD compression until step ``start_powerSGD_iter``, and vanilla allreduce runs prior to step ``start_powerSGD_iter``. This hybrid scheme of **vanilla allreduce + PowerSGD** can effectively improve the accuracy, even a relatively small ``matrix_approximation_rank`` is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy. 137 138 To tune ``start_powerSGD_iter``, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached. If there is a warm-up stage in the training, ``start_powerSGD_iter`` typically should be no less than the number of warm-up steps. 139 140 3. ``min_compression_rate`` is the minimum compression rate required when a layer is compressed. Due to the computation overheads incurred by the compression, a tensor is worth compressing only if there can be sufficient saving in bandwidth, where ``(num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols``. If the specified compression rate threshold cannot be satisfied, the tensor will be directly allreduced without compression. 141 142 Compression statistics are logged every ``compression_stats_logging_frequency`` iterations once PowerSGD compression starts. 143 144 4. ``orthogonalization_epsilon`` can be a very small value (e.g., 1e-8) added to every normalized matrix column in orthogonalization step, to prevent div-by-zero error if any column has all 0s. If this can already be prevented (e.g., by batch normalization), an epsilon of 0 is recommended for accuracy. 145 146 5. ``batch_tensors_with_same_shape`` controls whether to compress and decompress tensors with same shape in a batched operation to achieve higher parallelism. Note that you should also increase the bucket size (i.e., ``bucket_cap_mb`` arg in DDP constructor) to make more same-shaped tensors appear in the same bucket, however this may reduce the overlap between computation and communication, and increase the memory footprint due to stacking the tensors of the same shape. Set to ``True`` if the compression / decompression computation is a bottleneck. 147 148 .. warning :: 149 If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2. 150 This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP, 151 and this can conflict with any tensor memorized before the rebuild process. 152 """ # noqa: B950 153 154 __slots__ = [ 155 "process_group", 156 # The fields below are the hyperparameters that often need to be tuned by the user. 157 "matrix_approximation_rank", 158 "start_powerSGD_iter", 159 # The fields below are the hyperparameters that seldom need be tuned by the user. 160 "min_compression_rate", 161 "orthogonalization_epsilon", 162 # The fields below are the binary hyperparameters recommended to be turned on for performance and accuracy. 163 "use_error_feedback", 164 "warm_start", 165 "batch_tensors_with_same_shape", 166 # The fields below are internal state. 167 "rng", 168 "error_dict", 169 "p_memory_dict", 170 "q_memory_dict", 171 "iter", 172 # The fields below are for recording compression stats. 173 "total_numel_before_compression", 174 "total_numel_after_compression", 175 "compression_stats_logging_frequency", 176 "next_stats_report", 177 ] 178 179 def __init__( 180 self, 181 process_group, 182 matrix_approximation_rank=1, 183 start_powerSGD_iter=1_000, 184 min_compression_rate=2, 185 use_error_feedback=True, 186 warm_start=True, 187 orthogonalization_epsilon=0, 188 random_seed=0, 189 compression_stats_logging_frequency=10_000, 190 batch_tensors_with_same_shape: bool = False, 191 ): 192 logger.info( 193 "PowerSGD config: matrix_approximation_rank = %s; start_powerSGD_iter = %s; " 194 "min_compression_rate = %s; orthogonalization_epsilon = %s; use_error_feedback = %s; warm_start = %s; " 195 "random_seed = %s; compression_stats_logging_frequency = %s; batch_tensors_with_same_shape = %s", 196 matrix_approximation_rank, 197 start_powerSGD_iter, 198 min_compression_rate, 199 orthogonalization_epsilon, 200 use_error_feedback, 201 warm_start, 202 random_seed, 203 compression_stats_logging_frequency, 204 batch_tensors_with_same_shape, 205 ) 206 207 self.process_group = process_group 208 self.matrix_approximation_rank = matrix_approximation_rank 209 # Deferring PowerSGD compression util step 'start_powerSGD_iter' can have two advantages: 210 # 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss, 211 # even if the matrix approximation rank is increased to a large value. 212 # To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce 213 # (or a more conservative compression such as FP16 compression) with PowerSGD. 214 # 2) There is an internal optimization of rebuilding buckets process in DDP, 215 # in order to save the memory space. 216 # This step takes place after the first iteration. 217 # However, this means that the shape of input bucketized tensors is subject to change, 218 # which will complicate the implementations of error feedback and warm-up. 219 # Running vanilla allreduce in the first few iterations can avoid this complexity. 220 if (use_error_feedback or warm_start) and start_powerSGD_iter <= 1: 221 raise ValueError( 222 "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, " 223 "because PowerSGD can only be applied after the first two iterations in DDP." 224 ) 225 self.start_powerSGD_iter = start_powerSGD_iter 226 self.min_compression_rate = min_compression_rate 227 # Error feedback is usually crucial for both for convergence and generalization, 228 # because PowerSGD is a biased compressor, 229 # i.e., compressing and decompressing a random gradient does not yield the original in expectation. 230 # This mechanism requires a temporary copy of the input gradients, 231 # so it increases the peak memory consumption by the size of the gradient tensor. 232 # However, if the target matrices are known to be exactly low-ranked (instead of just low stable rank), 233 # sometimes it is possible to converge to the optima without error feedback. 234 # See: http://proceedings.mlr.press/v54/yurtsever17a/yurtsever17a.pdf 235 self.use_error_feedback = use_error_feedback 236 # Warm-start reuses P(s) and Q(s) from the previous iteration. 237 # This can improve the approximation quality and hence improve the accuracy. 238 # Additionally, by avoiding the initialization of these low-rank tensors at every step, 239 # this can also accelerate training. 240 # However, this is at the cost of extra memory. 241 self.warm_start = warm_start 242 # Can use a very small value to prevent div-by-zero error caused by orthogonalization of vanishing gradients. 243 self.orthogonalization_epsilon = orthogonalization_epsilon 244 # The purpose of this RNG is to generate different random seeds for initializing Q across iterations, 245 # but in the same order for all the DDP replicas. 246 # Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps. 247 # If the same random projection is used, 248 # there will be differences between the gradients that are never synchronized. 249 import numpy as np 250 251 self.rng = np.random.RandomState(random_seed) 252 # Since there is only a single state instance for all the input buckets, 253 # need to maintain a dictionary that maps each bucket index to the local error. 254 self.error_dict: Dict[int, torch.Tensor] = {} 255 self.p_memory_dict: Dict[int, torch.Tensor] = {} 256 self.q_memory_dict: Dict[int, torch.Tensor] = {} 257 # Iteration/step in the training loop. 258 self.iter = 0 259 # Compression stats accumulators 260 self.total_numel_before_compression = 0 261 self.total_numel_after_compression = 0 262 # We'll report compression stats every 'compression_stats_logging_frequency' iterations 263 # Note that we always report compression stats at least once. 264 self.compression_stats_logging_frequency = max( 265 1, compression_stats_logging_frequency 266 ) 267 self.next_stats_report = 0 268 # Batching tensors with same shape can increase parallelism in compression / decompression computation. 269 # This requires a larger bucket size to make more same-shaped tensor to appear in one bucket, however 270 # this may reduce the overlap between computation and communication, and increase the memory footprint 271 # due to stacking tensors. 272 # Turn on if compression / decompression computation is a bottleneck. 273 self.batch_tensors_with_same_shape = batch_tensors_with_same_shape 274 275 def __getstate__(self): 276 r""" 277 Return a ``Dict[str, Any]`` which will be pickled and saved. 278 279 ``process_group`` is not serializable and excluded from 280 a returned state. 281 """ 282 logger.warning( 283 "NOTE: Process group is not serializable and excluded from a saved state." 284 ) 285 return { 286 slot: getattr(self, slot) 287 for slot in self.__slots__ 288 if slot != "process_group" 289 } 290 291 def __setstate__(self, state): 292 r""" 293 Take a provided ``state`` and set to this ``PowerSGDState`` instance. 294 295 ``process_group`` is set to default. 296 """ 297 self.process_group = distributed_c10d._get_default_group() 298 logger.warning( 299 "NOTE: Process group will be set to a default group (i.e. the world size).\ 300 If a different group is desired, please set `self.process_group` after PowerSGD state is loaded." 301 ) 302 for slot, value in state.items(): 303 setattr(self, slot, value) 304 305 def maybe_increase_iter(self, bucket): 306 """Track iterations and trigger log message at start of local SGD.""" 307 # Since bucket 0 is the last bucket to allreduce in an iteration. 308 # Only increase `iter` when bucket 0 is processed. 309 if bucket.is_last(): 310 self.iter += 1 311 312 if self.iter == self.start_powerSGD_iter: 313 logger.info("Start to apply PowerSGD after %s iterations.", self.iter) 314 315 def compression_stats(self): 316 r""" 317 Return latest compression statistics as tuple. 318 319 Returns tuple of form (compress_rate, numel_before_compression, numel_after_compression) where: 320 321 compress_rate is the effective compression rate i.e. (number of elements before compression) / (number of elements after compression); 322 323 numel_before_compression is the total number of elements before compression was applied; and, 324 325 numel_after_compression is the total number of elements after compression was applied. 326 """ # noqa: B950 327 compress_rate = ( 328 self.total_numel_before_compression / self.total_numel_after_compression 329 if self.total_numel_after_compression > 0 330 else 0 331 ) 332 return ( 333 compress_rate, 334 self.total_numel_before_compression, 335 self.total_numel_after_compression, 336 ) 337 338 339def powerSGD_hook( 340 state: PowerSGDState, bucket: dist.GradBucket 341) -> torch.futures.Future[torch.Tensor]: 342 r""" 343 Implement PowerSGD algorithm. 344 345 This DDP communication hook implements PowerSGD gradient compression 346 algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_. 347 Once gradient tensors are aggregated across all workers, this hook applies 348 compression as follows: 349 350 1. Views the input flattened 1D gradient tensor as a list of per-parameter tensors, and divides all the tensors into two groups: 351 352 1.1 The tensors that should be compressed before allreduce, because the compression can give enough saving in bandwidth. 353 354 1.2 Rest of the tensors will be directly allreduced without compression, including all the vector tensors (for biases). 355 356 2. Handles uncompressed tensors: 357 358 2.1. Allocate contiguous memory for those uncompressed tensors, and allreduces all the uncompressed tensors as a batch, without compression; 359 360 2.2. Copies the individual uncompressed tensors from the contiguous memory back to the input tensor. 361 362 3. Handles the tensors that should be compressed by PowerSGD compression: 363 364 3.1. For each tensor M, creates two low-rank tensors P and Q for decomposing M, 365 such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; 366 367 3.2. Computes each P in Ps, which is equal to MQ; 368 369 3.3. Allreduces Ps as a batch; 370 371 3.4. Orthogonalizes each P in Ps; 372 373 3.5. Computes each Q in Qs, which is approximately equal to M^TP; 374 375 3.6. Allreduces Qs as a batch; 376 377 3.7. Computes each M among all the compressed tensors, which is approximately equal to PQ^T. 378 379 Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations. 380 This not only gives the user more control over the tradeoff between speedup and accuracy, 381 but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers. 382 383 Args: 384 state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. 385 To tune the compression configs, mainly need to tune ``matrix_approximation_rank``, ``start_powerSGD_iter`` 386 and ``min_compression_rate``. 387 bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. 388 Note that since DDP comm hook only supports single process single device mode, 389 only exactly one tensor is stored in this bucket. 390 391 Returns: 392 Future handler of the communication, which updates the gradients in place. 393 394 Example:: 395 >>> # xdoctest: +SKIP 396 >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, 397 start_powerSGD_iter=10, min_compression_rate=0.5) 398 >>> ddp_model.register_comm_hook(state, powerSGD_hook) 399 """ # noqa: B950 400 process_group = state.process_group 401 group_to_use = process_group if process_group is not None else dist.group.WORLD 402 world_size = group_to_use.size() 403 404 # The input tensor is a flattened 1D tensor. 405 input_tensor = bucket.buffer() 406 407 # Run vanilla allreduce in the first `start_powerSGD_iter` iterations. 408 if state.iter < state.start_powerSGD_iter: 409 state.maybe_increase_iter(bucket) 410 return default._allreduce_fut(group_to_use, input_tensor) 411 412 # Apply PowerSGD after `start_powerSGD_iter` iterations. 413 device = input_tensor.device 414 dtype = input_tensor.dtype 415 416 # Incorporate the error from the previous state into the gradients. 417 bucket_index = bucket.index() 418 input_tensor_cp = None 419 total_length = input_tensor.shape[0] 420 if state.use_error_feedback: 421 if bucket_index in state.error_dict: 422 input_tensor.add_(state.error_dict[bucket_index]) 423 else: 424 logger.info( 425 "A zero tensor of length %s that represents local error is created.", 426 total_length, 427 ) 428 state.error_dict[bucket_index] = torch.zeros( 429 total_length, device=device, dtype=dtype 430 ) 431 432 # Keep a copy of the input tensor, 433 # so that we can compute the local error caused by compression later, 434 # by comparing this copy and the input tensor updated after decompression. 435 input_tensor_cp = torch.clone(input_tensor).detach() 436 437 # Unflatten the input tensor into per-parameter tensors, for layer-wise compression. 438 tensors = bucket.gradients() 439 440 # Step I: Divide all the tensors into two groups, 441 # one will be compressed before allreduce and the other will be directly allreduced without compression. 442 tensors_to_compress, uncompressed_tensors = [], [] 443 total_Ps_size = 0 444 total_Qs_size = 0 445 for tensor in tensors: 446 matrix = tensor.view(tensor.shape[0], -1) 447 n, m = matrix.shape 448 matrix_approximation_rank = min(n, m, state.matrix_approximation_rank) 449 compress_test = _should_compress( 450 n, m, matrix_approximation_rank, state.min_compression_rate 451 ) 452 state.total_numel_before_compression += compress_test[1] 453 if compress_test[0]: 454 tensors_to_compress.append(matrix) 455 total_Ps_size += n * matrix_approximation_rank 456 total_Qs_size += m * matrix_approximation_rank 457 state.total_numel_after_compression += compress_test[2] 458 else: 459 uncompressed_tensors.append(tensor) 460 state.total_numel_after_compression += compress_test[1] 461 462 _report_compression_stats(bucket, state) 463 464 # Step II: Handle uncompressed tensors. 465 # Allocate contiguous memory for these tensors to allreduce efficiently. 466 uncompressed_tensors_memory = ( 467 torch.cat([tensor.view(-1) for tensor in uncompressed_tensors]) 468 if uncompressed_tensors 469 else torch.tensor([], device=device, dtype=dtype) 470 ) 471 472 # Step III: Handle the tensors that should be compressed. 473 # Allocate contiguous memory for Ps and Qs to allreduce efficiently. 474 # If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible. 475 # The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied. 476 need_randomize_qs = False 477 if not state.warm_start or bucket_index not in state.p_memory_dict: 478 need_randomize_qs = True 479 # If warm-start is disabled, low-rank tensors will be initialized at every step. 480 # Only log this if warm-start to avoid spamming. 481 if state.warm_start: 482 logger.info( 483 "Allocating contiguous memory of length %s for Ps, and of length %s for Qs, respectively.", 484 total_Ps_size, 485 total_Qs_size, 486 ) 487 state.p_memory_dict[bucket_index] = torch.empty( 488 total_Ps_size, device=device, dtype=dtype 489 ) 490 state.q_memory_dict[bucket_index] = torch.empty( 491 total_Qs_size, device=device, dtype=dtype 492 ) 493 494 # Batch tensors to compress by shape. 495 shape_to_tensors = defaultdict(list) 496 for tensor in tensors_to_compress: 497 shape_to_tensors[tensor.shape].append(tensor) 498 499 # This function decides whether to batch tensors with same shape or not according to the argument, 500 # so the following process could share the same code. 501 def maybe_batched_tensors_to_compress(): 502 for tensors in shape_to_tensors.values(): 503 if state.batch_tensors_with_same_shape: 504 batch_size = len(tensors) 505 if batch_size == 1: 506 # Use the original tensor to avoid copy. 507 yield tensors[0].unsqueeze(0) 508 else: 509 yield torch.stack(tensors) 510 else: 511 for tensor in tensors: 512 yield tensor.unsqueeze(0) 513 514 # Create Ps and Qs that point to the allocated memory. 515 tensors_to_compress = [] 516 ps = [] 517 qs = [] 518 p_idx = 0 519 q_idx = 0 520 for tensor in maybe_batched_tensors_to_compress(): 521 batch_size, n, m = tensor.shape 522 matrix_approximation_rank = min(n, m, state.matrix_approximation_rank) 523 tensors_to_compress.append(tensor) 524 ps.append( 525 state.p_memory_dict[bucket_index][ 526 p_idx : p_idx + batch_size * n * matrix_approximation_rank 527 ].view(batch_size, n, matrix_approximation_rank) 528 ) 529 qs.append( 530 state.q_memory_dict[bucket_index][ 531 q_idx : q_idx + batch_size * m * matrix_approximation_rank 532 ].view(batch_size, m, matrix_approximation_rank) 533 ) 534 p_idx += batch_size * n * matrix_approximation_rank 535 q_idx += batch_size * m * matrix_approximation_rank 536 537 # If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values. 538 # The exception is the first iteration when PowerSGD is applied. 539 if not need_randomize_qs: 540 for q in qs: 541 _orthogonalize(q, state.orthogonalization_epsilon) 542 else: 543 with torch.random.fork_rng(devices=[]): 544 # Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training. 545 # The seed makes sure that the initial random values are the same across all the DDP replicas. 546 # This seed should differ at every step. 547 # Since it is very slow to fork RNG state across all the CUDA devices, 548 # only fork on CPU and then move the generated tensor to the CUDA device (by overwriting q). 549 torch.manual_seed(state.rng.randint(1_000_000_000)) 550 for q in qs: 551 q.copy_( 552 torch.randn( 553 *q.shape, 554 device="cpu", 555 dtype=dtype, 556 ) 557 ) 558 _orthogonalize(q, state.orthogonalization_epsilon) 559 560 # Compute Ps. 561 for tensor, q, p in zip(tensors_to_compress, qs, ps): 562 torch.bmm(tensor, q, out=p) 563 564 # This allreduce is only applied to uncompressed tensors, 565 # so it should have been kicked off before the above computation on the compressed tensors to hide more communication costs. 566 # However, this somehow requires a separate future chain at this time. 567 allreduce_contiguous_uncompressed_tensors_fut = dist.all_reduce( 568 uncompressed_tensors_memory, group=group_to_use, async_op=True 569 ).get_future() 570 571 def unpack_uncompressed_tensors_and_allreduce_ps(fut): 572 uncompressed_tensors_memory = fut.value()[0].div_(world_size) 573 idx = 0 574 for tensor in uncompressed_tensors: 575 tensor.copy_( 576 uncompressed_tensors_memory[idx : idx + tensor.numel()].view_as(tensor) 577 ) 578 idx += tensor.numel() 579 580 # Since these Ps will be orthogonalized later, no need to divide them by world size. 581 return ( 582 dist.all_reduce( 583 state.p_memory_dict[bucket_index], group=group_to_use, async_op=True 584 ) 585 .get_future() 586 .wait()[0] 587 ) 588 589 def compute_qs(fut): 590 state.p_memory_dict[bucket_index] = fut.value() 591 for p in ps: 592 _orthogonalize(p, state.orthogonalization_epsilon) 593 594 # Compute Qs. 595 for tensor, p, q in zip(tensors_to_compress, ps, qs): 596 torch.bmm(tensor.transpose(1, 2), p, out=q) 597 598 # TODO: The above procedure does two matmul+allreduce steps per iteration -- 599 # one left multiplication and one right multiplication. 600 # For warm-start, can take one such step at a time, and alternate between them. 601 602 # Allreduce Qs. 603 return ( 604 dist.all_reduce( 605 state.q_memory_dict[bucket_index], group=group_to_use, async_op=True 606 ) 607 .get_future() 608 .wait()[0] 609 ) 610 611 def decompress(fut): 612 state.q_memory_dict[bucket_index] = fut.value().div_(world_size) 613 614 for p, q, tensor in zip(ps, qs, tensors_to_compress): 615 torch.bmm(p, q.transpose(1, 2), out=tensor) 616 617 # Copy batched tensors back to original buffer. 618 if state.batch_tensors_with_same_shape: 619 for tensor in tensors_to_compress: 620 if tensor.shape[0] == 1: 621 # Skip tensor with batch_size == 1 since itself is the original tensor. 622 continue 623 original_tensors = shape_to_tensors[tensor.shape[1:]] 624 for i, original_tensor in enumerate(original_tensors): 625 original_tensor.copy_(tensor[i]) 626 627 if torch.cuda.is_available(): 628 torch.cuda.synchronize(device) 629 630 if state.use_error_feedback: 631 # Memorize the local errors. 632 state.error_dict[bucket_index] = input_tensor_cp - input_tensor 633 if not state.warm_start: 634 state.p_memory_dict.clear() 635 state.q_memory_dict.clear() 636 637 state.maybe_increase_iter(bucket) 638 639 return input_tensor 640 641 return ( 642 allreduce_contiguous_uncompressed_tensors_fut.then( 643 unpack_uncompressed_tensors_and_allreduce_ps 644 ) 645 .then(compute_qs) 646 .then(decompress) 647 ) 648 649 650def batched_powerSGD_hook( 651 state: PowerSGDState, bucket: dist.GradBucket 652) -> torch.futures.Future[torch.Tensor]: 653 r""" 654 Implement simplified PowerSGD algorithm. 655 656 This DDP communication hook implements a simplified PowerSGD gradient compression 657 algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_. 658 This variant does not compress the gradients layer by layer, 659 but instead compresses the flattened input tensor that batches all the gradients. 660 Therefore, it is **faster** than :meth:`powerSGD_hook`, 661 but usually results in a **much lower accuracy**, unless ``matrix_approximation_rank`` is 1. 662 663 .. warning :: 664 Increasing ``matrix_approximation_rank`` here may not necessarily increase the accuracy, 665 because batching per-parameter tensors without column/row alignment can destroy low-rank structure. 666 Therefore, the user should always consider :meth:`powerSGD_hook` first, 667 and only consider this variant when a satisfactory accuracy can be achieved when ``matrix_approximation_rank`` is 1. 668 669 Once gradient tensors are aggregated across all workers, this hook applies 670 compression as follows: 671 672 1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings; 673 674 2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; 675 676 3. Computes P, which is equal to MQ; 677 678 4. Allreduces P; 679 680 5. Orthogonalizes P; 681 682 6. Computes Q, which is approximately equal to M^TP; 683 684 7. Allreduces Q; 685 686 8. Computes M, which is approximately equal to PQ^T. 687 688 9. Truncates the input tensor to the original length. 689 690 Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations. 691 This not only gives the user more control over the tradeoff between speedup and accuracy, 692 but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers. 693 694 Args: 695 state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. 696 To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``. 697 bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. 698 Note that since DDP comm hook only supports single process single device mode, 699 only exactly one tensor is stored in this bucket. 700 701 Returns: 702 Future handler of the communication, which updates the gradients in place. 703 704 Example:: 705 >>> # xdoctest: +SKIP 706 >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) 707 >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook) 708 """ # noqa: B950 709 process_group = state.process_group 710 group_to_use = process_group if process_group is not None else dist.group.WORLD 711 world_size = group_to_use.size() 712 713 # The input tensor is a flattened 1D tensor. 714 input_tensor = bucket.buffer() 715 716 # Run vanilla allreduce in the first `start_powerSGD_iter` iterations. 717 if state.iter < state.start_powerSGD_iter: 718 state.maybe_increase_iter(bucket) 719 return default._allreduce_fut(group_to_use, input_tensor) 720 721 # Apply PowerSGD after `start_powerSGD_iter` iterations. 722 device = input_tensor.device 723 total_length = input_tensor.shape[0] 724 state.total_numel_before_compression += total_length 725 726 # View the input tensor as a 2D square-shape tensor, and pad 0s if necessary. 727 square_side_length = math.ceil(math.sqrt(total_length)) 728 state.total_numel_after_compression += ( 729 square_side_length * state.matrix_approximation_rank * 2 730 ) 731 padded_total_length = square_side_length**2 732 input_tensor.resize_(padded_total_length) 733 input_tensor[total_length:padded_total_length].fill_(0) 734 735 _report_compression_stats(bucket, state) 736 737 # Incorporate the error from the previous state into the gradients. 738 bucket_index = bucket.index() 739 input_tensor_cp = None 740 if state.use_error_feedback: 741 if bucket_index in state.error_dict: 742 input_tensor.add_(state.error_dict[bucket_index]) 743 else: 744 logger.info( 745 "A zero tensor of length %s that represents local error is created.", 746 padded_total_length, 747 ) 748 state.error_dict[bucket_index] = torch.zeros( 749 padded_total_length, device=device, dtype=input_tensor.dtype 750 ) 751 752 # Keep a copy of the input tensor, 753 # so that we can compute the local error caused by compression later, 754 # by comparing this copy and the input tensor updated after decompression. 755 input_tensor_cp = torch.clone(input_tensor).detach() 756 matrix = input_tensor.view(square_side_length, square_side_length) 757 758 # Reuse P and Q from the previous iteration if possible. 759 # The memory spaces of P and Q need to be allocated in the first iteration when PowerSGD is applied. 760 if not state.warm_start or bucket_index not in state.p_memory_dict: 761 # If warm-start is disabled, low-rank tensors will be initialized at every step. 762 # Only log this if warm-start to avoid spamming. 763 if state.warm_start: 764 logger.info( 765 "Initializing low-rank tensors P and Q, each of which has a shape of %s x %s.", 766 square_side_length, 767 state.matrix_approximation_rank, 768 ) 769 770 def create_low_rank_tensor(fill_random_values, rng): 771 """Return a low-rank 2D tensor of square_side_length * matrix_approximation_rank.""" 772 if fill_random_values: 773 with torch.random.fork_rng(devices=[]): 774 # Fork this RNG to avoid changing the seed globally and affecting the random sampling 775 # anywhere else in the training. 776 # The seed makes sure that the initial random values are the same across all the DDP replicas. 777 # This seed should differ at every step. 778 # Since it is very slow to fork RNG state across all the CUDA devices, 779 # only fork on CPU and then move the generated tensor to the CUDA device. 780 torch.manual_seed(rng.randint(1_000_000_000)) 781 return torch.randn( 782 square_side_length, 783 state.matrix_approximation_rank, 784 device="cpu", 785 dtype=input_tensor.dtype, 786 ).to(device) 787 else: 788 return torch.empty( 789 square_side_length, 790 state.matrix_approximation_rank, 791 device=device, 792 dtype=input_tensor.dtype, 793 ) 794 795 state.p_memory_dict[bucket_index] = create_low_rank_tensor( 796 fill_random_values=False, rng=state.rng 797 ) 798 state.q_memory_dict[bucket_index] = create_low_rank_tensor( 799 fill_random_values=True, rng=state.rng 800 ) 801 _orthogonalize(state.q_memory_dict[bucket_index]) 802 803 torch.matmul( 804 matrix, state.q_memory_dict[bucket_index], out=state.p_memory_dict[bucket_index] 805 ) 806 allreduce_p_fut = dist.all_reduce( 807 state.p_memory_dict[bucket_index], group=group_to_use, async_op=True 808 ).get_future() 809 810 def compute_q(fut): 811 state.p_memory_dict[bucket_index] = fut.value()[0] 812 _orthogonalize(state.p_memory_dict[bucket_index]) 813 814 torch.matmul( 815 matrix.t(), 816 state.p_memory_dict[bucket_index], 817 out=state.q_memory_dict[bucket_index], 818 ) 819 820 # TODO: The above procedure does two matmul+allreduce steps per iteration -- 821 # one left multiplication and one right multiplication. 822 # For warm-start, can take one such step at a time, and alternate between them. 823 824 return ( 825 dist.all_reduce( 826 state.q_memory_dict[bucket_index], group=group_to_use, async_op=True 827 ) 828 .get_future() 829 .wait()[0] 830 ) 831 832 def decompress(fut): 833 state.q_memory_dict[bucket_index] = fut.value().div_(world_size) 834 torch.matmul( 835 state.p_memory_dict[bucket_index], 836 state.q_memory_dict[bucket_index].t(), 837 out=matrix, 838 ) 839 840 if state.use_error_feedback: 841 # Memorize the local errors. 842 state.error_dict[bucket_index] = input_tensor_cp - input_tensor 843 # Removing this seemingly unnecessary sync somehow may cause failures. 844 # See: https://github.com/pytorch/pytorch/pull/54838 845 if torch.cuda.is_available(): 846 torch.cuda.synchronize(device) 847 if not state.warm_start: 848 state.p_memory_dict.clear() 849 state.q_memory_dict.clear() 850 ret = input_tensor.resize_(total_length) 851 852 state.maybe_increase_iter(bucket) 853 854 return ret 855 856 return allreduce_p_fut.then(compute_q).then(decompress) 857