xref: /aosp_15_r20/external/pytorch/torch/nn/utils/_named_member_accessor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# This source code is licensed under the BSD-style license found in the
2# LICENSE file in the root directory of this source tree.
3
4from typing import Dict, Iterable, List, Tuple
5
6import torch
7
8
9_MISSING: torch.Tensor = object()  # type: ignore[assignment]
10
11
12def set_tensor(module: "torch.nn.Module", name: str, tensor: torch.Tensor) -> None:
13    if not isinstance(module, torch.nn.Module):
14        raise TypeError(f"{module} is not an instance of torch.nn.Module")
15    if not isinstance(tensor, torch.Tensor) and tensor is not None:
16        raise TypeError(f"{tensor} is not an instance of torch.Tensor")
17    if "." in name:
18        raise KeyError('tensor name can\'t contain "."')
19    if name == "":
20        raise KeyError('tensor name can\'t be empty string ""')
21    if name in module._parameters:
22        module._parameters[name] = tensor  # type: ignore[assignment]
23    elif name in module._buffers:
24        module._buffers[name] = tensor
25    else:
26        setattr(module, name, tensor)
27
28
29def swap_tensor(
30    module: "torch.nn.Module",
31    name: str,
32    tensor: torch.Tensor,
33    allow_missing: bool = False,
34) -> torch.Tensor:
35    if not isinstance(module, torch.nn.Module):
36        raise TypeError(f"{module} is not an instance of torch.nn.Module")
37    if (
38        tensor is not _MISSING
39        and not isinstance(tensor, torch.Tensor)
40        and tensor is not None
41    ):
42        raise TypeError(f"{tensor} is not an instance of torch.Tensor")
43    if "." in name:
44        raise KeyError('tensor name can\'t contain "."')
45    if name == "":
46        raise KeyError('tensor name can\'t be empty string ""')
47
48    orig_tensor: torch.Tensor
49    if name in module._parameters:
50        orig_tensor = module._parameters[name]  # type: ignore[assignment]
51        if tensor is not _MISSING:
52            module._parameters[name] = tensor  # type: ignore[assignment]
53        else:
54            del module._parameters[name]
55    elif name in module._buffers:
56        orig_tensor = module._buffers[name]  # type: ignore[assignment]
57        if tensor is not _MISSING:
58            module._buffers[name] = tensor
59        else:
60            del module._buffers[name]
61    else:
62        if hasattr(module, name):
63            orig_tensor = getattr(module, name)
64        else:
65            if not allow_missing:
66                raise AttributeError(f"{module._get_name()} has no attribute `{name}`")
67            orig_tensor = _MISSING
68        if (
69            orig_tensor is not _MISSING
70            and not isinstance(orig_tensor, torch.Tensor)
71            and orig_tensor is not None
72        ):
73            raise TypeError(
74                f"attribute `{name}`: {orig_tensor} is not an instance of torch.Tensor"
75            )
76        if tensor is not _MISSING:
77            setattr(module, name, tensor)
78        elif hasattr(module, name):
79            delattr(module, name)
80    return orig_tensor
81
82
83def swap_submodule(
84    module: "torch.nn.Module",
85    name: str,
86    submodule: "torch.nn.Module",
87) -> "torch.nn.Module":
88    if not isinstance(module, torch.nn.Module):
89        raise TypeError(f"{module} is not an instance of torch.nn.Module")
90    if not isinstance(submodule, torch.nn.Module):
91        raise TypeError(f"{submodule} is not an instance of torch.nn.Module")
92    if "." in name:
93        raise KeyError('submodule name can\'t contain "."')
94    if name == "":
95        raise KeyError('submodule name can\'t be empty string ""')
96    if name not in module._modules:
97        raise KeyError(f"submodule {name} does not exist")
98
99    orig_submodule = module._modules[name]
100    if not isinstance(orig_submodule, torch.nn.Module):
101        raise TypeError(f"{name} attribute is not an instance of torch.nn.Module")
102    module._modules[name] = submodule
103    return orig_submodule
104
105
106class NamedMemberAccessor:
107    """
108    A class that provides a way to access the submodules and parameters/buffers of a module.
109
110    It provides caching mechanism to speed up submodule lookups.
111    This is useful for functional programming to manipulate the module state.
112    """
113
114    def __init__(self, module: "torch.nn.Module") -> None:
115        self.module = module
116        self.memo: Dict[str, torch.nn.Module] = {}
117
118    # Nested attribute access
119
120    def get_submodule(self, name: str) -> "torch.nn.Module":
121        """
122        Return the submodule specified by the given path.
123
124        For example, to get the submodule mod.layer1.conv1,
125        use accessor.get_submodule("layer1.conv1")
126
127        Compare to mod.get_submodule("layer1.conv1"), this method will cache the
128        intermediate submodule access to speed up future lookups.
129        """
130        if not name:
131            return self.module
132
133        if name in self.memo:
134            return self.memo[name]
135        else:
136            prefix, dot, attr = name.rpartition(".")
137            if dot:
138                module = self.get_submodule(prefix)
139            else:
140                module = self.module
141            try:
142                submodule = getattr(module, attr)
143            except AttributeError as ex:
144                raise AttributeError(
145                    f"{module._get_name()} has no attribute `{attr}`"
146                ) from ex
147            if not isinstance(submodule, torch.nn.Module):
148                raise TypeError(  # noqa: B904
149                    f"submodule `{name}`: {submodule} is not an instance of torch.nn.Module"
150                )
151            self.memo[name] = submodule
152            return submodule
153
154    def swap_submodule(self, path: str, value: "torch.nn.Module") -> "torch.nn.Module":
155        """
156        Swap the submodule specified by the given ``path`` to ``value``.
157
158        For example, to swap the attribute mod.layer1.conv1 use
159        ``accessor.swap_submodule("layer1.conv1", conv2)``.
160        """
161        prefix, _, attr = path.rpartition(".")
162        return swap_submodule(self.get_submodule(prefix), attr, value)
163
164    def get_tensor(self, name: str) -> torch.Tensor:
165        """
166        Get the tensor specified by the given path to value.
167
168        For example, to get the attribute mod.layer1.conv1.weight,
169        use accessor.get_tensor('layer1.conv1.weight')
170
171        Compare to mod.get_parameter("layer1.conv1.weight"), this method will
172        cache the intermediate submodule access to speed up future lookups.
173        """
174        prefix, _, attr = name.rpartition(".")
175        submodule = self.get_submodule(prefix)
176        try:
177            tensor = getattr(submodule, attr)
178        except AttributeError as ex:
179            raise AttributeError(
180                f"{submodule._get_name()} has no attribute `{name}`"
181            ) from ex
182        if not isinstance(tensor, torch.Tensor) and tensor is not None:
183            raise TypeError(f"{tensor} is not an instance of torch.Tensor")
184        return tensor  # type: ignore[return-value]
185
186    def set_tensor(self, name: str, value: torch.Tensor) -> None:
187        """
188        Set the attribute specified by the given path to value.
189
190        For example, to set the attribute mod.layer1.conv1.weight,
191        use accessor.set_tensor("layer1.conv1.weight", value)
192        """
193        prefix, _, attr = name.rpartition(".")
194        set_tensor(self.get_submodule(prefix), attr, value)
195
196    def del_tensor(self, name: str) -> None:
197        """
198        Delete the attribute specified by the given path.
199
200        For example, to delete the attribute mod.layer1.conv1.weight,
201        use accessor.del_tensor("layer1.conv1.weight")
202        """
203        prefix, _, attr = name.rpartition(".")
204        submodule = self.get_submodule(prefix)
205        try:
206            delattr(submodule, attr)
207        except AttributeError as ex:
208            raise AttributeError(
209                f"{submodule._get_name()} has no attribute `{name}`"
210            ) from ex
211
212    def swap_tensor(
213        self, name: str, value: torch.Tensor, allow_missing: bool = False
214    ) -> torch.Tensor:
215        """
216        Swap the attribute specified by the given path to value.
217
218        For example, to swap the attribute mod.layer1.conv1.weight,
219        use accessor.swap_tensor("layer1.conv1.weight", value)
220        """
221        prefix, _, attr = name.rpartition(".")
222        return swap_tensor(
223            self.get_submodule(prefix), attr, value, allow_missing=allow_missing
224        )
225
226    # Batched operations
227
228    def get_tensors(self, names: Iterable[str]) -> List[torch.Tensor]:
229        """
230        Get the tensors specified by the given paths.
231
232        For example, to get the attributes mod.layer1.conv1.weight and
233        mod.layer1.conv1.bias, use accessor.get_tensors(["layer1.conv1.weight",
234        "layer1.conv1.bias"])
235        """
236        return [self.get_tensor(name) for name in names]
237
238    def set_tensors(self, names: Iterable[str], values: Iterable[torch.Tensor]) -> None:
239        """
240        Set the attributes specified by the given paths to values.
241
242        For example, to set the attributes mod.layer1.conv1.weight and
243        mod.layer1.conv1.bias, use accessor.set_tensors(["layer1.conv1.weight",
244        "layer1.conv1.bias"], [weight, bias])
245        """
246        if not isinstance(names, (list, tuple)):
247            names = list(names)
248        if not isinstance(values, (list, tuple)):
249            values = list(values)
250        assert len(names) == len(values), "names and values must have the same length"
251
252        for name, value in zip(names, values):
253            self.set_tensor(name, value)
254
255    def set_tensors_dict(self, named_tensors: Dict[str, torch.Tensor]) -> None:
256        """
257        Set the attributes specified by the given paths to values.
258
259        For example, to set the attributes mod.layer1.conv1.weight and
260        mod.layer1.conv1.bias, use accessor.set_tensors_dict({
261            "layer1.conv1.weight": weight,
262            "layer1.conv1.bias": bias,
263        })
264        """
265        for name, value in named_tensors.items():
266            self.set_tensor(name, value)
267
268    def del_tensors(self, names: Iterable[str]) -> None:
269        """
270        Delete the attributes specified by the given paths.
271
272        For example, to delete the attributes mod.layer1.conv1.weight and
273        mod.layer1.conv1.bias, use accessor.del_tensors(["layer1.conv1.weight",
274        "layer1.conv1.bias"])
275        """
276        for name in names:
277            self.del_tensor(name)
278
279    def swap_tensors(
280        self,
281        names: Iterable[str],
282        values: Iterable[torch.Tensor],
283        allow_missing: bool = False,
284    ) -> List[torch.Tensor]:
285        """
286        Swap the attributes specified by the given paths to values.
287
288        For example, to swap the attributes mod.layer1.conv1.weight and
289        mod.layer1.conv1.bias, use accessor.swap_tensors(["layer1.conv1.weight",
290        "layer1.conv1.bias"], [weight, bias])
291        """
292        if not isinstance(names, (list, tuple)):
293            names = list(names)
294        if not isinstance(values, (list, tuple)):
295            values = list(values)
296        assert len(names) == len(values), "names and values must have the same length"
297
298        return [
299            self.swap_tensor(name, value, allow_missing=allow_missing)
300            for name, value in zip(names, values)
301        ]
302
303    def swap_tensors_dict(
304        self, named_tensors: Dict[str, torch.Tensor], allow_missing: bool = False
305    ) -> Tuple[Dict[str, torch.Tensor], List[str]]:
306        """
307        Swap the attributes specified by the given paths to values.
308
309        For example, to swap the attributes mod.layer1.conv1.weight and
310        mod.layer1.conv1.bias, use accessor.swap_tensors_dict({
311            "layer1.conv1.weight": weight,
312            "layer1.conv1.bias": bias,
313        })
314        """
315        orig_named_tensors = {}
316        missing_keys = []
317        try:
318            for name, tensor in named_tensors.items():
319                orig_tensor = self.swap_tensor(name, tensor, allow_missing=True)
320                if orig_tensor is _MISSING:
321                    missing_keys.append(name)
322                orig_named_tensors[name] = orig_tensor
323        except Exception:
324            # Swap back if any exception occurs
325            for name, orig_tensor in orig_named_tensors.items():
326                self.swap_tensor(name, orig_tensor, allow_missing=True)
327            raise
328        if missing_keys and not allow_missing:
329            # Swap back if any key is missing when allow_missing is False
330            for name, orig_tensor in orig_named_tensors.items():
331                self.swap_tensor(name, orig_tensor, allow_missing=True)
332            raise RuntimeError(f"Missing key(s): {', '.join(map(repr, missing_keys))}.")
333        return orig_named_tensors, missing_keys
334
335    def check_keys(self, keys: Iterable[str]) -> Tuple[List[str], List[str]]:
336        """Check that the given keys are valid."""
337        keys = set(keys)
338        valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)}
339        missing_keys = valid_keys - keys
340        unexpected_keys = keys - valid_keys
341        return sorted(missing_keys), sorted(unexpected_keys)
342
343    # Shortcut methods
344
345    def named_parameters(
346        self,
347        remove_duplicate: bool = True,
348    ) -> Iterable[Tuple[str, torch.Tensor]]:
349        """Iterate over all the parameters in the module."""
350        yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
351
352    def named_buffers(
353        self,
354        remove_duplicate: bool = True,
355    ) -> Iterable[Tuple[str, torch.Tensor]]:
356        """Iterate over all the buffers in the module."""
357        yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
358
359    def named_tensors(
360        self,
361        remove_duplicate: bool = True,
362    ) -> Iterable[Tuple[str, torch.Tensor]]:
363        """Iterate over all the tensors in the module."""
364        yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
365        yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
366
367    def named_modules(
368        self,
369        remove_duplicate: bool = True,
370    ) -> Iterable[Tuple[str, "torch.nn.Module"]]:
371        """Iterate over all the modules in the module."""
372        yield from self.module.named_modules(remove_duplicate=remove_duplicate)
373