xref: /aosp_15_r20/external/pytorch/torch/fx/immutable_collections.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Dict, Iterable, List, Tuple
3
4from torch.utils._pytree import (
5    _dict_flatten,
6    _dict_flatten_with_keys,
7    _dict_unflatten,
8    _list_flatten,
9    _list_flatten_with_keys,
10    _list_unflatten,
11    Context,
12    register_pytree_node,
13)
14
15from ._compatibility import compatibility
16
17
18__all__ = ["immutable_list", "immutable_dict"]
19
20_help_mutation = """\
21If you are attempting to modify the kwargs or args of a torch.fx.Node object,
22instead create a new copy of it and assign the copy to the node:
23    new_args = ... # copy and mutate args
24    node.args = new_args
25"""
26
27
28def _no_mutation(self, *args, **kwargs):
29    raise NotImplementedError(
30        f"'{type(self).__name__}' object does not support mutation. {_help_mutation}",
31    )
32
33
34def _create_immutable_container(base, mutable_functions):
35    container = type("immutable_" + base.__name__, (base,), {})
36    for attr in mutable_functions:
37        setattr(container, attr, _no_mutation)
38    return container
39
40
41immutable_list = _create_immutable_container(
42    list,
43    (
44        "__delitem__",
45        "__iadd__",
46        "__imul__",
47        "__setitem__",
48        "append",
49        "clear",
50        "extend",
51        "insert",
52        "pop",
53        "remove",
54        "reverse",
55        "sort",
56    ),
57)
58immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),))
59immutable_list.__hash__ = lambda self: hash(tuple(self))
60
61compatibility(is_backward_compatible=True)(immutable_list)
62
63immutable_dict = _create_immutable_container(
64    dict,
65    (
66        "__delitem__",
67        "__ior__",
68        "__setitem__",
69        "clear",
70        "pop",
71        "popitem",
72        "setdefault",
73        "update",
74    ),
75)
76immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),))
77immutable_dict.__hash__ = lambda self: hash(tuple(self.items()))
78compatibility(is_backward_compatible=True)(immutable_dict)
79
80
81# Register immutable collections for PyTree operations
82def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
83    return _dict_flatten(d)
84
85
86def _immutable_dict_unflatten(
87    values: Iterable[Any],
88    context: Context,
89) -> Dict[Any, Any]:
90    return immutable_dict(_dict_unflatten(values, context))
91
92
93def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
94    return _list_flatten(d)
95
96
97def _immutable_list_unflatten(
98    values: Iterable[Any],
99    context: Context,
100) -> List[Any]:
101    return immutable_list(_list_unflatten(values, context))
102
103
104register_pytree_node(
105    immutable_dict,
106    _immutable_dict_flatten,
107    _immutable_dict_unflatten,
108    serialized_type_name="torch.fx.immutable_collections.immutable_dict",
109    flatten_with_keys_fn=_dict_flatten_with_keys,
110)
111register_pytree_node(
112    immutable_list,
113    _immutable_list_flatten,
114    _immutable_list_unflatten,
115    serialized_type_name="torch.fx.immutable_collections.immutable_list",
116    flatten_with_keys_fn=_list_flatten_with_keys,
117)
118