1# mypy: allow-untyped-defs 2# Copyright (c) Facebook, Inc. and its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import copy 9from typing import ( 10 Any, 11 Callable, 12 Dict, 13 Iterable, 14 List, 15 NoReturn, 16 Sequence, 17 Tuple, 18 Type, 19 Union, 20) 21 22import torch 23import torch.nn as nn 24from torch import Tensor 25from torch.nn.utils._named_member_accessor import NamedMemberAccessor 26 27 28# Utilities to make nn.Module "functional" 29# In particular the goal is to be able to provide a function that takes as input 30# the parameters and evaluate the nn.Module using fixed inputs. 31 32 33def raise_parameter_tying_error() -> NoReturn: 34 raise RuntimeError( 35 "make_functional(module): we don't yet support models that " 36 "do parameter tying (also sometimes known as weight sharing). " 37 "Please try to rewrite your model by replacing all instances of the " 38 "tied parameter with another and/or comment your support in " 39 "https://github.com/pytorch/functorch/issues/446" 40 ) 41 42 43def create_names_map( 44 named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]], 45 tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]], 46) -> Dict[str, List[str]]: 47 """ 48 named_params is a dictionary of tensors: {'A': A, 'B': B} 49 tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B} 50 with potentially tied (or 'duplicated') tensors 51 52 This function creates a mapping from the names in named_params to the 53 names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}. 54 """ 55 named_params = dict(named_params) 56 tied_named_params = dict(tied_named_params) 57 58 tensors_dict_keys = set(named_params.keys()) 59 tied_tensors_dict_keys = set(tied_named_params.keys()) 60 assert tensors_dict_keys.issubset(tied_tensors_dict_keys) 61 62 tensor_to_mapping: Dict[Tensor, Tuple[str, List[str]]] = {} 63 for key, tensor in named_params.items(): 64 tensor_to_mapping[tensor] = (key, []) 65 for key, tensor in tied_named_params.items(): 66 assert tensor in tensor_to_mapping 67 tensor_to_mapping[tensor][1].append(key) 68 return dict(tensor_to_mapping.values()) 69 70 71def _extract_members( 72 mod: nn.Module, 73 named_members: Callable[..., Iterable[Tuple[str, Tensor]]], 74 subclass: Callable[[Tensor], Tensor], 75) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]: 76 all_named_members = tuple(named_members(remove_duplicate=False)) 77 unique_named_members = tuple(named_members(remove_duplicate=True)) 78 names_map = create_names_map(unique_named_members, all_named_members) 79 80 # Remove all the members in the model 81 memo = {} 82 accessor = NamedMemberAccessor(mod) 83 for name, p in all_named_members: 84 if p not in memo: 85 memo[p] = subclass(torch.empty_like(p, device="meta")) 86 replacement = memo[p] 87 accessor.set_tensor(name, replacement) 88 89 if len(unique_named_members) == 0: 90 names, params = (), () 91 else: 92 names, params = zip(*unique_named_members) # type: ignore[assignment] 93 return params, names, names_map 94 95 96def extract_weights( 97 mod: nn.Module, 98) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]: 99 """ 100 This function removes all the Parameters from the model and 101 return them as a tuple as well as their original attribute names. 102 The weights must be re-loaded with `load_weights` before the model 103 can be used again. 104 Note that this function modifies the model in place and after this 105 call, mod.parameters() will be empty. 106 """ 107 return _extract_members(mod, mod.named_parameters, nn.Parameter) 108 109 110def extract_buffers( 111 mod: nn.Module, 112) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]: 113 return _extract_members(mod, mod.named_buffers, lambda x: x) 114 115 116def load_weights( 117 mod: nn.Module, 118 names: Sequence[str], 119 params: Sequence[Tensor], 120 as_params: bool = False, 121) -> None: 122 """ 123 Reload a set of weights so that `mod` can be used again to perform a forward pass. 124 Note that the `params` are regular Tensors (that can have history) and so are left 125 as Tensors. This means that mod.parameters() will still be empty after this call. 126 """ 127 accessor = NamedMemberAccessor(mod) 128 if as_params: 129 params = [nn.Parameter(p) for p in params] 130 accessor.set_tensors(names, params) 131 132 133def _swap_state( 134 mod: nn.Module, names_map: Dict[str, List[str]], elems: Iterable[Tensor] 135) -> List[Tensor]: 136 result: List[Tensor] = [] 137 accessor = NamedMemberAccessor(mod) 138 for (_, attr_names), elem in zip(names_map.items(), elems): 139 for i, attr_name in enumerate(attr_names): 140 if i == 0: 141 result.append(accessor.swap_tensor(attr_name, elem)) 142 else: 143 accessor.set_tensor(attr_name, elem) 144 return result 145 146 147def load_buffers( 148 mod: nn.Module, 149 names: Sequence[str], 150 buffers: Sequence[Tensor], 151 as_params: bool = False, 152) -> None: 153 accessor = NamedMemberAccessor(mod) 154 accessor.set_tensors(names, buffers) 155 156 157def load_state( 158 model: nn.Module, 159 weights: Sequence[Tensor], 160 weight_names: Sequence[str], 161 buffers: Sequence[Tensor] = (), 162 buffer_names: Sequence[str] = (), 163) -> nn.Module: 164 """load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model 165 166 load_state takes `weights` and `buffers` and assigns them to the model. 167 This is the inverse operation of `make_functional_deprecated_v1`. 168 """ 169 assert len(weight_names) == len(weights) 170 load_weights(model, weight_names, weights) 171 if len(buffers) > 0: 172 assert len(buffer_names) == len(buffers) 173 load_buffers(model, buffer_names, buffers) 174 return model 175 176 177def make_functional_deprecated_v1(model: nn.Module): 178 """make_functional_deprecated_v1(model) -> weights, func, weight_names 179 180 Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights) 181 and returns a functional version of the model, `func`. This makes 182 it so that it is possible use transforms over the parameters of 183 `model`. 184 185 `func` can be invoked as follows: 186 ``` 187 x = torch.randn(4, 3) 188 model = nn.Linear(3, 3) 189 weights, func, _ = make_functional_deprecated_v1(model) 190 func(weights, (x,)) 191 ``` 192 193 And here is an example of applying the grad transform: 194 ``` 195 x = torch.randn(4, 3) 196 model = nn.Linear(3, 3) 197 weights, _, func = make_functional_deprecated_v1(model) 198 grad_weights = grad(func)(weights, (x,)) 199 ``` 200 201 To put the state back into a model, use `load_state`. 202 """ 203 buffers = list(model.buffers()) 204 if len(buffers) > 0: 205 raise RuntimeError( 206 "make_functional_deprecated_v1(model): `model` has buffers. Please use " 207 "make_functional_with_buffers_deprecated_v1(model) instead." 208 ) 209 weights, descriptors, _ = extract_weights(model) 210 211 def fun(weights, data): 212 mutable_model = copy.deepcopy(model) 213 load_weights(mutable_model, descriptors, weights) 214 return mutable_model(*data) 215 216 return weights, fun, descriptors 217 218 219def make_functional_with_buffers_deprecated_v1(model: nn.Module): 220 """make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names 221 222 Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers) 223 and returns a functional version of the model, `func`. 224 225 `func` can be invoked as follows: 226 ``` 227 x = torch.randn(4, 3) 228 model = nn.Linear(3, 3) 229 weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model) 230 func(weights, buffers, (x,)) 231 ``` 232 233 And here is an example of applying the grad transform: 234 ``` 235 x = torch.randn(4, 3) 236 model = nn.Linear(3, 3) 237 weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model) 238 func(weights, buffers, (x,)) 239 grad_weights = grad(func)(weights, buffers, (x,)) 240 ``` 241 242 To put the state back into a model, use `load_state`. 243 """ 244 weights, weight_descriptors, _ = extract_weights(model) 245 buffers, buf_descriptors, _ = extract_buffers(model) 246 247 def fun(weights, buffers, data): 248 mutable_model = copy.deepcopy(model) 249 load_weights(mutable_model, weight_descriptors, weights) 250 load_buffers(mutable_model, buf_descriptors, buffers) 251 return mutable_model(*data) 252 253 return weights, buffers, fun, weight_descriptors, buf_descriptors 254 255 256class FunctionalModuleWithBuffers(nn.Module): 257 """ 258 This is the callable object returned by :func:`make_functional_with_buffers`. 259 """ 260 261 def __init__( 262 self, 263 stateless_model: nn.Module, 264 param_names: Tuple[str, ...], 265 buffer_names: Tuple[str, ...], 266 param_names_map: Dict[str, List[str]], 267 buffer_names_map: Dict[str, List[str]], 268 ) -> None: 269 super().__init__() 270 self.stateless_model = stateless_model 271 self.param_names = param_names 272 self.buffer_names = buffer_names 273 274 self.all_names_map = dict(param_names_map) 275 self.all_names_map.update(buffer_names_map) 276 277 @staticmethod 278 def _create_from( 279 model: nn.Module, disable_autograd_tracking: bool = False 280 ) -> Tuple["FunctionalModuleWithBuffers", Tuple[Tensor, ...], Tuple[Tensor, ...]]: 281 # TODO: We don't need to copy the model to create a stateless copy 282 model_copy = copy.deepcopy(model) 283 params, param_names, param_names_map = extract_weights(model_copy) 284 buffers, buffer_names, buffer_names_map = extract_buffers(model_copy) 285 if disable_autograd_tracking: 286 for param in params: 287 param.requires_grad_(False) 288 return ( 289 FunctionalModuleWithBuffers( 290 model_copy, param_names, buffer_names, param_names_map, buffer_names_map 291 ), 292 params, 293 buffers, 294 ) 295 296 def forward( 297 self, params: Iterable[Tensor], buffers: Iterable[Tensor], *args, **kwargs 298 ) -> Any: 299 # Temporarily load the state back onto self.stateless_model 300 old_state = _swap_state( 301 self.stateless_model, 302 self.all_names_map, 303 tuple(params) + tuple(buffers), 304 ) 305 try: 306 return self.stateless_model(*args, **kwargs) 307 finally: 308 # Remove the loaded state on self.stateless_model 309 _swap_state(self.stateless_model, self.all_names_map, old_state) 310 311 312class FunctionalModule(nn.Module): 313 """ 314 This is the callable object returned by :func:`make_functional`. 315 """ 316 317 def __init__( 318 self, 319 stateless_model: nn.Module, 320 param_names: Tuple[str, ...], 321 names_map: Dict[str, List[str]], 322 ) -> None: 323 super().__init__() 324 self.stateless_model = stateless_model 325 self.param_names = param_names 326 self.names_map = names_map 327 328 @staticmethod 329 def _create_from( 330 model: nn.Module, disable_autograd_tracking: bool = False 331 ) -> Tuple["FunctionalModule", Tuple[Tensor, ...]]: 332 # TODO: We don't need to copy the model to create a stateless copy 333 model_copy = copy.deepcopy(model) 334 params, param_names, names_map = extract_weights(model_copy) 335 if disable_autograd_tracking: 336 for param in params: 337 param.requires_grad_(False) 338 return FunctionalModule(model_copy, param_names, names_map), params 339 340 def forward(self, params: Iterable[Tensor], *args, **kwargs) -> Any: 341 # Temporarily load the state back onto self.stateless_model 342 old_state = _swap_state(self.stateless_model, self.names_map, params) 343 try: 344 return self.stateless_model(*args, **kwargs) 345 finally: 346 # Remove the loaded state on self.stateless_model 347 _swap_state(self.stateless_model, self.names_map, old_state) 348 349 350def make_functional( 351 model: nn.Module, disable_autograd_tracking: bool = False 352) -> Tuple[FunctionalModule, Tuple[Tensor, ...]]: 353 """make_functional(model, disable_autograd_tracking=False) -> func, params 354 355 Given a ``torch.nn.Module``, :func:`make_functional` extracts the state 356 (params) and returns a functional version of the model, ``func``. This 357 makes it so that it is possible use transforms over the parameters of 358 ``model``. 359 360 ``func`` can be invoked as follows: 361 362 .. code-block:: python 363 364 import torch 365 import torch.nn as nn 366 from functorch import make_functional 367 368 x = torch.randn(4, 3) 369 model = nn.Linear(3, 3) 370 func, params = make_functional(model) 371 func(params, x) 372 373 And here is an example of applying the grad transform over the parameters 374 of a model. 375 376 .. code-block:: python 377 378 import torch 379 import torch.nn as nn 380 from functorch import make_functional, grad 381 382 x = torch.randn(4, 3) 383 t = torch.randn(4, 3) 384 model = nn.Linear(3, 3) 385 func, params = make_functional(model) 386 387 def compute_loss(params, x, t): 388 y = func(params, x) 389 return nn.functional.mse_loss(y, t) 390 391 grad_weights = grad(compute_loss)(params, x, t) 392 393 If the model has any buffers, please use :func:`make_functional_with_buffers` instead. 394 395 Args: 396 model (torch.nn.Module): Input model. 397 disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. 398 The returned params are unrelated to the set of params from the original model. If False (default), 399 the params will have ``requires_grad=True`` on them (aka they will be trackable with regular 400 PyTorch autograd), matching the requires_grad-ness of the params from the original model. 401 Otherwise, the returned params will have ``requires_grad=False``. Default, False. 402 If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or 403 ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. 404 Otherwise, if you're only planning on using functorch's gradient transforms, 405 then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking 406 history with PyTorch autograd. 407 408 """ 409 buffers = list(model.buffers()) 410 if len(buffers) > 0: 411 raise RuntimeError( 412 "make_functional(model): `model` has buffers. Please use " 413 "make_functional_with_buffers(model) instead." 414 ) 415 return FunctionalModule._create_from( 416 model, disable_autograd_tracking=disable_autograd_tracking 417 ) 418 419 420def make_functional_with_buffers( 421 model: nn.Module, disable_autograd_tracking: bool = False 422) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]: 423 """make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers 424 425 Given a ``torch.nn.Module``, make_functional_with_buffers extracts the 426 state (params and buffers) and returns a functional version of the model 427 ``func`` that can be invoked like a function. 428 429 ``func`` can be invoked as follows: 430 431 .. code-block:: python 432 433 import torch 434 import torch.nn as nn 435 from functorch import make_functional_with_buffers 436 437 x = torch.randn(4, 3) 438 model = nn.Linear(3, 3) 439 func, params, buffers = make_functional_with_buffers(model) 440 func(params, buffers, x) 441 442 And here is an example of applying the grad transform over the parameters 443 of a model: 444 445 .. code-block:: python 446 447 import torch 448 import torch.nn as nn 449 from functorch import make_functional_with_buffers, grad 450 451 x = torch.randn(4, 3) 452 t = torch.randn(4, 3) 453 model = nn.Linear(3, 3) 454 func, params, buffers = make_functional_with_buffers(model) 455 456 def compute_loss(params, buffers, x, t): 457 y = func(params, buffers, x) 458 return nn.functional.mse_loss(y, t) 459 460 grad_weights = grad(compute_loss)(params, buffers, x, t) 461 462 Args: 463 model (torch.nn.Module): Input model. 464 disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. 465 The returned params are unrelated to the set of params from the original model. If False (default), 466 the params will have ``requires_grad=True`` on them (aka they will be trackable with regular 467 PyTorch autograd), matching the requires_grad-ness of the params from the original model. 468 Otherwise, the returned params will have ``requires_grad=False``. Default, False. 469 If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or 470 ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. 471 Otherwise, if you're only planning on using functorch's gradient transforms, 472 then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking 473 history with PyTorch autograd. 474 475 """ 476 return FunctionalModuleWithBuffers._create_from( 477 model, disable_autograd_tracking=disable_autograd_tracking 478 ) 479 480 481def transpose_stack( 482 tuple_of_tuple_of_tensors: Tuple[Tuple[Tensor, ...], ...] 483) -> Tuple[Tensor, ...]: 484 tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors)) 485 results = tuple( 486 torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors 487 ) 488 return results 489 490 491def combine_state_for_ensemble( 492 models: Sequence[nn.Module], 493) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]: 494 """combine_state_for_ensemble(models) -> func, params, buffers 495 496 Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`. 497 498 Given a list of ``M`` ``nn.Modules`` of the same class, stacks all of their 499 parameters and buffers together to make ``params`` and ``buffers``. 500 Each parameter and buffer in the result will have an additional dimension 501 of size ``M``. 502 503 :func:`combine_state_for_ensemble` also returns ``func``, a functional 504 version of one of the models in :attr:`models`. One cannot directly run 505 ``func(params, buffers, *args, **kwargs)`` directly, you probably want to 506 use ``vmap(func, ...)(params, buffers, *args, **kwargs)`` 507 508 Here's an example of how to ensemble over a very simple model: 509 510 .. code-block:: python 511 512 num_models = 5 513 batch_size = 64 514 in_features, out_features = 3, 3 515 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] 516 data = torch.randn(batch_size, 3) 517 518 fmodel, params, buffers = combine_state_for_ensemble(models) 519 output = vmap(fmodel, (0, 0, None))(params, buffers, data) 520 521 assert output.shape == (num_models, batch_size, out_features) 522 523 .. warning:: 524 All of the modules being stacked together must be the same (except for 525 the values of their parameters/buffers). For example, they should be in the 526 same mode (training vs eval). 527 528 This API is subject to change -- we're investigating better ways to 529 create ensembles and would love your feedback how to improve this. 530 """ 531 if len(models) == 0: 532 raise RuntimeError( 533 "combine_state_for_ensemble: Expected at least one model, got 0." 534 ) 535 if not (all(m.training for m in models) or all(not m.training for m in models)): 536 raise RuntimeError( 537 "combine_state_for_ensemble: Expected all models to " 538 "have the same training/eval mode." 539 ) 540 model0_typ = type(models[0]) 541 if not all(type(m) == model0_typ for m in models): 542 raise RuntimeError( 543 "combine_state_for_ensemble: Expected all models to be of the same class." 544 ) 545 funcs, params, buffers = zip( 546 *[make_functional_with_buffers(model) for model in models] 547 ) 548 params = transpose_stack(params) 549 buffers = transpose_stack(buffers) 550 return funcs[0], params, buffers 551 552 553def functional_init( 554 model_class: Type[nn.Module], 555 ensemble_shape: Union[Tuple[()], Tuple[int]] = (), 556 device: torch.types.Device = "cpu", 557): 558 def wrapped(*args, **kwargs): 559 if len(ensemble_shape) >= 2: 560 raise ValueError("NYI: ensemble_shape with more than 1 element") 561 if len(ensemble_shape) == 0: 562 model = model_class(*args, **kwargs).to(device) 563 return make_functional_deprecated_v1(model) 564 num_models = ensemble_shape[0] # type: ignore[misc] 565 if num_models <= 0: 566 raise ValueError(f"num_models {num_models} should be > 0") 567 # NB: Not very efficient, more of a POC 568 models = tuple( 569 model_class(*args, **kwargs).to(device) for _ in range(num_models) 570 ) 571 _, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs)) 572 weights = tuple(make_functional_deprecated_v1(model)[0] for model in models) 573 weights = tuple(zip(*weights)) 574 weights = tuple(torch.stack(shards).detach() for shards in weights) 575 return weights, fn, names 576 577 return wrapped 578 579 580def functional_init_with_buffers( 581 model_class: Type[nn.Module], 582 ensemble_shape: Union[Tuple[()], Tuple[int]] = (), 583 device: torch.types.Device = "cpu", 584): 585 def wrapped(*args, **kwargs): 586 if len(ensemble_shape) >= 2: 587 raise ValueError("NYI: ensemble_shape with more than 1 element") 588 if len(ensemble_shape) == 0: 589 model = model_class(*args, **kwargs).to(device) 590 return make_functional_deprecated_v1(model) 591 num_models = ensemble_shape[0] # type: ignore[misc] 592 if num_models <= 0: 593 raise ValueError(f"num_models {num_models} should be > 0") 594 # NB: Not very efficient, more of a POC 595 models = tuple( 596 model_class(*args, **kwargs).to(device) for _ in range(num_models) 597 ) 598 ( 599 _, 600 _, 601 fn, 602 weight_names, 603 buffer_names, 604 ) = make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs)) 605 weights, buffers = zip( 606 *tuple( 607 make_functional_with_buffers_deprecated_v1(model)[:2] 608 for model in models 609 ) 610 ) 611 weights = tuple(zip(*weights)) 612 weights = tuple(torch.stack(shards).detach() for shards in weights) 613 buffers = tuple(zip(*buffers)) 614 buffers = tuple(torch.stack(shards).detach() for shards in buffers) 615 return weights, buffers, fn, weight_names, buffer_names 616 617 return wrapped 618