xref: /aosp_15_r20/external/pytorch/torch/distributed/_composable/contract.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import uuid
3from collections import OrderedDict
4from functools import wraps
5from typing import Callable, Dict, List, Optional, Sequence, Type, Union
6
7import torch
8import torch.nn as nn
9from torch.distributed._composable_state import _State
10from torch.distributed.utils import _get_root_modules
11
12
13def generate_state_key(string="__composable_api_state_key"):
14    return f"{string}_{str(uuid.uuid4())}"
15
16
17STATE_KEY = generate_state_key()
18REGISTRY_KEY = generate_state_key()
19
20
21# TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
22# we can add args and kwargs here, and then we can detect whether fully_shard
23# is combined with reentrant activation checkpointing and error out with a clear
24# message.
25class RegistryItem:
26    pass
27
28
29def contract(state_cls: Type[_State] = _State):
30    r"""
31    Decorate a function as a composable distributed API, where the first
32    argument of the function must be an :class:`nn.Module` instance or sequence
33    of :class:`nn.Module` instances.
34
35    The decorator verifies that the decorated function does not modify
36    fully-qualified names (FQNs) for parameters, buffers, or modules. The
37    decorated function can return different module instances than the input
38    modules; the FQN invariant will be enforced following the input order.
39
40    When a function ``func`` is decorated by ``@contract()``, a
41    ``.state(module: nn.Module)`` method will be installed to the decorated
42    function. Then you can retrieve and modify the state on a module by calling
43    ``func.state(module)``.
44
45    Example::
46        >>> # xdoctest: +SKIP
47        >>> import torch.nn as nn
48        >>>
49        >>> class MyModel(nn.Module):
50        >>>     def __init__(self) -> None:
51        >>>         super().__init__()
52        >>>         self.l1 = nn.Linear(10, 10)
53        >>>         self.l2 = nn.Linear(10, 10)
54        >>>
55        >>>     def forward(self, x):
56        >>>         return self.l2(self.l1(x))
57        >>>
58        >>> @contract()
59        >>> def my_feature(module: nn.Module) -> nn.Module:
60        >>>     my_feature.state(module).some_state = "any value"
61        >>>     return module
62        >>>
63        >>> model = MyModel()
64        >>> my_feature(model.l1)
65        >>> assert my_feature.state(model.l1).some_state == "any value"
66        >>> my_feature(model.l2)
67        >>> model(torch.randn(2, 10)).sum().backward()
68    """
69
70    # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
71    @wraps(state_cls)
72    def inner(func):
73        @wraps(func)
74        def wrapper(
75            module: Union[nn.Module, Sequence[nn.Module]], *args, **kwargs
76        ) -> Optional[nn.Module]:
77            inp_module = module
78            if isinstance(module, nn.Module):
79                modules = [module]
80            else:
81                # If the user passes a sequence of modules, then we assume that
82                # we only need to insert the state object on the root modules
83                # (i.e. those without a parent) among the passed-in modules.
84                modules = _get_root_modules(list(module))
85            state = state_cls()  # shared across all modules
86            registry_item = RegistryItem()  # shared across all modules
87
88            # `func` is allowed to return different module instances than the
89            # input modules as long as FQNs are preserved following the input
90            # module order
91            all_orig_named_params: List[Dict[str, nn.Parameter]] = []
92            all_orig_named_buffers: List[Dict[str, torch.Tensor]] = []
93            all_orig_named_modules: List[Dict[str, nn.Module]] = []
94
95            for module in modules:
96                default_all_state: Dict[Callable, _State] = OrderedDict()
97                default_registry: Dict[str, RegistryItem] = OrderedDict()
98                all_state: Dict[Callable, _State] = module.__dict__.setdefault(  # type: ignore[call-overload]
99                    STATE_KEY, default_all_state
100                )
101                if not isinstance(all_state, dict):
102                    raise AssertionError(
103                        f"Distributed composable API states corrupted: {all_state}"
104                    )
105                registry: Dict[str, RegistryItem] = module.__dict__.setdefault(  # type: ignore[call-overload]
106                    REGISTRY_KEY, default_registry
107                )
108                if not isinstance(registry, dict):
109                    raise AssertionError(
110                        f"Distributed composable API registry corrupted: {registry}"
111                    )
112                if func in all_state or func.__name__ in registry:
113                    raise AssertionError(
114                        "Each distinct composable distributed API can only be applied to a "
115                        f"module once. {func.__name__} has already been applied to the "
116                        f"following module:\n{module}"
117                    )
118                all_state.setdefault(func, state)
119                registry.setdefault(func.__name__, registry_item)
120
121                all_orig_named_params.append(OrderedDict(module.named_parameters()))
122                all_orig_named_buffers.append(OrderedDict(module.named_buffers()))
123                all_orig_named_modules.append(OrderedDict(module.named_modules()))
124
125            updated = func(inp_module, *args, **kwargs)
126            if updated is None:
127                updated = inp_module
128            if isinstance(updated, nn.Module):
129                updated_modules = [updated]
130            else:
131                updated_modules = _get_root_modules(list(inp_module))
132
133            all_new_named_params: List[Dict[str, nn.Parameter]] = []
134            all_new_named_buffers: List[Dict[str, torch.Tensor]] = []
135            all_new_named_modules: List[Dict[str, nn.Module]] = []
136            for module in updated_modules:
137                all_new_named_params.append(OrderedDict(module.named_parameters()))
138                all_new_named_buffers.append(OrderedDict(module.named_buffers()))
139                all_new_named_modules.append(OrderedDict(module.named_modules()))
140
141            num_orig_modules = len(all_orig_named_modules)
142            num_new_modules = len(all_new_named_modules)
143            if num_orig_modules != num_new_modules:
144                raise AssertionError(
145                    f"{func.__name__} should return the same number of modules as input modules"
146                    f"Inputs: {num_orig_modules} modules\n"
147                    f"Outputs: {num_new_modules} modules"
148                )
149
150            def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str):
151                if orig_fqns == new_fqns:
152                    return
153
154                orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
155                orig_only = orig_fqn_set - new_fqn_set
156                new_only = new_fqn_set - orig_fqn_set
157                if len(orig_only) or len(new_only):
158                    raise RuntimeError(
159                        f"{check_key}"
160                        "Composable distributed API implementations cannot modify FQNs.\n"
161                        f"FQNs only in original: {orig_only}\n"
162                        f"FQNs only in new: {new_only}"
163                    )
164                else:
165                    raise RuntimeError(
166                        f"{check_key}"
167                        "Composable distributed API implementations cannot modify "
168                        "the order of FQNs.\n"
169                        f"Original FQNs: {orig_only}\n"
170                        f"New FQNs: {new_only}"
171                    )
172
173            for orig_named_params, new_named_params in zip(
174                all_orig_named_params, all_new_named_params
175            ):
176                check_fqn(
177                    list(orig_named_params.keys()),
178                    list(new_named_params.keys()),
179                    "Checking parameters: ",
180                )
181            for orig_named_buffers, new_named_buffers in zip(
182                all_orig_named_buffers, all_new_named_buffers
183            ):
184                check_fqn(
185                    list(orig_named_buffers.keys()),
186                    list(new_named_buffers.keys()),
187                    "Checking buffers: ",
188                )
189            for orig_named_modules, new_named_modules in zip(
190                all_orig_named_modules, all_new_named_modules
191            ):
192                check_fqn(
193                    list(orig_named_modules.keys()),
194                    list(new_named_modules.keys()),
195                    "Checking modules: ",
196                )
197
198            # TODO: verify that installed distributed paradigms are compatible with
199            # each other.
200
201            return updated
202
203        def get_state(module: nn.Module) -> Optional[_State]:
204            return module.__dict__.setdefault(  # type: ignore[call-overload]
205                STATE_KEY,
206                {},  # TODO(@yhcharles): this is a temporary fix, need a better way
207            ).get(
208                func
209            )  # type: ignore[call-overload]
210
211        wrapper.state = get_state  # type: ignore[attr-defined]
212
213        return wrapper
214
215    return inner
216
217
218def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]:
219    r"""
220    Get an ``OrderedDict`` of composable APIs that have been applied to the
221    ``module``, indexed by the API name. If no API has been applied, then this
222    returns ``None``.
223    """
224    return getattr(module, REGISTRY_KEY, None)
225