1from typing import Any, Callable 2 3import torch 4 5 6def setup_baseline(): 7 from torchao.quantization.utils import recommended_inductor_config_setter 8 9 recommended_inductor_config_setter() 10 torch._dynamo.config.automatic_dynamic_shapes = False 11 torch._dynamo.config.cache_size_limit = 10000 12 13 14def torchao_optimize_ctx(quantization: str): 15 from torchao.quantization.quant_api import ( 16 autoquant, 17 int4_weight_only, 18 int8_dynamic_activation_int8_weight, 19 int8_weight_only, 20 quantize_, 21 ) 22 from torchao.utils import unwrap_tensor_subclass 23 24 def inner(model_iter_fn: Callable): 25 def _torchao_apply(module: torch.nn.Module, example_inputs: Any): 26 if getattr(module, "_quantized", None) is None: 27 if quantization == "int8dynamic": 28 quantize_( 29 module, 30 int8_dynamic_activation_int8_weight(), 31 set_inductor_config=False, 32 ) 33 elif quantization == "int8weightonly": 34 quantize_(module, int8_weight_only(), set_inductor_config=False) 35 elif quantization == "int4weightonly": 36 quantize_(module, int4_weight_only(), set_inductor_config=False) 37 if quantization == "autoquant": 38 autoquant(module, error_on_unseen=False, set_inductor_config=False) 39 if isinstance(example_inputs, dict): 40 module(**example_inputs) 41 else: 42 module(*example_inputs) 43 from torchao.quantization.autoquant import AUTOQUANT_CACHE 44 45 if len(AUTOQUANT_CACHE) == 0: 46 raise Exception( # noqa: TRY002` 47 "NotAutoquantizable" 48 f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run" 49 ) 50 else: 51 unwrap_tensor_subclass(module) 52 setattr(module, "_quantized", True) # noqa: B010 53 model_iter_fn(module, example_inputs) 54 55 return _torchao_apply 56 57 return inner 58