1# mypy: allow-untyped-defs 2import functools 3from contextlib import contextmanager 4from typing import Callable, Dict 5 6import torch 7from torch._decomp import decomposition_table 8from torch.utils._pytree import tree_map_only 9 10 11HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {} 12 13aten = torch._ops.ops.aten 14# __torch_function__ runs before the pydispatcher so we need to manually use the same 15# decompositions indexed by their torch equivalent 16expanded_weights_rnn_decomps = { 17 # func: (input_decomp, data_decomp) 18 torch.rnn_relu: ( 19 decomposition_table[aten.rnn_relu.input], 20 decomposition_table[aten.rnn_relu.data], 21 ), 22 torch.rnn_tanh: ( 23 decomposition_table[aten.rnn_tanh.input], 24 decomposition_table[aten.rnn_tanh.data], 25 ), 26 torch.lstm: ( 27 decomposition_table[aten.lstm.input], 28 decomposition_table[aten.lstm.data], 29 ), 30 torch.gru: ( 31 decomposition_table[aten.gru.input], 32 decomposition_table[aten.gru.data], 33 ), 34} 35 36 37# all of the RNN decomps run linear with the batch dimension second, even if batch_first was set 38@contextmanager 39def batch_second(args, kwargs): 40 def set_batch_second(ew): 41 ew.set_batch_first(False) 42 43 def reset_batch_first(ew): 44 ew.set_batch_first(True) 45 46 tree_map_only(ExpandedWeight, set_batch_second, args) 47 tree_map_only(ExpandedWeight, set_batch_second, kwargs) 48 try: 49 yield 50 finally: 51 tree_map_only(ExpandedWeight, reset_batch_first, args) 52 tree_map_only(ExpandedWeight, reset_batch_first, kwargs) 53 54 55# to support packed sequences, we need to allow for smaller batches. Expanded weights represents the largest batch 56@contextmanager 57def allow_smaller_batches(args, kwargs): 58 def allow(ew): 59 ew.set_allow_smaller_batches(True) 60 61 def reset(ew): 62 ew.set_allow_smaller_batches(False) 63 64 tree_map_only(ExpandedWeight, allow, args) 65 tree_map_only(ExpandedWeight, allow, kwargs) 66 try: 67 yield 68 finally: 69 tree_map_only(ExpandedWeight, reset, args) 70 tree_map_only(ExpandedWeight, reset, kwargs) 71 72 73@contextmanager 74def setup_rnn(use_input_variant, args, kwargs): 75 with batch_second(args, kwargs) if use_input_variant else allow_smaller_batches( 76 args, kwargs 77 ): 78 yield 79 80 81def implements_per_sample_grads(torch_function): 82 @functools.wraps(torch_function) 83 def decorator(autograd_func): 84 HANDLED_FUNCTIONS[torch_function] = autograd_func 85 return autograd_func 86 87 return decorator 88 89 90# ExpandedWeight represents a weight (parameter) Tensor that has an expanded 91# batch dimension. Operations on the ExpandedWeight Tensor act exactly like 92# those without an expanded batch dimension but a call to .backward() populates 93# the original (unexpanded) tensor with per-sample-gradients for in the grad_sample field 94# 95# ExpandedWeight has a fallback that always fails since we cannot know what the batch 96# dimension of the input tensor is and therefore cannot know if this is a valid call 97# 98# This is a __torch_function__ object but it could have also been a Tensor Extension 99# with a dispatch key. 100# 101# Needs to be a tensor subclass to allow reparamaterization 102class ExpandedWeight(torch.Tensor): 103 def __init__(self, orig_weight, batch_size, loss_reduction): 104 self.batch_size = batch_size 105 self.batch_first = True 106 self.allow_smaller_batches = False 107 self.orig_weight = orig_weight 108 self.loss_reduction = loss_reduction 109 110 handled_functions = HANDLED_FUNCTIONS 111 112 def __new__(cls, orig_weight, batch_size, loss_reduction): 113 if not isinstance(orig_weight, torch.Tensor): 114 raise RuntimeError( 115 f"Can only make Expanded Weights of Tensors, got {type(orig_weight).__name__}" 116 ) 117 if not orig_weight.requires_grad: 118 raise RuntimeError( 119 "Can only build ExpandedWeights objects of tensors that require_grad" 120 ) 121 ret = torch.Tensor._make_subclass(cls, orig_weight, True) 122 return ret 123 124 @classmethod 125 def __torch_function__(cls, func, _, args=(), kwargs=None): 126 if kwargs is None: 127 kwargs = {} 128 if func in expanded_weights_rnn_decomps: 129 # in aten, choosing the input or data variants is done by parsing logic. This mimics some of that 130 decomp_opts = expanded_weights_rnn_decomps[func] 131 use_input_variant = isinstance( 132 args[2], list 133 ) # data variant uses a list here 134 decomp = decomp_opts[0] if use_input_variant else decomp_opts[1] 135 136 if decomp is not None: 137 with setup_rnn(use_input_variant, args, kwargs): 138 return decomp(*args, **kwargs) 139 if func == torch._cudnn_rnn_flatten_weight: 140 # since we aren't using the fused cuda kernels for RNNs, don't do this 141 return 142 if func in cls.handled_functions: 143 return cls.handled_functions[func].apply( 144 tuple(kwargs.keys()), func, *(args + tuple(kwargs.values())) 145 ) 146 # We cannot use a fallback here because we do not know the batch dimension for any regular tensor inputs, 147 # i.e. torch.add(torch.Tensor, ExpandedWeight) 148 raise RuntimeError( 149 f"Expanded Weights encountered but cannot handle function {func.__name__}" 150 ) 151 152 @property 153 def dtype(self): 154 return self.orig_weight.dtype 155 156 @property 157 def data(self): 158 return self.orig_weight.data 159 160 @property 161 def shape(self): 162 return self.orig_weight.shape 163 164 @property 165 def device(self): 166 return self.orig_weight.device 167 168 @property 169 def is_cuda(self): 170 return self.orig_weight.is_cuda 171 172 def data_ptr(self): 173 return self.orig_weight.data_ptr() 174 175 def get_device(self): 176 return self.orig_weight.get_device() 177 178 def set_allow_smaller_batches(self, is_allow_smaller_batches): 179 self.allow_smaller_batches = is_allow_smaller_batches 180 181 def set_batch_first(self, is_batch_first=True): 182 self.batch_first = is_batch_first 183