1# This source code is licensed under the BSD-style license found in the 2# LICENSE file in the root directory of this source tree. 3 4from typing import Dict, Iterable, List, Tuple 5 6import torch 7 8 9_MISSING: torch.Tensor = object() # type: ignore[assignment] 10 11 12def set_tensor(module: "torch.nn.Module", name: str, tensor: torch.Tensor) -> None: 13 if not isinstance(module, torch.nn.Module): 14 raise TypeError(f"{module} is not an instance of torch.nn.Module") 15 if not isinstance(tensor, torch.Tensor) and tensor is not None: 16 raise TypeError(f"{tensor} is not an instance of torch.Tensor") 17 if "." in name: 18 raise KeyError('tensor name can\'t contain "."') 19 if name == "": 20 raise KeyError('tensor name can\'t be empty string ""') 21 if name in module._parameters: 22 module._parameters[name] = tensor # type: ignore[assignment] 23 elif name in module._buffers: 24 module._buffers[name] = tensor 25 else: 26 setattr(module, name, tensor) 27 28 29def swap_tensor( 30 module: "torch.nn.Module", 31 name: str, 32 tensor: torch.Tensor, 33 allow_missing: bool = False, 34) -> torch.Tensor: 35 if not isinstance(module, torch.nn.Module): 36 raise TypeError(f"{module} is not an instance of torch.nn.Module") 37 if ( 38 tensor is not _MISSING 39 and not isinstance(tensor, torch.Tensor) 40 and tensor is not None 41 ): 42 raise TypeError(f"{tensor} is not an instance of torch.Tensor") 43 if "." in name: 44 raise KeyError('tensor name can\'t contain "."') 45 if name == "": 46 raise KeyError('tensor name can\'t be empty string ""') 47 48 orig_tensor: torch.Tensor 49 if name in module._parameters: 50 orig_tensor = module._parameters[name] # type: ignore[assignment] 51 if tensor is not _MISSING: 52 module._parameters[name] = tensor # type: ignore[assignment] 53 else: 54 del module._parameters[name] 55 elif name in module._buffers: 56 orig_tensor = module._buffers[name] # type: ignore[assignment] 57 if tensor is not _MISSING: 58 module._buffers[name] = tensor 59 else: 60 del module._buffers[name] 61 else: 62 if hasattr(module, name): 63 orig_tensor = getattr(module, name) 64 else: 65 if not allow_missing: 66 raise AttributeError(f"{module._get_name()} has no attribute `{name}`") 67 orig_tensor = _MISSING 68 if ( 69 orig_tensor is not _MISSING 70 and not isinstance(orig_tensor, torch.Tensor) 71 and orig_tensor is not None 72 ): 73 raise TypeError( 74 f"attribute `{name}`: {orig_tensor} is not an instance of torch.Tensor" 75 ) 76 if tensor is not _MISSING: 77 setattr(module, name, tensor) 78 elif hasattr(module, name): 79 delattr(module, name) 80 return orig_tensor 81 82 83def swap_submodule( 84 module: "torch.nn.Module", 85 name: str, 86 submodule: "torch.nn.Module", 87) -> "torch.nn.Module": 88 if not isinstance(module, torch.nn.Module): 89 raise TypeError(f"{module} is not an instance of torch.nn.Module") 90 if not isinstance(submodule, torch.nn.Module): 91 raise TypeError(f"{submodule} is not an instance of torch.nn.Module") 92 if "." in name: 93 raise KeyError('submodule name can\'t contain "."') 94 if name == "": 95 raise KeyError('submodule name can\'t be empty string ""') 96 if name not in module._modules: 97 raise KeyError(f"submodule {name} does not exist") 98 99 orig_submodule = module._modules[name] 100 if not isinstance(orig_submodule, torch.nn.Module): 101 raise TypeError(f"{name} attribute is not an instance of torch.nn.Module") 102 module._modules[name] = submodule 103 return orig_submodule 104 105 106class NamedMemberAccessor: 107 """ 108 A class that provides a way to access the submodules and parameters/buffers of a module. 109 110 It provides caching mechanism to speed up submodule lookups. 111 This is useful for functional programming to manipulate the module state. 112 """ 113 114 def __init__(self, module: "torch.nn.Module") -> None: 115 self.module = module 116 self.memo: Dict[str, torch.nn.Module] = {} 117 118 # Nested attribute access 119 120 def get_submodule(self, name: str) -> "torch.nn.Module": 121 """ 122 Return the submodule specified by the given path. 123 124 For example, to get the submodule mod.layer1.conv1, 125 use accessor.get_submodule("layer1.conv1") 126 127 Compare to mod.get_submodule("layer1.conv1"), this method will cache the 128 intermediate submodule access to speed up future lookups. 129 """ 130 if not name: 131 return self.module 132 133 if name in self.memo: 134 return self.memo[name] 135 else: 136 prefix, dot, attr = name.rpartition(".") 137 if dot: 138 module = self.get_submodule(prefix) 139 else: 140 module = self.module 141 try: 142 submodule = getattr(module, attr) 143 except AttributeError as ex: 144 raise AttributeError( 145 f"{module._get_name()} has no attribute `{attr}`" 146 ) from ex 147 if not isinstance(submodule, torch.nn.Module): 148 raise TypeError( # noqa: B904 149 f"submodule `{name}`: {submodule} is not an instance of torch.nn.Module" 150 ) 151 self.memo[name] = submodule 152 return submodule 153 154 def swap_submodule(self, path: str, value: "torch.nn.Module") -> "torch.nn.Module": 155 """ 156 Swap the submodule specified by the given ``path`` to ``value``. 157 158 For example, to swap the attribute mod.layer1.conv1 use 159 ``accessor.swap_submodule("layer1.conv1", conv2)``. 160 """ 161 prefix, _, attr = path.rpartition(".") 162 return swap_submodule(self.get_submodule(prefix), attr, value) 163 164 def get_tensor(self, name: str) -> torch.Tensor: 165 """ 166 Get the tensor specified by the given path to value. 167 168 For example, to get the attribute mod.layer1.conv1.weight, 169 use accessor.get_tensor('layer1.conv1.weight') 170 171 Compare to mod.get_parameter("layer1.conv1.weight"), this method will 172 cache the intermediate submodule access to speed up future lookups. 173 """ 174 prefix, _, attr = name.rpartition(".") 175 submodule = self.get_submodule(prefix) 176 try: 177 tensor = getattr(submodule, attr) 178 except AttributeError as ex: 179 raise AttributeError( 180 f"{submodule._get_name()} has no attribute `{name}`" 181 ) from ex 182 if not isinstance(tensor, torch.Tensor) and tensor is not None: 183 raise TypeError(f"{tensor} is not an instance of torch.Tensor") 184 return tensor # type: ignore[return-value] 185 186 def set_tensor(self, name: str, value: torch.Tensor) -> None: 187 """ 188 Set the attribute specified by the given path to value. 189 190 For example, to set the attribute mod.layer1.conv1.weight, 191 use accessor.set_tensor("layer1.conv1.weight", value) 192 """ 193 prefix, _, attr = name.rpartition(".") 194 set_tensor(self.get_submodule(prefix), attr, value) 195 196 def del_tensor(self, name: str) -> None: 197 """ 198 Delete the attribute specified by the given path. 199 200 For example, to delete the attribute mod.layer1.conv1.weight, 201 use accessor.del_tensor("layer1.conv1.weight") 202 """ 203 prefix, _, attr = name.rpartition(".") 204 submodule = self.get_submodule(prefix) 205 try: 206 delattr(submodule, attr) 207 except AttributeError as ex: 208 raise AttributeError( 209 f"{submodule._get_name()} has no attribute `{name}`" 210 ) from ex 211 212 def swap_tensor( 213 self, name: str, value: torch.Tensor, allow_missing: bool = False 214 ) -> torch.Tensor: 215 """ 216 Swap the attribute specified by the given path to value. 217 218 For example, to swap the attribute mod.layer1.conv1.weight, 219 use accessor.swap_tensor("layer1.conv1.weight", value) 220 """ 221 prefix, _, attr = name.rpartition(".") 222 return swap_tensor( 223 self.get_submodule(prefix), attr, value, allow_missing=allow_missing 224 ) 225 226 # Batched operations 227 228 def get_tensors(self, names: Iterable[str]) -> List[torch.Tensor]: 229 """ 230 Get the tensors specified by the given paths. 231 232 For example, to get the attributes mod.layer1.conv1.weight and 233 mod.layer1.conv1.bias, use accessor.get_tensors(["layer1.conv1.weight", 234 "layer1.conv1.bias"]) 235 """ 236 return [self.get_tensor(name) for name in names] 237 238 def set_tensors(self, names: Iterable[str], values: Iterable[torch.Tensor]) -> None: 239 """ 240 Set the attributes specified by the given paths to values. 241 242 For example, to set the attributes mod.layer1.conv1.weight and 243 mod.layer1.conv1.bias, use accessor.set_tensors(["layer1.conv1.weight", 244 "layer1.conv1.bias"], [weight, bias]) 245 """ 246 if not isinstance(names, (list, tuple)): 247 names = list(names) 248 if not isinstance(values, (list, tuple)): 249 values = list(values) 250 assert len(names) == len(values), "names and values must have the same length" 251 252 for name, value in zip(names, values): 253 self.set_tensor(name, value) 254 255 def set_tensors_dict(self, named_tensors: Dict[str, torch.Tensor]) -> None: 256 """ 257 Set the attributes specified by the given paths to values. 258 259 For example, to set the attributes mod.layer1.conv1.weight and 260 mod.layer1.conv1.bias, use accessor.set_tensors_dict({ 261 "layer1.conv1.weight": weight, 262 "layer1.conv1.bias": bias, 263 }) 264 """ 265 for name, value in named_tensors.items(): 266 self.set_tensor(name, value) 267 268 def del_tensors(self, names: Iterable[str]) -> None: 269 """ 270 Delete the attributes specified by the given paths. 271 272 For example, to delete the attributes mod.layer1.conv1.weight and 273 mod.layer1.conv1.bias, use accessor.del_tensors(["layer1.conv1.weight", 274 "layer1.conv1.bias"]) 275 """ 276 for name in names: 277 self.del_tensor(name) 278 279 def swap_tensors( 280 self, 281 names: Iterable[str], 282 values: Iterable[torch.Tensor], 283 allow_missing: bool = False, 284 ) -> List[torch.Tensor]: 285 """ 286 Swap the attributes specified by the given paths to values. 287 288 For example, to swap the attributes mod.layer1.conv1.weight and 289 mod.layer1.conv1.bias, use accessor.swap_tensors(["layer1.conv1.weight", 290 "layer1.conv1.bias"], [weight, bias]) 291 """ 292 if not isinstance(names, (list, tuple)): 293 names = list(names) 294 if not isinstance(values, (list, tuple)): 295 values = list(values) 296 assert len(names) == len(values), "names and values must have the same length" 297 298 return [ 299 self.swap_tensor(name, value, allow_missing=allow_missing) 300 for name, value in zip(names, values) 301 ] 302 303 def swap_tensors_dict( 304 self, named_tensors: Dict[str, torch.Tensor], allow_missing: bool = False 305 ) -> Tuple[Dict[str, torch.Tensor], List[str]]: 306 """ 307 Swap the attributes specified by the given paths to values. 308 309 For example, to swap the attributes mod.layer1.conv1.weight and 310 mod.layer1.conv1.bias, use accessor.swap_tensors_dict({ 311 "layer1.conv1.weight": weight, 312 "layer1.conv1.bias": bias, 313 }) 314 """ 315 orig_named_tensors = {} 316 missing_keys = [] 317 try: 318 for name, tensor in named_tensors.items(): 319 orig_tensor = self.swap_tensor(name, tensor, allow_missing=True) 320 if orig_tensor is _MISSING: 321 missing_keys.append(name) 322 orig_named_tensors[name] = orig_tensor 323 except Exception: 324 # Swap back if any exception occurs 325 for name, orig_tensor in orig_named_tensors.items(): 326 self.swap_tensor(name, orig_tensor, allow_missing=True) 327 raise 328 if missing_keys and not allow_missing: 329 # Swap back if any key is missing when allow_missing is False 330 for name, orig_tensor in orig_named_tensors.items(): 331 self.swap_tensor(name, orig_tensor, allow_missing=True) 332 raise RuntimeError(f"Missing key(s): {', '.join(map(repr, missing_keys))}.") 333 return orig_named_tensors, missing_keys 334 335 def check_keys(self, keys: Iterable[str]) -> Tuple[List[str], List[str]]: 336 """Check that the given keys are valid.""" 337 keys = set(keys) 338 valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)} 339 missing_keys = valid_keys - keys 340 unexpected_keys = keys - valid_keys 341 return sorted(missing_keys), sorted(unexpected_keys) 342 343 # Shortcut methods 344 345 def named_parameters( 346 self, 347 remove_duplicate: bool = True, 348 ) -> Iterable[Tuple[str, torch.Tensor]]: 349 """Iterate over all the parameters in the module.""" 350 yield from self.module.named_parameters(remove_duplicate=remove_duplicate) 351 352 def named_buffers( 353 self, 354 remove_duplicate: bool = True, 355 ) -> Iterable[Tuple[str, torch.Tensor]]: 356 """Iterate over all the buffers in the module.""" 357 yield from self.module.named_buffers(remove_duplicate=remove_duplicate) 358 359 def named_tensors( 360 self, 361 remove_duplicate: bool = True, 362 ) -> Iterable[Tuple[str, torch.Tensor]]: 363 """Iterate over all the tensors in the module.""" 364 yield from self.module.named_parameters(remove_duplicate=remove_duplicate) 365 yield from self.module.named_buffers(remove_duplicate=remove_duplicate) 366 367 def named_modules( 368 self, 369 remove_duplicate: bool = True, 370 ) -> Iterable[Tuple[str, "torch.nn.Module"]]: 371 """Iterate over all the modules in the module.""" 372 yield from self.module.named_modules(remove_duplicate=remove_duplicate) 373