1# mypy: allow-untyped-defs 2from typing import Any, Dict, Iterable, List, Tuple 3 4from torch.utils._pytree import ( 5 _dict_flatten, 6 _dict_flatten_with_keys, 7 _dict_unflatten, 8 _list_flatten, 9 _list_flatten_with_keys, 10 _list_unflatten, 11 Context, 12 register_pytree_node, 13) 14 15from ._compatibility import compatibility 16 17 18__all__ = ["immutable_list", "immutable_dict"] 19 20_help_mutation = """\ 21If you are attempting to modify the kwargs or args of a torch.fx.Node object, 22instead create a new copy of it and assign the copy to the node: 23 new_args = ... # copy and mutate args 24 node.args = new_args 25""" 26 27 28def _no_mutation(self, *args, **kwargs): 29 raise NotImplementedError( 30 f"'{type(self).__name__}' object does not support mutation. {_help_mutation}", 31 ) 32 33 34def _create_immutable_container(base, mutable_functions): 35 container = type("immutable_" + base.__name__, (base,), {}) 36 for attr in mutable_functions: 37 setattr(container, attr, _no_mutation) 38 return container 39 40 41immutable_list = _create_immutable_container( 42 list, 43 ( 44 "__delitem__", 45 "__iadd__", 46 "__imul__", 47 "__setitem__", 48 "append", 49 "clear", 50 "extend", 51 "insert", 52 "pop", 53 "remove", 54 "reverse", 55 "sort", 56 ), 57) 58immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),)) 59immutable_list.__hash__ = lambda self: hash(tuple(self)) 60 61compatibility(is_backward_compatible=True)(immutable_list) 62 63immutable_dict = _create_immutable_container( 64 dict, 65 ( 66 "__delitem__", 67 "__ior__", 68 "__setitem__", 69 "clear", 70 "pop", 71 "popitem", 72 "setdefault", 73 "update", 74 ), 75) 76immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),)) 77immutable_dict.__hash__ = lambda self: hash(tuple(self.items())) 78compatibility(is_backward_compatible=True)(immutable_dict) 79 80 81# Register immutable collections for PyTree operations 82def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: 83 return _dict_flatten(d) 84 85 86def _immutable_dict_unflatten( 87 values: Iterable[Any], 88 context: Context, 89) -> Dict[Any, Any]: 90 return immutable_dict(_dict_unflatten(values, context)) 91 92 93def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: 94 return _list_flatten(d) 95 96 97def _immutable_list_unflatten( 98 values: Iterable[Any], 99 context: Context, 100) -> List[Any]: 101 return immutable_list(_list_unflatten(values, context)) 102 103 104register_pytree_node( 105 immutable_dict, 106 _immutable_dict_flatten, 107 _immutable_dict_unflatten, 108 serialized_type_name="torch.fx.immutable_collections.immutable_dict", 109 flatten_with_keys_fn=_dict_flatten_with_keys, 110) 111register_pytree_node( 112 immutable_list, 113 _immutable_list_flatten, 114 _immutable_list_unflatten, 115 serialized_type_name="torch.fx.immutable_collections.immutable_list", 116 flatten_with_keys_fn=_list_flatten_with_keys, 117) 118