1from collections import defaultdict 2from typing import Callable, Dict, List, Tuple, Union 3 4import torch 5from torch import nn, Tensor 6 7 8# Type helpers 9InputsType = Union[Tensor, Tuple[Tensor, ...]] 10# A Getter takes in a device and returns a callable and the inputs to that callable 11GetterReturnType = Tuple[Callable[..., Tensor], InputsType] 12GetterType = Callable[[torch.device], GetterReturnType] 13# V here refers to the v in either vjp, jvp, vhp or hvp 14VType = Union[None, Tensor, Tuple[Tensor, ...]] 15# Type used to store timing results. The first key is the model name, the second key 16# is the task name, the result is a Tuple of: speedup, mean_before, var_before, mean_after, var_after. 17TimingResultType = Dict[str, Dict[str, Tuple[float, ...]]] 18 19 20# Utilities to make nn.Module "functional" 21# In particular the goal is to be able to provide a function that takes as input 22# the parameters and evaluate the nn.Module using fixed inputs. 23def _del_nested_attr(obj: nn.Module, names: List[str]) -> None: 24 """ 25 Deletes the attribute specified by the given list of names. 26 For example, to delete the attribute obj.conv.weight, 27 use _del_nested_attr(obj, ['conv', 'weight']) 28 """ 29 if len(names) == 1: 30 delattr(obj, names[0]) 31 else: 32 _del_nested_attr(getattr(obj, names[0]), names[1:]) 33 34 35def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None: 36 """ 37 Set the attribute specified by the given list of names to value. 38 For example, to set the attribute obj.conv.weight, 39 use _del_nested_attr(obj, ['conv', 'weight'], value) 40 """ 41 if len(names) == 1: 42 setattr(obj, names[0], value) 43 else: 44 _set_nested_attr(getattr(obj, names[0]), names[1:], value) 45 46 47def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]: 48 """ 49 This function removes all the Parameters from the model and 50 return them as a tuple as well as their original attribute names. 51 The weights must be re-loaded with `load_weights` before the model 52 can be used again. 53 Note that this function modifies the model in place and after this 54 call, mod.parameters() will be empty. 55 """ 56 orig_params = tuple(mod.parameters()) 57 # Remove all the parameters in the model 58 names = [] 59 for name, p in list(mod.named_parameters()): 60 _del_nested_attr(mod, name.split(".")) 61 names.append(name) 62 63 # Make params regular Tensors instead of nn.Parameter 64 params = tuple(p.detach().requires_grad_() for p in orig_params) 65 return params, names 66 67 68def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) -> None: 69 """ 70 Reload a set of weights so that `mod` can be used again to perform a forward pass. 71 Note that the `params` are regular Tensors (that can have history) and so are left 72 as Tensors. This means that mod.parameters() will still be empty after this call. 73 """ 74 for name, p in zip(names, params): 75 _set_nested_attr(mod, name.split("."), p) 76 77 78# Utilities to read/write markdown table-like content. 79def to_markdown_table(res: TimingResultType, header: Tuple[str, ...] = None) -> str: 80 if header is None: 81 header = ("model", "task", "mean", "var") 82 out = "" 83 84 def write_line(*args): 85 nonlocal out 86 out += f"| {' | '.join(str(a) for a in args)} |\n" 87 88 # Make it a markdown table 89 write_line(*header) 90 write_line(*["--"] * len(header)) 91 for model, tasks in res.items(): 92 for task, line in tasks.items(): 93 write_line(*(model, task) + line) 94 95 return out 96 97 98def from_markdown_table(data: str) -> TimingResultType: 99 out = data.strip().split("\n") 100 out = out[2:] # Ignore the header lines 101 102 res: TimingResultType 103 res = defaultdict(defaultdict) 104 105 for line in out: 106 model, task, mean, var = (f.strip() for f in line.strip().split("|") if f) 107 res[model][task] = (float(mean), float(var)) 108 109 return res 110 111 112def check_for_functorch(): 113 try: 114 import functorch # noqa: F401 115 116 return True 117 except ImportError: 118 return False 119