1# mypy: allow-untyped-defs 2from typing import Optional 3 4import torch 5 6from .expanded_weights_impl import ExpandedWeight 7 8 9def is_batch_first(expanded_args_and_kwargs): 10 batch_first = None 11 for arg in expanded_args_and_kwargs: 12 if not isinstance(arg, ExpandedWeight): 13 continue 14 15 if not batch_first: 16 batch_first = arg.batch_first 17 elif arg.batch_first != batch_first: 18 raise RuntimeError( 19 "Got conflicting batch_first arguments in the same layer" 20 ) 21 return batch_first 22 23 24def standard_kwargs(kwarg_names, expanded_args): 25 r"""Separate args and kwargs from `__torch_function__`s that standardize kwargs. 26 27 Most `__torch_function__`s standardize the kwargs that they give, so this will separate 28 the args and kwargs they pass. Functions that don't are linear and convND. 29 """ 30 kwarg_values = expanded_args[len(expanded_args) - len(kwarg_names) :] 31 expanded_args_without_kwargs = expanded_args[ 32 : len(expanded_args) - len(kwarg_names) 33 ] 34 expanded_kwargs = dict(zip(kwarg_names, kwarg_values)) 35 return expanded_args_without_kwargs, expanded_kwargs 36 37 38def forward_helper(func, expanded_args, expanded_kwargs): 39 r"""Compute the forward pass for a function that has expanded weight(s) passed to it. 40 41 It will run the forward pass where all ExpandedWeights are their original 42 weight. It runs checks on the given arguments and detaches the outputs. 43 44 .. note:: First argument in :attr:`expanded_args` must be the input with the batch 45 dimension as the first element of the shape 46 47 .. note:: :attr:`func` must return a Tensor or tuple of Tensors 48 49 Args: 50 func: The function to be called 51 expanded_args: Arguments to be passed to :attr:`func`. Will include arguments 52 that need to be unpacked because they are ExpandedWeights 53 expanded_kwargs: Keyword arguments to be passed to :attr:`func`. 54 Similar to :attr:`expanded_args`. 55 """ 56 unexpanded_args, unexpanded_kwargs = _check_and_unexpand_args( 57 func, expanded_args, expanded_kwargs 58 ) 59 return func(*unexpanded_args, **unexpanded_kwargs) 60 61 62def _check_and_unexpand_args(func, expanded_args, expanded_kwargs): 63 # input must be the first argument passed 64 input = expanded_args[0] 65 if isinstance(input, ExpandedWeight): 66 raise RuntimeError( 67 "Expanded Weights do not support inputs that are also ExpandedWeights. " 68 f"Input must be a Tensor, got {type(input).__name__} in function {func.__name__}" 69 ) 70 if not isinstance(input, torch.Tensor): 71 raise RuntimeError( 72 "Expanded Weights requires a Tensor as the first input to get the batch dimension, " 73 f"got {type(input).__name__} in function {func.__name__}" 74 ) 75 if len(input.shape) == 0: 76 raise RuntimeError( 77 f"Expanded Weights requires a batch dimension but got an input of size 0 in function {func.__name__}" 78 ) 79 if input.shape[0] == 0: 80 raise RuntimeError( 81 "0 is not a valid batch size for Expanded Weights but got input tensor of " 82 f"{input} in function {func.__name__}" 83 ) 84 for arg in expanded_args + tuple(expanded_kwargs.values()): 85 if not isinstance(arg, ExpandedWeight): 86 continue 87 batch_size = input.shape[0] if arg.batch_first else input.shape[1] 88 if (arg.allow_smaller_batches and batch_size > arg.batch_size) or ( 89 not arg.allow_smaller_batches and arg.batch_size != batch_size 90 ): 91 raise RuntimeError( 92 "Expected ExpandedWeights to have batch size matching input but got " 93 f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}" 94 ) 95 96 loss_reduction: Optional[str] = None 97 for arg in expanded_args + tuple(expanded_kwargs.values()): 98 if isinstance(arg, ExpandedWeight): 99 if loss_reduction is None: 100 loss_reduction = arg.loss_reduction 101 elif loss_reduction != arg.loss_reduction: 102 raise RuntimeError( 103 "Expected ExpandedWeights to all have the same loss_reduction argument but got one" 104 f"with {loss_reduction} and one with {arg.loss_reduction}" 105 ) 106 107 unexpanded_args = tuple( 108 arg.orig_weight if isinstance(arg, ExpandedWeight) else arg 109 for arg in expanded_args 110 ) 111 unexpanded_kwargs = { 112 name: arg.orig_weight if isinstance(arg, ExpandedWeight) else arg 113 for (name, arg) in expanded_kwargs.items() 114 } 115 return unexpanded_args, unexpanded_kwargs 116 117 118def maybe_scale_by_batch_size(grad_sample, expanded_weight): 119 if expanded_weight.loss_reduction == "mean": 120 return grad_sample * expanded_weight.batch_size 121 else: 122 return grad_sample 123 124 125def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn): 126 unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight) 127 if isinstance(maybe_expanded_weight, ExpandedWeight): 128 grad_sample_contribution = maybe_scale_by_batch_size( 129 per_sample_grad_fn(unpacked), maybe_expanded_weight 130 ) 131 132 if maybe_expanded_weight.batch_size > grad_sample_contribution.shape[0]: 133 # this only passes the other checks if the arg allows smaller batch sizes 134 intermediate = torch.zeros( 135 maybe_expanded_weight.batch_size, 136 *grad_sample_contribution.shape[1:], 137 dtype=grad_sample_contribution.dtype, 138 device=grad_sample_contribution.device, 139 ) 140 intermediate[: grad_sample_contribution.shape[0]] = grad_sample_contribution 141 grad_sample_contribution = intermediate 142 143 if hasattr(unpacked, "grad_sample") and unpacked.grad_sample is not None: 144 unpacked.grad_sample = unpacked.grad_sample + grad_sample_contribution 145 else: 146 unpacked.grad_sample = grad_sample_contribution 147 148 149def unpack_expanded_weight_or_tensor(maybe_expanded_weight, func=lambda x: x): 150 if isinstance(maybe_expanded_weight, ExpandedWeight): 151 orig_weight = maybe_expanded_weight.orig_weight 152 return func(orig_weight) 153 elif ( 154 isinstance(maybe_expanded_weight, torch.Tensor) 155 and not maybe_expanded_weight.requires_grad 156 ): 157 return func(maybe_expanded_weight) 158 elif isinstance(maybe_expanded_weight, torch.Tensor): 159 raise RuntimeError( 160 "ExpandedWeights currently does not support a mixture of ExpandedWeight parameters " 161 "and normal Parameters. Please file and issue with pytorch/pytorch" 162 ) 163 164 165def sum_over_all_but_batch_and_last_n( 166 tensor: torch.Tensor, 167 n_dims: int, 168) -> torch.Tensor: 169 r""" 170 Calculate the sum over all dimensions, except the first (batch dimension), and excluding the last n_dims. 171 172 This function will ignore the first dimension and it will 173 not aggregate over the last n_dims dimensions. 174 Args: 175 tensor: An input tensor of shape ``(B, ..., X[n_dims-1])``. 176 n_dims: Number of dimensions to keep. 177 Example: 178 >>> tensor = torch.ones(1, 2, 3, 4, 5) 179 >>> sum_over_all_but_batch_and_last_n(tensor, n_dims=2).shape 180 torch.Size([1, 4, 5]) 181 Returns: 182 A tensor of shape ``(B, ..., X[n_dims-1])`` 183 """ 184 if tensor.dim() == n_dims + 1: 185 return tensor 186 else: 187 dims = list(range(1, tensor.dim() - n_dims)) 188 return tensor.sum(dim=dims) 189