xref: /aosp_15_r20/external/pytorch/benchmarks/functional_autograd_benchmark/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from collections import defaultdict
2from typing import Callable, Dict, List, Tuple, Union
3
4import torch
5from torch import nn, Tensor
6
7
8# Type helpers
9InputsType = Union[Tensor, Tuple[Tensor, ...]]
10# A Getter takes in a device and returns a callable and the inputs to that callable
11GetterReturnType = Tuple[Callable[..., Tensor], InputsType]
12GetterType = Callable[[torch.device], GetterReturnType]
13# V here refers to the v in either vjp, jvp, vhp or hvp
14VType = Union[None, Tensor, Tuple[Tensor, ...]]
15# Type used to store timing results. The first key is the model name, the second key
16# is the task name, the result is a Tuple of: speedup, mean_before, var_before, mean_after, var_after.
17TimingResultType = Dict[str, Dict[str, Tuple[float, ...]]]
18
19
20# Utilities to make nn.Module "functional"
21# In particular the goal is to be able to provide a function that takes as input
22# the parameters and evaluate the nn.Module using fixed inputs.
23def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
24    """
25    Deletes the attribute specified by the given list of names.
26    For example, to delete the attribute obj.conv.weight,
27    use _del_nested_attr(obj, ['conv', 'weight'])
28    """
29    if len(names) == 1:
30        delattr(obj, names[0])
31    else:
32        _del_nested_attr(getattr(obj, names[0]), names[1:])
33
34
35def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
36    """
37    Set the attribute specified by the given list of names to value.
38    For example, to set the attribute obj.conv.weight,
39    use _del_nested_attr(obj, ['conv', 'weight'], value)
40    """
41    if len(names) == 1:
42        setattr(obj, names[0], value)
43    else:
44        _set_nested_attr(getattr(obj, names[0]), names[1:], value)
45
46
47def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
48    """
49    This function removes all the Parameters from the model and
50    return them as a tuple as well as their original attribute names.
51    The weights must be re-loaded with `load_weights` before the model
52    can be used again.
53    Note that this function modifies the model in place and after this
54    call, mod.parameters() will be empty.
55    """
56    orig_params = tuple(mod.parameters())
57    # Remove all the parameters in the model
58    names = []
59    for name, p in list(mod.named_parameters()):
60        _del_nested_attr(mod, name.split("."))
61        names.append(name)
62
63    # Make params regular Tensors instead of nn.Parameter
64    params = tuple(p.detach().requires_grad_() for p in orig_params)
65    return params, names
66
67
68def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) -> None:
69    """
70    Reload a set of weights so that `mod` can be used again to perform a forward pass.
71    Note that the `params` are regular Tensors (that can have history) and so are left
72    as Tensors. This means that mod.parameters() will still be empty after this call.
73    """
74    for name, p in zip(names, params):
75        _set_nested_attr(mod, name.split("."), p)
76
77
78# Utilities to read/write markdown table-like content.
79def to_markdown_table(res: TimingResultType, header: Tuple[str, ...] = None) -> str:
80    if header is None:
81        header = ("model", "task", "mean", "var")
82    out = ""
83
84    def write_line(*args):
85        nonlocal out
86        out += f"| {' | '.join(str(a) for a in args)} |\n"
87
88    # Make it a markdown table
89    write_line(*header)
90    write_line(*["--"] * len(header))
91    for model, tasks in res.items():
92        for task, line in tasks.items():
93            write_line(*(model, task) + line)
94
95    return out
96
97
98def from_markdown_table(data: str) -> TimingResultType:
99    out = data.strip().split("\n")
100    out = out[2:]  # Ignore the header lines
101
102    res: TimingResultType
103    res = defaultdict(defaultdict)
104
105    for line in out:
106        model, task, mean, var = (f.strip() for f in line.strip().split("|") if f)
107        res[model][task] = (float(mean), float(var))
108
109    return res
110
111
112def check_for_functorch():
113    try:
114        import functorch  # noqa: F401
115
116        return True
117    except ImportError:
118        return False
119