xref: /aosp_15_r20/external/pytorch/torch/_functorch/functional_call.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
4
5import torch
6import torch.nn as nn
7from torch import Tensor
8from torch._functorch.utils import exposed_in
9
10
11@exposed_in("torch.func")
12def functional_call(
13    module: "torch.nn.Module",
14    parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
15    args: Union[Any, Tuple],
16    kwargs: Optional[Dict[str, Any]] = None,
17    *,
18    tie_weights: bool = True,
19    strict: bool = False,
20):
21    r"""Performs a functional call on the module by replacing the module parameters
22    and buffers with the provided ones.
23
24    .. note:: If the module has active parametrizations, passing a value in the
25        :attr:`parameter_and_buffer_dicts` argument with the name set to the regular parameter
26        name will completely disable the parametrization.
27        If you want to apply the parametrization function to the value passed
28        please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``.
29
30    .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected
31        in the ``parameter_and_buffer_dicts`` input.
32
33
34         Example::
35
36            >>> a = {'foo': torch.zeros(())}
37            >>> # xdoctest: +SKIP
38            >>> mod = Foo()  # does self.foo = self.foo + 1
39            >>> print(mod.foo)  # tensor(0.)
40            >>> functional_call(mod, a, torch.ones(()))
41            >>> print(mod.foo)  # tensor(0.)
42            >>> print(a['foo'])  # tensor(1.)
43
44    .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
45        tie_weights flag.
46
47        Example::
48
49            >>> a = {'foo': torch.zeros(())}
50            >>> # xdoctest: +SKIP
51            >>> mod = Foo()  # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
52            >>> print(mod.foo)  # tensor(1.)
53            >>> mod(torch.zeros(()))  # tensor(2.)
54            >>> functional_call(mod, a, torch.zeros(()))  # tensor(0.) since it will change self.foo_tied too
55            >>> functional_call(mod, a, torch.zeros(()), tie_weights=False)  # tensor(1.)--self.foo_tied is not updated
56            >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())}
57            >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
58
59    An example of passing multiple dictionaries
60
61    .. code-block:: python
62
63            a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)})  # two separate dictionaries
64            mod = nn.Bar(1, 1)  # return self.weight @ x + self.buffer
65            print(mod.weight)  # tensor(...)
66            print(mod.buffer)  # tensor(...)
67            x = torch.randn((1, 1))
68            print(x)
69            functional_call(mod, a, x)  # same as x
70            print(mod.weight)  # same as before functional_call
71
72
73    And here is an example of applying the grad transform over the parameters
74    of a model.
75
76    .. code-block:: python
77
78        import torch
79        import torch.nn as nn
80        from torch.func import functional_call, grad
81
82        x = torch.randn(4, 3)
83        t = torch.randn(4, 3)
84        model = nn.Linear(3, 3)
85
86        def compute_loss(params, x, t):
87            y = functional_call(model, params, x)
88            return nn.functional.mse_loss(y, t)
89
90        grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)
91
92    .. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the
93        parameters for better performance and memory usage
94
95        Example::
96
97            >>> detached_params = {k: v.detach() for k, v in model.named_parameters()}
98            >>> grad_weights = grad(compute_loss)(detached_params, x, t)
99            >>> grad_weights.grad_fn  # None--it's not tracking gradients outside of grad
100
101        This means that the user cannot call ``grad_weight.backward()``. However, if they don't need autograd tracking
102        outside of the transforms, this will result in less memory usage and faster speeds.
103
104    Args:
105        module (torch.nn.Module): the module to call
106        parameters_and_buffer_dicts (Dict[str, Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in
107            the module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries can
108            be used together
109        args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
110        kwargs (dict): keyword arguments to be passed to the module call
111        tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
112            tied in the reparameterized version. Therefore, if True and different values are passed for the tied
113            parameters and buffers, it will error. If False, it will not respect the originally tied parameters and
114            buffers unless the values passed for both weights are the same. Default: True.
115        strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and
116            buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will
117            error. Default: False.
118
119    Returns:
120        Any: the result of calling ``module``.
121    """
122    if isinstance(parameter_and_buffer_dicts, dict):
123        parameters_and_buffers = parameter_and_buffer_dicts
124    elif isinstance(parameter_and_buffer_dicts, Sequence):
125        if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts):
126            raise ValueError(
127                "Expected all elements of parameter_and_buffer_dicts to be dictionaries"
128            )
129        all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()]
130        all_keys_counter: Dict[str, int] = {}
131        for k in all_keys:
132            v = all_keys_counter.get(k, 0)
133            all_keys_counter[k] = v + 1
134        repeated_keys = [key for key, n in all_keys_counter.items() if n > 1]
135        if len(repeated_keys) > 0:
136            raise ValueError(
137                f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous"
138            )
139        parameters_and_buffers = {
140            k: v for d in parameter_and_buffer_dicts for k, v in d.items()
141        }
142    else:
143        raise ValueError(
144            f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, "
145            f"but got {type(parameter_and_buffer_dicts)}"
146        )
147
148    return nn.utils.stateless._functional_call(
149        module,
150        parameters_and_buffers,
151        args,
152        kwargs,
153        tie_weights=tie_weights,
154        strict=strict,
155    )
156
157
158@exposed_in("torch.func")
159def stack_module_state(
160    models: List[nn.Module],
161) -> Tuple[Dict[str, Any], Dict[str, Any]]:
162    """stack_module_state(models) -> params, buffers
163
164    Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
165
166    Given a list of ``M`` ``nn.Modules`` of the same class, returns two dictionaries
167    that stack all of their parameters and buffers together, indexed by name.
168    The stacked parameters are optimizable (i.e. they are new leaf nodes in the
169    autograd history that are unrelated to the original parameters and can be
170    passed directly to an optimizer).
171
172    Here's an example of how to ensemble over a very simple model:
173
174    .. code-block:: python
175
176        num_models = 5
177        batch_size = 64
178        in_features, out_features = 3, 3
179        models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
180        data = torch.randn(batch_size, 3)
181
182        def wrapper(params, buffers, data):
183            return torch.func.functional_call(models[0], (params, buffers), data)
184
185        params, buffers = stack_module_state(models)
186        output = vmap(wrapper, (0, 0, None))(params, buffers, data)
187
188        assert output.shape == (num_models, batch_size, out_features)
189
190    When there's submodules, this follows state dict naming conventions
191
192    .. code-block:: python
193
194        import torch.nn as nn
195        class Foo(nn.Module):
196            def __init__(self, in_features, out_features):
197                super().__init__()
198                hidden = 4
199                self.l1 = nn.Linear(in_features, hidden)
200                self.l2 = nn.Linear(hidden, out_features)
201
202            def forward(self, x):
203                return self.l2(self.l1(x))
204
205        num_models = 5
206        in_features, out_features = 3, 3
207        models = [Foo(in_features, out_features) for i in range(num_models)]
208        params, buffers = stack_module_state(models)
209        print(list(params.keys()))  # "l1.weight", "l1.bias", "l2.weight", "l2.bias"
210
211    .. warning::
212        All of the modules being stacked together must be the same (except for
213        the values of their parameters/buffers). For example, they should be in the
214        same mode (training vs eval).
215    """
216    if len(models) == 0:
217        raise RuntimeError("stack_module_state: Expected at least one model, got 0.")
218    if not (all(m.training for m in models) or all(not m.training for m in models)):
219        raise RuntimeError(
220            "stack_module_state: Expected all models to have the same training/eval mode."
221        )
222    model0_typ = type(models[0])
223    if not all(type(m) == model0_typ for m in models):
224        raise RuntimeError(
225            "stack_module_state: Expected all models to be of the same class."
226        )
227    all_params = [dict(model.named_parameters()) for model in models]
228    params = {
229        k: construct_stacked_leaf(tuple(params[k] for params in all_params), k)
230        for k in all_params[0]
231    }
232    all_buffers = [dict(model.named_buffers()) for model in models]
233    buffers = {
234        k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k)
235        for k in all_buffers[0]
236    }
237
238    return params, buffers
239
240
241def construct_stacked_leaf(
242    tensors: Union[Tuple[Tensor, ...], List[Tensor]], name: str
243) -> Tensor:
244    all_requires_grad = all(t.requires_grad for t in tensors)
245    none_requires_grad = all(not t.requires_grad for t in tensors)
246    if not all_requires_grad and not none_requires_grad:
247        raise RuntimeError(
248            f"Expected {name} from each model to have the same .requires_grad"
249        )
250    result = torch.stack(tensors)
251    if all_requires_grad:
252        result = result.detach().requires_grad_()
253    return result
254