xref: /aosp_15_r20/external/pytorch/torch/_functorch/make_functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Facebook, Inc. and its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import copy
9from typing import (
10    Any,
11    Callable,
12    Dict,
13    Iterable,
14    List,
15    NoReturn,
16    Sequence,
17    Tuple,
18    Type,
19    Union,
20)
21
22import torch
23import torch.nn as nn
24from torch import Tensor
25from torch.nn.utils._named_member_accessor import NamedMemberAccessor
26
27
28# Utilities to make nn.Module "functional"
29# In particular the goal is to be able to provide a function that takes as input
30# the parameters and evaluate the nn.Module using fixed inputs.
31
32
33def raise_parameter_tying_error() -> NoReturn:
34    raise RuntimeError(
35        "make_functional(module): we don't yet support models that "
36        "do parameter tying (also sometimes known as weight sharing). "
37        "Please try to rewrite your model by replacing all instances of the "
38        "tied parameter with another and/or comment your support in "
39        "https://github.com/pytorch/functorch/issues/446"
40    )
41
42
43def create_names_map(
44    named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
45    tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
46) -> Dict[str, List[str]]:
47    """
48    named_params is a dictionary of tensors: {'A': A, 'B': B}
49    tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
50    with potentially tied (or 'duplicated') tensors
51
52    This function creates a mapping from the names in named_params to the
53    names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
54    """
55    named_params = dict(named_params)
56    tied_named_params = dict(tied_named_params)
57
58    tensors_dict_keys = set(named_params.keys())
59    tied_tensors_dict_keys = set(tied_named_params.keys())
60    assert tensors_dict_keys.issubset(tied_tensors_dict_keys)
61
62    tensor_to_mapping: Dict[Tensor, Tuple[str, List[str]]] = {}
63    for key, tensor in named_params.items():
64        tensor_to_mapping[tensor] = (key, [])
65    for key, tensor in tied_named_params.items():
66        assert tensor in tensor_to_mapping
67        tensor_to_mapping[tensor][1].append(key)
68    return dict(tensor_to_mapping.values())
69
70
71def _extract_members(
72    mod: nn.Module,
73    named_members: Callable[..., Iterable[Tuple[str, Tensor]]],
74    subclass: Callable[[Tensor], Tensor],
75) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
76    all_named_members = tuple(named_members(remove_duplicate=False))
77    unique_named_members = tuple(named_members(remove_duplicate=True))
78    names_map = create_names_map(unique_named_members, all_named_members)
79
80    # Remove all the members in the model
81    memo = {}
82    accessor = NamedMemberAccessor(mod)
83    for name, p in all_named_members:
84        if p not in memo:
85            memo[p] = subclass(torch.empty_like(p, device="meta"))
86        replacement = memo[p]
87        accessor.set_tensor(name, replacement)
88
89    if len(unique_named_members) == 0:
90        names, params = (), ()
91    else:
92        names, params = zip(*unique_named_members)  # type: ignore[assignment]
93    return params, names, names_map
94
95
96def extract_weights(
97    mod: nn.Module,
98) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
99    """
100    This function removes all the Parameters from the model and
101    return them as a tuple as well as their original attribute names.
102    The weights must be re-loaded with `load_weights` before the model
103    can be used again.
104    Note that this function modifies the model in place and after this
105    call, mod.parameters() will be empty.
106    """
107    return _extract_members(mod, mod.named_parameters, nn.Parameter)
108
109
110def extract_buffers(
111    mod: nn.Module,
112) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
113    return _extract_members(mod, mod.named_buffers, lambda x: x)
114
115
116def load_weights(
117    mod: nn.Module,
118    names: Sequence[str],
119    params: Sequence[Tensor],
120    as_params: bool = False,
121) -> None:
122    """
123    Reload a set of weights so that `mod` can be used again to perform a forward pass.
124    Note that the `params` are regular Tensors (that can have history) and so are left
125    as Tensors. This means that mod.parameters() will still be empty after this call.
126    """
127    accessor = NamedMemberAccessor(mod)
128    if as_params:
129        params = [nn.Parameter(p) for p in params]
130    accessor.set_tensors(names, params)
131
132
133def _swap_state(
134    mod: nn.Module, names_map: Dict[str, List[str]], elems: Iterable[Tensor]
135) -> List[Tensor]:
136    result: List[Tensor] = []
137    accessor = NamedMemberAccessor(mod)
138    for (_, attr_names), elem in zip(names_map.items(), elems):
139        for i, attr_name in enumerate(attr_names):
140            if i == 0:
141                result.append(accessor.swap_tensor(attr_name, elem))
142            else:
143                accessor.set_tensor(attr_name, elem)
144    return result
145
146
147def load_buffers(
148    mod: nn.Module,
149    names: Sequence[str],
150    buffers: Sequence[Tensor],
151    as_params: bool = False,
152) -> None:
153    accessor = NamedMemberAccessor(mod)
154    accessor.set_tensors(names, buffers)
155
156
157def load_state(
158    model: nn.Module,
159    weights: Sequence[Tensor],
160    weight_names: Sequence[str],
161    buffers: Sequence[Tensor] = (),
162    buffer_names: Sequence[str] = (),
163) -> nn.Module:
164    """load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model
165
166    load_state takes `weights` and `buffers` and assigns them to the model.
167    This is the inverse operation of `make_functional_deprecated_v1`.
168    """
169    assert len(weight_names) == len(weights)
170    load_weights(model, weight_names, weights)
171    if len(buffers) > 0:
172        assert len(buffer_names) == len(buffers)
173        load_buffers(model, buffer_names, buffers)
174    return model
175
176
177def make_functional_deprecated_v1(model: nn.Module):
178    """make_functional_deprecated_v1(model) -> weights, func, weight_names
179
180    Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights)
181    and returns a functional version of the model, `func`. This makes
182    it so that it is possible use transforms over the parameters of
183    `model`.
184
185    `func` can be invoked as follows:
186    ```
187    x = torch.randn(4, 3)
188    model = nn.Linear(3, 3)
189    weights, func, _ = make_functional_deprecated_v1(model)
190    func(weights, (x,))
191    ```
192
193    And here is an example of applying the grad transform:
194    ```
195    x = torch.randn(4, 3)
196    model = nn.Linear(3, 3)
197    weights, _, func = make_functional_deprecated_v1(model)
198    grad_weights = grad(func)(weights, (x,))
199    ```
200
201    To put the state back into a model, use `load_state`.
202    """
203    buffers = list(model.buffers())
204    if len(buffers) > 0:
205        raise RuntimeError(
206            "make_functional_deprecated_v1(model): `model` has buffers. Please use "
207            "make_functional_with_buffers_deprecated_v1(model) instead."
208        )
209    weights, descriptors, _ = extract_weights(model)
210
211    def fun(weights, data):
212        mutable_model = copy.deepcopy(model)
213        load_weights(mutable_model, descriptors, weights)
214        return mutable_model(*data)
215
216    return weights, fun, descriptors
217
218
219def make_functional_with_buffers_deprecated_v1(model: nn.Module):
220    """make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names
221
222    Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers)
223    and returns a functional version of the model, `func`.
224
225    `func` can be invoked as follows:
226    ```
227    x = torch.randn(4, 3)
228    model = nn.Linear(3, 3)
229    weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
230    func(weights, buffers, (x,))
231    ```
232
233    And here is an example of applying the grad transform:
234    ```
235    x = torch.randn(4, 3)
236    model = nn.Linear(3, 3)
237    weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
238    func(weights, buffers, (x,))
239    grad_weights = grad(func)(weights, buffers, (x,))
240    ```
241
242    To put the state back into a model, use `load_state`.
243    """
244    weights, weight_descriptors, _ = extract_weights(model)
245    buffers, buf_descriptors, _ = extract_buffers(model)
246
247    def fun(weights, buffers, data):
248        mutable_model = copy.deepcopy(model)
249        load_weights(mutable_model, weight_descriptors, weights)
250        load_buffers(mutable_model, buf_descriptors, buffers)
251        return mutable_model(*data)
252
253    return weights, buffers, fun, weight_descriptors, buf_descriptors
254
255
256class FunctionalModuleWithBuffers(nn.Module):
257    """
258    This is the callable object returned by :func:`make_functional_with_buffers`.
259    """
260
261    def __init__(
262        self,
263        stateless_model: nn.Module,
264        param_names: Tuple[str, ...],
265        buffer_names: Tuple[str, ...],
266        param_names_map: Dict[str, List[str]],
267        buffer_names_map: Dict[str, List[str]],
268    ) -> None:
269        super().__init__()
270        self.stateless_model = stateless_model
271        self.param_names = param_names
272        self.buffer_names = buffer_names
273
274        self.all_names_map = dict(param_names_map)
275        self.all_names_map.update(buffer_names_map)
276
277    @staticmethod
278    def _create_from(
279        model: nn.Module, disable_autograd_tracking: bool = False
280    ) -> Tuple["FunctionalModuleWithBuffers", Tuple[Tensor, ...], Tuple[Tensor, ...]]:
281        # TODO: We don't need to copy the model to create a stateless copy
282        model_copy = copy.deepcopy(model)
283        params, param_names, param_names_map = extract_weights(model_copy)
284        buffers, buffer_names, buffer_names_map = extract_buffers(model_copy)
285        if disable_autograd_tracking:
286            for param in params:
287                param.requires_grad_(False)
288        return (
289            FunctionalModuleWithBuffers(
290                model_copy, param_names, buffer_names, param_names_map, buffer_names_map
291            ),
292            params,
293            buffers,
294        )
295
296    def forward(
297        self, params: Iterable[Tensor], buffers: Iterable[Tensor], *args, **kwargs
298    ) -> Any:
299        # Temporarily load the state back onto self.stateless_model
300        old_state = _swap_state(
301            self.stateless_model,
302            self.all_names_map,
303            tuple(params) + tuple(buffers),
304        )
305        try:
306            return self.stateless_model(*args, **kwargs)
307        finally:
308            # Remove the loaded state on self.stateless_model
309            _swap_state(self.stateless_model, self.all_names_map, old_state)
310
311
312class FunctionalModule(nn.Module):
313    """
314    This is the callable object returned by :func:`make_functional`.
315    """
316
317    def __init__(
318        self,
319        stateless_model: nn.Module,
320        param_names: Tuple[str, ...],
321        names_map: Dict[str, List[str]],
322    ) -> None:
323        super().__init__()
324        self.stateless_model = stateless_model
325        self.param_names = param_names
326        self.names_map = names_map
327
328    @staticmethod
329    def _create_from(
330        model: nn.Module, disable_autograd_tracking: bool = False
331    ) -> Tuple["FunctionalModule", Tuple[Tensor, ...]]:
332        # TODO: We don't need to copy the model to create a stateless copy
333        model_copy = copy.deepcopy(model)
334        params, param_names, names_map = extract_weights(model_copy)
335        if disable_autograd_tracking:
336            for param in params:
337                param.requires_grad_(False)
338        return FunctionalModule(model_copy, param_names, names_map), params
339
340    def forward(self, params: Iterable[Tensor], *args, **kwargs) -> Any:
341        # Temporarily load the state back onto self.stateless_model
342        old_state = _swap_state(self.stateless_model, self.names_map, params)
343        try:
344            return self.stateless_model(*args, **kwargs)
345        finally:
346            # Remove the loaded state on self.stateless_model
347            _swap_state(self.stateless_model, self.names_map, old_state)
348
349
350def make_functional(
351    model: nn.Module, disable_autograd_tracking: bool = False
352) -> Tuple[FunctionalModule, Tuple[Tensor, ...]]:
353    """make_functional(model, disable_autograd_tracking=False) -> func, params
354
355    Given a ``torch.nn.Module``, :func:`make_functional` extracts the state
356    (params) and returns a functional version of the model, ``func``. This
357    makes it so that it is possible use transforms over the parameters of
358    ``model``.
359
360    ``func`` can be invoked as follows:
361
362    .. code-block:: python
363
364        import torch
365        import torch.nn as nn
366        from functorch import make_functional
367
368        x = torch.randn(4, 3)
369        model = nn.Linear(3, 3)
370        func, params = make_functional(model)
371        func(params, x)
372
373    And here is an example of applying the grad transform over the parameters
374    of a model.
375
376    .. code-block:: python
377
378        import torch
379        import torch.nn as nn
380        from functorch import make_functional, grad
381
382        x = torch.randn(4, 3)
383        t = torch.randn(4, 3)
384        model = nn.Linear(3, 3)
385        func, params = make_functional(model)
386
387        def compute_loss(params, x, t):
388            y = func(params, x)
389            return nn.functional.mse_loss(y, t)
390
391        grad_weights = grad(compute_loss)(params, x, t)
392
393    If the model has any buffers, please use :func:`make_functional_with_buffers` instead.
394
395    Args:
396        model (torch.nn.Module): Input model.
397        disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters.
398            The returned params are unrelated to the set of params from the original model. If False (default),
399            the params will have ``requires_grad=True`` on them (aka they will be trackable with regular
400            PyTorch autograd), matching the requires_grad-ness of the params from the original model.
401            Otherwise, the returned params will have ``requires_grad=False``. Default, False.
402            If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or
403            ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``.
404            Otherwise, if you're only planning on using functorch's gradient transforms,
405            then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking
406            history with PyTorch autograd.
407
408    """
409    buffers = list(model.buffers())
410    if len(buffers) > 0:
411        raise RuntimeError(
412            "make_functional(model): `model` has buffers. Please use "
413            "make_functional_with_buffers(model) instead."
414        )
415    return FunctionalModule._create_from(
416        model, disable_autograd_tracking=disable_autograd_tracking
417    )
418
419
420def make_functional_with_buffers(
421    model: nn.Module, disable_autograd_tracking: bool = False
422) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
423    """make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers
424
425    Given a ``torch.nn.Module``, make_functional_with_buffers extracts the
426    state (params and buffers) and returns a functional version of the model
427    ``func`` that can be invoked like a function.
428
429    ``func`` can be invoked as follows:
430
431    .. code-block:: python
432
433        import torch
434        import torch.nn as nn
435        from functorch import make_functional_with_buffers
436
437        x = torch.randn(4, 3)
438        model = nn.Linear(3, 3)
439        func, params, buffers = make_functional_with_buffers(model)
440        func(params, buffers, x)
441
442    And here is an example of applying the grad transform over the parameters
443    of a model:
444
445    .. code-block:: python
446
447        import torch
448        import torch.nn as nn
449        from functorch import make_functional_with_buffers, grad
450
451        x = torch.randn(4, 3)
452        t = torch.randn(4, 3)
453        model = nn.Linear(3, 3)
454        func, params, buffers = make_functional_with_buffers(model)
455
456        def compute_loss(params, buffers, x, t):
457            y = func(params, buffers, x)
458            return nn.functional.mse_loss(y, t)
459
460        grad_weights = grad(compute_loss)(params, buffers, x, t)
461
462    Args:
463        model (torch.nn.Module): Input model.
464        disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters.
465            The returned params are unrelated to the set of params from the original model. If False (default),
466            the params will have ``requires_grad=True`` on them (aka they will be trackable with regular
467            PyTorch autograd), matching the requires_grad-ness of the params from the original model.
468            Otherwise, the returned params will have ``requires_grad=False``. Default, False.
469            If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or
470            ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``.
471            Otherwise, if you're only planning on using functorch's gradient transforms,
472            then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking
473            history with PyTorch autograd.
474
475    """
476    return FunctionalModuleWithBuffers._create_from(
477        model, disable_autograd_tracking=disable_autograd_tracking
478    )
479
480
481def transpose_stack(
482    tuple_of_tuple_of_tensors: Tuple[Tuple[Tensor, ...], ...]
483) -> Tuple[Tensor, ...]:
484    tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))
485    results = tuple(
486        torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors
487    )
488    return results
489
490
491def combine_state_for_ensemble(
492    models: Sequence[nn.Module],
493) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
494    """combine_state_for_ensemble(models) -> func, params, buffers
495
496    Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
497
498    Given a list of ``M`` ``nn.Modules`` of the same class, stacks all of their
499    parameters and buffers together to make ``params`` and ``buffers``.
500    Each parameter and buffer in the result will have an additional dimension
501    of size ``M``.
502
503    :func:`combine_state_for_ensemble` also returns ``func``, a functional
504    version of one of the models in :attr:`models`. One cannot directly run
505    ``func(params, buffers, *args, **kwargs)`` directly, you probably want to
506    use ``vmap(func, ...)(params, buffers, *args, **kwargs)``
507
508    Here's an example of how to ensemble over a very simple model:
509
510    .. code-block:: python
511
512        num_models = 5
513        batch_size = 64
514        in_features, out_features = 3, 3
515        models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
516        data = torch.randn(batch_size, 3)
517
518        fmodel, params, buffers = combine_state_for_ensemble(models)
519        output = vmap(fmodel, (0, 0, None))(params, buffers, data)
520
521        assert output.shape == (num_models, batch_size, out_features)
522
523    .. warning::
524        All of the modules being stacked together must be the same (except for
525        the values of their parameters/buffers). For example, they should be in the
526        same mode (training vs eval).
527
528        This API is subject to change -- we're investigating better ways to
529        create ensembles and would love your feedback how to improve this.
530    """
531    if len(models) == 0:
532        raise RuntimeError(
533            "combine_state_for_ensemble: Expected at least one model, got 0."
534        )
535    if not (all(m.training for m in models) or all(not m.training for m in models)):
536        raise RuntimeError(
537            "combine_state_for_ensemble: Expected all models to "
538            "have the same training/eval mode."
539        )
540    model0_typ = type(models[0])
541    if not all(type(m) == model0_typ for m in models):
542        raise RuntimeError(
543            "combine_state_for_ensemble: Expected all models to be of the same class."
544        )
545    funcs, params, buffers = zip(
546        *[make_functional_with_buffers(model) for model in models]
547    )
548    params = transpose_stack(params)
549    buffers = transpose_stack(buffers)
550    return funcs[0], params, buffers
551
552
553def functional_init(
554    model_class: Type[nn.Module],
555    ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
556    device: torch.types.Device = "cpu",
557):
558    def wrapped(*args, **kwargs):
559        if len(ensemble_shape) >= 2:
560            raise ValueError("NYI: ensemble_shape with more than 1 element")
561        if len(ensemble_shape) == 0:
562            model = model_class(*args, **kwargs).to(device)
563            return make_functional_deprecated_v1(model)
564        num_models = ensemble_shape[0]  # type: ignore[misc]
565        if num_models <= 0:
566            raise ValueError(f"num_models {num_models} should be > 0")
567        # NB: Not very efficient, more of a POC
568        models = tuple(
569            model_class(*args, **kwargs).to(device) for _ in range(num_models)
570        )
571        _, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs))
572        weights = tuple(make_functional_deprecated_v1(model)[0] for model in models)
573        weights = tuple(zip(*weights))
574        weights = tuple(torch.stack(shards).detach() for shards in weights)
575        return weights, fn, names
576
577    return wrapped
578
579
580def functional_init_with_buffers(
581    model_class: Type[nn.Module],
582    ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
583    device: torch.types.Device = "cpu",
584):
585    def wrapped(*args, **kwargs):
586        if len(ensemble_shape) >= 2:
587            raise ValueError("NYI: ensemble_shape with more than 1 element")
588        if len(ensemble_shape) == 0:
589            model = model_class(*args, **kwargs).to(device)
590            return make_functional_deprecated_v1(model)
591        num_models = ensemble_shape[0]  # type: ignore[misc]
592        if num_models <= 0:
593            raise ValueError(f"num_models {num_models} should be > 0")
594        # NB: Not very efficient, more of a POC
595        models = tuple(
596            model_class(*args, **kwargs).to(device) for _ in range(num_models)
597        )
598        (
599            _,
600            _,
601            fn,
602            weight_names,
603            buffer_names,
604        ) = make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs))
605        weights, buffers = zip(
606            *tuple(
607                make_functional_with_buffers_deprecated_v1(model)[:2]
608                for model in models
609            )
610        )
611        weights = tuple(zip(*weights))
612        weights = tuple(torch.stack(shards).detach() for shards in weights)
613        buffers = tuple(zip(*buffers))
614        buffers = tuple(torch.stack(shards).detach() for shards in buffers)
615        return weights, buffers, fn, weight_names, buffer_names
616
617    return wrapped
618