xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/torchao_backend.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Any, Callable
2
3import torch
4
5
6def setup_baseline():
7    from torchao.quantization.utils import recommended_inductor_config_setter
8
9    recommended_inductor_config_setter()
10    torch._dynamo.config.automatic_dynamic_shapes = False
11    torch._dynamo.config.cache_size_limit = 10000
12
13
14def torchao_optimize_ctx(quantization: str):
15    from torchao.quantization.quant_api import (
16        autoquant,
17        int4_weight_only,
18        int8_dynamic_activation_int8_weight,
19        int8_weight_only,
20        quantize_,
21    )
22    from torchao.utils import unwrap_tensor_subclass
23
24    def inner(model_iter_fn: Callable):
25        def _torchao_apply(module: torch.nn.Module, example_inputs: Any):
26            if getattr(module, "_quantized", None) is None:
27                if quantization == "int8dynamic":
28                    quantize_(
29                        module,
30                        int8_dynamic_activation_int8_weight(),
31                        set_inductor_config=False,
32                    )
33                elif quantization == "int8weightonly":
34                    quantize_(module, int8_weight_only(), set_inductor_config=False)
35                elif quantization == "int4weightonly":
36                    quantize_(module, int4_weight_only(), set_inductor_config=False)
37                if quantization == "autoquant":
38                    autoquant(module, error_on_unseen=False, set_inductor_config=False)
39                    if isinstance(example_inputs, dict):
40                        module(**example_inputs)
41                    else:
42                        module(*example_inputs)
43                    from torchao.quantization.autoquant import AUTOQUANT_CACHE
44
45                    if len(AUTOQUANT_CACHE) == 0:
46                        raise Exception(  # noqa: TRY002`
47                            "NotAutoquantizable"
48                            f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run"
49                        )
50                else:
51                    unwrap_tensor_subclass(module)
52                setattr(module, "_quantized", True)  # noqa: B010
53            model_iter_fn(module, example_inputs)
54
55        return _torchao_apply
56
57    return inner
58