xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/lower_graph.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 namespace torch::jit {
6 
7 using ModulePtr = c10::intrusive_ptr<c10::ivalue::Object>;
8 
9 // Given a graph with of a method which first argument is %self, lower it to a
10 // graph where all attributes accesses are replaced with explicit inputs of the
11 // graph (rather than results of prim::GetAttr executed on %self).
12 //
13 // Returns a tuple (graph, parameters) where the last module.parameters.size()
14 // inputs to the graph are the trainable parameters used in this method. The
15 // remaining inputs are the true inputs to the function.
16 TORCH_API std::pair<std::shared_ptr<Graph>, std::vector<IValue>> LowerGraph(
17     Graph& graph,
18     const ModulePtr& self);
19 
20 } // namespace torch::jit
21