xref: /aosp_15_r20/external/pytorch/torch/nn/utils/stateless.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Dict, Optional, Set, Tuple, Union
3from typing_extensions import deprecated
4
5import torch
6from torch import Tensor
7from torch.nn.utils._named_member_accessor import NamedMemberAccessor
8
9
10__all__ = ["functional_call"]
11
12
13def _untie_named_tensors_map(
14    module: "torch.nn.Module",
15    parameters_and_buffers: Dict[str, Tensor],
16) -> Dict[str, Tensor]:
17    """
18    Unties all tied tensors in the module to parameters_and_buffers.
19
20    This function returns a new untied_parameters_and_buffers dictionary and leave the original
21    untied_parameters_and_buffers dictionary unchanged. It adds new (missing) keys for tied tensors
22    in the module to untied_parameters_and_buffers. The value of the new key is the user-given value
23    in the original parameters_and_buffers dictionary.
24
25    If there are more than one user-given values for the same tied tensor, it will raise an error.
26
27    For example, if the module has two tied weights self.foo and self.tied_foo and the user passes
28    {'foo': foo_value, ...}, this will return {'foo': foo_value, 'tied_foo': foo_value, ...}. If the
29    user passes {'foo': foo_value, 'tied_foo': tied_foo_value, ...}, it will raise an error. If the
30    user passes {'foo': foo_value, 'tied_foo': foo_value, ...}, it will not raise an error.
31
32    Args:
33        module (torch.nn.Module): the module to determine which tensors are tied.
34        parameters_and_buffers (Dict[str, Tensor]): a map of {name: tensor} for reparamaterizing the module.
35
36    Returns:
37        A new untied version of the parameters_and_buffers dictionary.
38
39    Raises:
40        ValueError: if there are more than one user-given values for the same tied tensor.
41    """
42    # A map of {name: tensor} for all tensors (including tied ones) in the module.
43    all_named_tensors: Dict[str, Tensor] = {}
44    all_named_tensors.update(module.named_parameters(remove_duplicate=False))
45    all_named_tensors.update(module.named_buffers(remove_duplicate=False))
46
47    # A map of {tensor: set(all_tied_names)} for all tensor names in the module.
48    tensor_to_tied_names_map: Dict[Tensor, Set[str]] = {}
49    for name, tensor in all_named_tensors.items():
50        if tensor not in tensor_to_tied_names_map:
51            tensor_to_tied_names_map[tensor] = set()
52        tensor_to_tied_names_map[tensor].add(name)
53
54    # A map of {tied_name: set(all_tied_names)} for all tensor names in the module.
55    # If a name is not tied, it will not be in this map.
56    tied_names_map: Dict[str, Set[str]] = {}
57    for tied_names in tensor_to_tied_names_map.values():
58        if len(tied_names) > 1:
59            for tied_name in tied_names:
60                tied_names_map[tied_name] = tied_names
61
62    # Make sure the user didn't pass multiple values for the same tied tensor.
63    given_names = set(parameters_and_buffers.keys())
64    # same as given_names.intersection(tied_names_map.keys()) but dynamo can't
65    # handle that
66    given_names_for_tied_tensors: set[str] = set()
67    for name in given_names:
68        if name in tied_names_map:
69            given_names_for_tied_tensors.add(name)
70
71    for given_name in given_names_for_tied_tensors:
72        tied_names = tied_names_map[given_name]
73        if (
74            # Detect if there are multiple keys present for the same tied tensor.
75            len(tied_names.intersection(given_names_for_tied_tensors)) > 1
76            # Only raise an error if the user passed multiple values for the same tied tensor.
77            # If all given values are the same, don't raise.
78            and len({parameters_and_buffers[tied_name] for tied_name in tied_names})
79            != 1
80        ):
81            raise ValueError(
82                f"functional_call got multiple values for keys {sorted(tied_names)}, "
83                f"which are tied. Consider using tie_weights=False"
84            )
85
86    # Untie the given named tensor map
87    # Make a copy for not modifying the original dict
88    untied_parameters_and_buffers = parameters_and_buffers.copy()
89    for given_name in given_names_for_tied_tensors:
90        for tied_name in tied_names_map[given_name]:
91            untied_parameters_and_buffers[tied_name] = parameters_and_buffers[
92                given_name
93            ]
94    return untied_parameters_and_buffers
95
96
97class _ReparametrizeModule:
98    def __init__(
99        self,
100        module: "torch.nn.Module",
101        parameters_and_buffers: Dict[str, Tensor],
102        tie_weights: bool = False,
103        strict: bool = False,
104        stack_weights: bool = False,
105    ):
106        self.parameters_and_buffers = parameters_and_buffers
107        self.stack_weights = stack_weights
108
109        if tie_weights:
110            self.untied_parameters_and_buffers = _untie_named_tensors_map(
111                module, parameters_and_buffers
112            )
113        else:
114            self.untied_parameters_and_buffers = parameters_and_buffers
115
116        self.accessor = NamedMemberAccessor(module)
117        if strict:
118            missing_keys, unexpected_keys = self.accessor.check_keys(
119                self.untied_parameters_and_buffers
120            )
121            error_msgs = []
122            if len(unexpected_keys) > 0:
123                error_msgs.append(
124                    f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}."
125                )
126            if len(missing_keys) > 0:
127                error_msgs.append(
128                    f"Missing key(s): {', '.join(map(repr, missing_keys))}."
129                )
130            if len(error_msgs) > 0:
131                raise RuntimeError(
132                    "Error(s) in reparametrizing for {}:\n\t{}".format(
133                        module._get_name(), "\n\t".join(error_msgs)
134                    )
135                )
136
137    def __enter__(self):
138        self.orig_parameters_and_buffers, _ = self.accessor.swap_tensors_dict(
139            self.untied_parameters_and_buffers, allow_missing=True
140        )
141
142    def __exit__(self, exception_type, exception_value, traceback):
143        if self.stack_weights:
144            # When stacking is enabled, we will restore the weights in LIFO order.
145            self.orig_parameters_and_buffers = dict(
146                reversed(self.orig_parameters_and_buffers.items())
147            )
148        new_parameters_and_buffers, _ = self.accessor.swap_tensors_dict(
149            self.orig_parameters_and_buffers, allow_missing=True
150        )
151        # Sometimes the module is not completely stateless and has some in-place modifications on
152        # the _parameters and _buffers dictionaries.
153        # Write the changed parameters and buffers back to the original dict.
154        self.parameters_and_buffers.update(
155            {
156                k: new_parameters_and_buffers[k]
157                for k in self.parameters_and_buffers
158                if k in new_parameters_and_buffers
159            }
160        )
161
162
163def _reparametrize_module(
164    module: "torch.nn.Module",
165    parameters_and_buffers: Dict[str, Tensor],
166    *,
167    tie_weights: bool = False,
168    strict: bool = False,
169    stack_weights: bool = False,
170) -> _ReparametrizeModule:
171    return _ReparametrizeModule(
172        module,
173        parameters_and_buffers,
174        tie_weights=tie_weights,
175        strict=strict,
176        stack_weights=stack_weights,
177    )
178
179
180@deprecated(
181    "`torch.nn.utils.stateless.functional_call` is deprecated as of PyTorch 2.0 "
182    "and will be removed in a future version of PyTorch. "
183    "Please use `torch.func.functional_call` instead which is a drop-in replacement.",
184    category=FutureWarning,
185)
186def functional_call(
187    module: "torch.nn.Module",
188    parameters_and_buffers: Dict[str, Tensor],
189    args: Union[Any, Tuple],
190    kwargs: Optional[Dict[str, Any]] = None,
191    *,
192    tie_weights: bool = True,
193    strict: bool = False,
194):
195    r"""Perform a functional call on the module by replacing the module parameters and buffers with the provided ones.
196
197    .. warning::
198
199        This API is deprecated as of PyTorch 2.0 and will be removed in a future
200        version of PyTorch. Please use :func:`torch.func.functional_call` instead,
201        which is a drop-in replacement for this API.
202
203    .. note:: If the module has active parametrizations, passing a value in the
204        :attr:`parameters_and_buffers` argument with the name set to the regular parameter
205        name will completely disable the parametrization.
206        If you want to apply the parametrization function to the value passed
207        please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``.
208
209    .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected
210        in the `parameters_and_buffers` input.
211
212        Example::
213
214            >>> a = {'foo': torch.zeros(())}
215            >>> # xdoctest: +SKIP
216            >>> mod = Foo()  # does self.foo = self.foo + 1
217            >>> print(mod.foo)  # tensor(0.)
218            >>> functional_call(mod, a, torch.ones(()))
219            >>> print(mod.foo)  # tensor(0.)
220            >>> print(a['foo'])  # tensor(1.)
221
222    .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
223        tie_weights flag.
224
225        Example::
226
227            >>> a = {'foo': torch.zeros(())}
228            >>> # xdoctest: +SKIP
229            >>> mod = Foo()  # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
230            >>> print(mod.foo)  # tensor(1.)
231            >>> mod(torch.zeros(()))  # tensor(2.)
232            >>> functional_call(mod, a, torch.zeros(()))  # tensor(0.) since it will change self.foo_tied too
233            >>> functional_call(mod, a, torch.zeros(()), tie_weights=False)  # tensor(1.)--self.foo_tied is not updated
234            >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())}
235            >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
236
237    Args:
238        module (torch.nn.Module): the module to call
239        parameters_and_buffers (dict of str and Tensor): the parameters that will be used in
240            the module call.
241        args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
242        kwargs (dict): keyword arguments to be passed to the module call
243        tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
244            tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
245            parameters and buffers, it will error. If False, it will not respect the originally tied parameters and
246            buffers unless the values passed for both weights are the same. Default: True.
247        strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and
248            buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will
249            error. Default: False.
250
251    Returns:
252        Any: the result of calling ``module``.
253    """
254    return _functional_call(
255        module,
256        parameters_and_buffers,
257        args,
258        kwargs,
259        tie_weights=tie_weights,
260        strict=strict,
261    )
262
263
264def _functional_call(
265    module: "torch.nn.Module",
266    parameters_and_buffers: Dict[str, Tensor],
267    args: Union[Any, Tuple],
268    kwargs: Optional[Dict[str, Any]] = None,
269    *,
270    tie_weights: bool = True,
271    strict: bool = False,
272):
273    # TODO allow kwargs such as unsafe and others for parametrization
274    if (
275        torch.jit.is_tracing()
276        or torch.jit.is_scripting()
277        or isinstance(
278            module,
279            (
280                torch.jit.RecursiveScriptModule,
281                torch.jit.ScriptModule,
282                torch.jit.ScriptFunction,
283            ),
284        )
285    ):
286        raise RuntimeError("The stateless API can't be used with Jitted modules")
287    if isinstance(module, torch.nn.DataParallel):
288        raise RuntimeError(
289            "The stateless API can't be used with nn.DataParallel module"
290        )
291    if kwargs is None:
292        kwargs = {}
293    if not isinstance(args, tuple):
294        args = (args,)
295    with _reparametrize_module(
296        module, parameters_and_buffers, tie_weights=tie_weights, strict=strict
297    ):
298        return module(*args, **kwargs)
299