1# mypy: allow-untyped-defs 2from typing import Any, Dict, Optional, Set, Tuple, Union 3from typing_extensions import deprecated 4 5import torch 6from torch import Tensor 7from torch.nn.utils._named_member_accessor import NamedMemberAccessor 8 9 10__all__ = ["functional_call"] 11 12 13def _untie_named_tensors_map( 14 module: "torch.nn.Module", 15 parameters_and_buffers: Dict[str, Tensor], 16) -> Dict[str, Tensor]: 17 """ 18 Unties all tied tensors in the module to parameters_and_buffers. 19 20 This function returns a new untied_parameters_and_buffers dictionary and leave the original 21 untied_parameters_and_buffers dictionary unchanged. It adds new (missing) keys for tied tensors 22 in the module to untied_parameters_and_buffers. The value of the new key is the user-given value 23 in the original parameters_and_buffers dictionary. 24 25 If there are more than one user-given values for the same tied tensor, it will raise an error. 26 27 For example, if the module has two tied weights self.foo and self.tied_foo and the user passes 28 {'foo': foo_value, ...}, this will return {'foo': foo_value, 'tied_foo': foo_value, ...}. If the 29 user passes {'foo': foo_value, 'tied_foo': tied_foo_value, ...}, it will raise an error. If the 30 user passes {'foo': foo_value, 'tied_foo': foo_value, ...}, it will not raise an error. 31 32 Args: 33 module (torch.nn.Module): the module to determine which tensors are tied. 34 parameters_and_buffers (Dict[str, Tensor]): a map of {name: tensor} for reparamaterizing the module. 35 36 Returns: 37 A new untied version of the parameters_and_buffers dictionary. 38 39 Raises: 40 ValueError: if there are more than one user-given values for the same tied tensor. 41 """ 42 # A map of {name: tensor} for all tensors (including tied ones) in the module. 43 all_named_tensors: Dict[str, Tensor] = {} 44 all_named_tensors.update(module.named_parameters(remove_duplicate=False)) 45 all_named_tensors.update(module.named_buffers(remove_duplicate=False)) 46 47 # A map of {tensor: set(all_tied_names)} for all tensor names in the module. 48 tensor_to_tied_names_map: Dict[Tensor, Set[str]] = {} 49 for name, tensor in all_named_tensors.items(): 50 if tensor not in tensor_to_tied_names_map: 51 tensor_to_tied_names_map[tensor] = set() 52 tensor_to_tied_names_map[tensor].add(name) 53 54 # A map of {tied_name: set(all_tied_names)} for all tensor names in the module. 55 # If a name is not tied, it will not be in this map. 56 tied_names_map: Dict[str, Set[str]] = {} 57 for tied_names in tensor_to_tied_names_map.values(): 58 if len(tied_names) > 1: 59 for tied_name in tied_names: 60 tied_names_map[tied_name] = tied_names 61 62 # Make sure the user didn't pass multiple values for the same tied tensor. 63 given_names = set(parameters_and_buffers.keys()) 64 # same as given_names.intersection(tied_names_map.keys()) but dynamo can't 65 # handle that 66 given_names_for_tied_tensors: set[str] = set() 67 for name in given_names: 68 if name in tied_names_map: 69 given_names_for_tied_tensors.add(name) 70 71 for given_name in given_names_for_tied_tensors: 72 tied_names = tied_names_map[given_name] 73 if ( 74 # Detect if there are multiple keys present for the same tied tensor. 75 len(tied_names.intersection(given_names_for_tied_tensors)) > 1 76 # Only raise an error if the user passed multiple values for the same tied tensor. 77 # If all given values are the same, don't raise. 78 and len({parameters_and_buffers[tied_name] for tied_name in tied_names}) 79 != 1 80 ): 81 raise ValueError( 82 f"functional_call got multiple values for keys {sorted(tied_names)}, " 83 f"which are tied. Consider using tie_weights=False" 84 ) 85 86 # Untie the given named tensor map 87 # Make a copy for not modifying the original dict 88 untied_parameters_and_buffers = parameters_and_buffers.copy() 89 for given_name in given_names_for_tied_tensors: 90 for tied_name in tied_names_map[given_name]: 91 untied_parameters_and_buffers[tied_name] = parameters_and_buffers[ 92 given_name 93 ] 94 return untied_parameters_and_buffers 95 96 97class _ReparametrizeModule: 98 def __init__( 99 self, 100 module: "torch.nn.Module", 101 parameters_and_buffers: Dict[str, Tensor], 102 tie_weights: bool = False, 103 strict: bool = False, 104 stack_weights: bool = False, 105 ): 106 self.parameters_and_buffers = parameters_and_buffers 107 self.stack_weights = stack_weights 108 109 if tie_weights: 110 self.untied_parameters_and_buffers = _untie_named_tensors_map( 111 module, parameters_and_buffers 112 ) 113 else: 114 self.untied_parameters_and_buffers = parameters_and_buffers 115 116 self.accessor = NamedMemberAccessor(module) 117 if strict: 118 missing_keys, unexpected_keys = self.accessor.check_keys( 119 self.untied_parameters_and_buffers 120 ) 121 error_msgs = [] 122 if len(unexpected_keys) > 0: 123 error_msgs.append( 124 f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}." 125 ) 126 if len(missing_keys) > 0: 127 error_msgs.append( 128 f"Missing key(s): {', '.join(map(repr, missing_keys))}." 129 ) 130 if len(error_msgs) > 0: 131 raise RuntimeError( 132 "Error(s) in reparametrizing for {}:\n\t{}".format( 133 module._get_name(), "\n\t".join(error_msgs) 134 ) 135 ) 136 137 def __enter__(self): 138 self.orig_parameters_and_buffers, _ = self.accessor.swap_tensors_dict( 139 self.untied_parameters_and_buffers, allow_missing=True 140 ) 141 142 def __exit__(self, exception_type, exception_value, traceback): 143 if self.stack_weights: 144 # When stacking is enabled, we will restore the weights in LIFO order. 145 self.orig_parameters_and_buffers = dict( 146 reversed(self.orig_parameters_and_buffers.items()) 147 ) 148 new_parameters_and_buffers, _ = self.accessor.swap_tensors_dict( 149 self.orig_parameters_and_buffers, allow_missing=True 150 ) 151 # Sometimes the module is not completely stateless and has some in-place modifications on 152 # the _parameters and _buffers dictionaries. 153 # Write the changed parameters and buffers back to the original dict. 154 self.parameters_and_buffers.update( 155 { 156 k: new_parameters_and_buffers[k] 157 for k in self.parameters_and_buffers 158 if k in new_parameters_and_buffers 159 } 160 ) 161 162 163def _reparametrize_module( 164 module: "torch.nn.Module", 165 parameters_and_buffers: Dict[str, Tensor], 166 *, 167 tie_weights: bool = False, 168 strict: bool = False, 169 stack_weights: bool = False, 170) -> _ReparametrizeModule: 171 return _ReparametrizeModule( 172 module, 173 parameters_and_buffers, 174 tie_weights=tie_weights, 175 strict=strict, 176 stack_weights=stack_weights, 177 ) 178 179 180@deprecated( 181 "`torch.nn.utils.stateless.functional_call` is deprecated as of PyTorch 2.0 " 182 "and will be removed in a future version of PyTorch. " 183 "Please use `torch.func.functional_call` instead which is a drop-in replacement.", 184 category=FutureWarning, 185) 186def functional_call( 187 module: "torch.nn.Module", 188 parameters_and_buffers: Dict[str, Tensor], 189 args: Union[Any, Tuple], 190 kwargs: Optional[Dict[str, Any]] = None, 191 *, 192 tie_weights: bool = True, 193 strict: bool = False, 194): 195 r"""Perform a functional call on the module by replacing the module parameters and buffers with the provided ones. 196 197 .. warning:: 198 199 This API is deprecated as of PyTorch 2.0 and will be removed in a future 200 version of PyTorch. Please use :func:`torch.func.functional_call` instead, 201 which is a drop-in replacement for this API. 202 203 .. note:: If the module has active parametrizations, passing a value in the 204 :attr:`parameters_and_buffers` argument with the name set to the regular parameter 205 name will completely disable the parametrization. 206 If you want to apply the parametrization function to the value passed 207 please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``. 208 209 .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected 210 in the `parameters_and_buffers` input. 211 212 Example:: 213 214 >>> a = {'foo': torch.zeros(())} 215 >>> # xdoctest: +SKIP 216 >>> mod = Foo() # does self.foo = self.foo + 1 217 >>> print(mod.foo) # tensor(0.) 218 >>> functional_call(mod, a, torch.ones(())) 219 >>> print(mod.foo) # tensor(0.) 220 >>> print(a['foo']) # tensor(1.) 221 222 .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the 223 tie_weights flag. 224 225 Example:: 226 227 >>> a = {'foo': torch.zeros(())} 228 >>> # xdoctest: +SKIP 229 >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied 230 >>> print(mod.foo) # tensor(1.) 231 >>> mod(torch.zeros(())) # tensor(2.) 232 >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too 233 >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated 234 >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} 235 >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.) 236 237 Args: 238 module (torch.nn.Module): the module to call 239 parameters_and_buffers (dict of str and Tensor): the parameters that will be used in 240 the module call. 241 args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument. 242 kwargs (dict): keyword arguments to be passed to the module call 243 tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as 244 tied in the reparamaterized version. Therefore, if True and different values are passed for the tied 245 parameters and buffers, it will error. If False, it will not respect the originally tied parameters and 246 buffers unless the values passed for both weights are the same. Default: True. 247 strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and 248 buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will 249 error. Default: False. 250 251 Returns: 252 Any: the result of calling ``module``. 253 """ 254 return _functional_call( 255 module, 256 parameters_and_buffers, 257 args, 258 kwargs, 259 tie_weights=tie_weights, 260 strict=strict, 261 ) 262 263 264def _functional_call( 265 module: "torch.nn.Module", 266 parameters_and_buffers: Dict[str, Tensor], 267 args: Union[Any, Tuple], 268 kwargs: Optional[Dict[str, Any]] = None, 269 *, 270 tie_weights: bool = True, 271 strict: bool = False, 272): 273 # TODO allow kwargs such as unsafe and others for parametrization 274 if ( 275 torch.jit.is_tracing() 276 or torch.jit.is_scripting() 277 or isinstance( 278 module, 279 ( 280 torch.jit.RecursiveScriptModule, 281 torch.jit.ScriptModule, 282 torch.jit.ScriptFunction, 283 ), 284 ) 285 ): 286 raise RuntimeError("The stateless API can't be used with Jitted modules") 287 if isinstance(module, torch.nn.DataParallel): 288 raise RuntimeError( 289 "The stateless API can't be used with nn.DataParallel module" 290 ) 291 if kwargs is None: 292 kwargs = {} 293 if not isinstance(args, tuple): 294 args = (args,) 295 with _reparametrize_module( 296 module, parameters_and_buffers, tie_weights=tie_weights, strict=strict 297 ): 298 return module(*args, **kwargs) 299