xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/freeze_module.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /** \brief This file defines freezing Torchscript module API.
2  *
3  * This API has python-binding and can be invoked directly or as a part of
4  * general optimization pipeline.
5  */
6 #pragma once
7 
8 #include <torch/csrc/jit/api/module.h>
9 #include <torch/csrc/jit/ir/ir.h>
10 
11 /** \brief Freeze Module, i.e., Assume all attributes are constants.
12  *
13  * Freezing module is a functionality that allows the JIT to internalize
14  * immutable attributes. Combined with inlining, the module is aggressively
15  * optimized and significant overhead is optimized away. The freezeModule API
16  * produces a cloned frozen module.
17  */
18 
19 namespace torch::jit {
20 
21 TORCH_API Module freeze_module(
22     const Module& module,
23     std::vector<std::string> preservedAttrs = std::vector<std::string>(),
24     bool freezeInterfaces = true,
25     bool preserveParameters = false);
26 
27 // Clone-free version of freeze_module. This modifies the module inplace.
28 // Use this version to avoid extra memory usage incurred by cloning the module.
29 TORCH_API void freeze_module_inplace(
30     Module* module,
31     std::vector<std::string> preservedAttrs = std::vector<std::string>(),
32     bool freezeInterfaces = true,
33     bool preserveParameters = false);
34 } // namespace torch::jit
35