1 #pragma once 2 3 #include <torch/csrc/jit/api/module.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 #include <torch/csrc/jit/passes/quantization/quantization_type.h> 6 7 namespace torch { 8 namespace jit { 9 10 /** Replicate quantize node for prim::If blocks, so that we can match 11 * quantization patterns in prim::If blocks 12 */ 13 TORCH_API void ReplicateQuant(std::shared_ptr<Graph>& graph); 14 15 /** Replicate dequantize node for each use, so that we can match 16 * quantization patterns 17 */ 18 TORCH_API void ReplicateDeQuant(std::shared_ptr<Graph>& graph); 19 20 /** \brief Insert quantize - dequantize calls to the Tensors 21 * that are observed in insert_observers pass 22 * 23 * For each Tensor that is observed, get the observer module and call 24 * calculate_qparam on the observer module to get quantization parameters 25 * and add quantize - int_repr - dequantize function calls using these 26 * parameters we also have special handling for quantizing "bias" right now. 27 * 28 * \param module the input module 29 * \param method_name the method we want to insert quantization calls for 30 */ 31 TORCH_API Module InsertQuantDeQuant( 32 Module& module, 33 const std::string& method_name, 34 bool inplace, 35 bool debug, 36 QuantType quant_type = QuantType::STATIC); 37 38 TORCH_API Module InsertQuantDeQuantOnDevicePTQ( 39 Module& module, 40 const std::string& method_name, 41 bool inplace, 42 bool debug, 43 QuantType quant_type = QuantType::STATIC); 44 45 } // namespace jit 46 } // namespace torch 47