xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/variadic_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/variadic_ops.h>
2 
3 #include <torch/csrc/jit/ir/alias_analysis.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/constant_pooling.h>
6 #include <torch/csrc/jit/passes/remove_mutation.h>
7 
8 namespace torch::jit {
9 
10 namespace {
11 
identifyListArgIndices(const c10::FunctionSchema & schema)12 std::vector<size_t> identifyListArgIndices(const c10::FunctionSchema& schema) {
13   std::vector<size_t> list_indices;
14   const auto& args = schema.arguments();
15   for (const auto i : c10::irange(args.size())) {
16     auto list_type = args[i].type()->castRaw<ListType>();
17     if (list_type && list_type->getElementType()->castRaw<TensorType>()) {
18       list_indices.push_back(i);
19     }
20   }
21   return list_indices;
22 }
23 
isTensorListConstruct(Node * node)24 bool isTensorListConstruct(Node* node) {
25   if (node->kind() != prim::ListConstruct) {
26     return false;
27   }
28   const auto type = node->output()->type()->castRaw<ListType>();
29   TORCH_CHECK(type != nullptr);
30   const auto& elem_type = type->getElementType();
31   return elem_type->castRaw<TensorType>();
32 }
33 
34 class VariadicUpdater {
35  public:
VariadicUpdater(std::shared_ptr<Graph> graph,NodeKind op,NodeKind variadic_op)36   VariadicUpdater(
37       std::shared_ptr<Graph> graph,
38       NodeKind op,
39       NodeKind variadic_op)
40       : graph_(std::move(graph)),
41         alias_db_(graph_),
42         op_(op),
43         variadic_op_(variadic_op) {}
44 
run()45   bool run() {
46     collectOpNodes(graph_->block());
47     bool changed = false;
48     for (auto n : op_nodes_) {
49       changed |= replaceWithVariadicOp(n);
50     }
51     return changed;
52   }
53 
54  private:
recordSchema(Node * op_node)55   void recordSchema(Node* op_node) {
56     const auto& schema = op_node->schema();
57     auto it = schema_to_list_indices_.find(schema.name());
58     if (it == schema_to_list_indices_.end()) {
59       schema_to_list_indices_.emplace(
60           schema.overload_name(), identifyListArgIndices(schema));
61     }
62   }
63 
getListIndices(Node * op_node) const64   const std::vector<size_t>& getListIndices(Node* op_node) const {
65     const auto& schema = op_node->schema();
66     auto it = schema_to_list_indices_.find(schema.overload_name());
67     TORCH_CHECK(it != schema_to_list_indices_.end());
68     return it->second;
69   }
70 
collectOpNodes(Block * block)71   void collectOpNodes(Block* block) {
72     for (auto node : block->nodes()) {
73       if (node->kind() == op_) {
74         op_nodes_.push_back(node);
75         recordSchema(node);
76       }
77       for (Block* b : node->blocks()) {
78         collectOpNodes(b);
79       }
80     }
81   }
82 
allListInputsAreValid(Node * op_node)83   bool allListInputsAreValid(Node* op_node) {
84     const size_t num_inputs = op_node->inputs().size();
85     for (const auto list_idx : getListIndices(op_node)) {
86       TORCH_CHECK(list_idx < num_inputs);
87       const auto list = op_node->input(list_idx)->node();
88       // We do not transform ops whose list input can not be moved to the
89       // position before op. This in turn implies that there is some mutation
90       // of the input list before op.
91       if (!isTensorListConstruct(list) ||
92           !alias_db_.couldMoveBeforeTopologically(list, op_node)) {
93         return false;
94       }
95     }
96     return true;
97   }
98 
insertAllInputsBetween(std::vector<Value * > & inputs,Node * node,size_t start_idx,size_t end_idx) const99   void insertAllInputsBetween(
100       std::vector<Value*>& inputs,
101       Node* node,
102       size_t start_idx,
103       size_t end_idx) const {
104     const size_t num_inputs = node->inputs().size();
105     TORCH_CHECK(start_idx <= end_idx && end_idx <= num_inputs);
106     inputs.insert(
107         inputs.end(),
108         node->inputs().begin() + start_idx,
109         node->inputs().begin() + end_idx);
110   }
111 
insertIntegerInput(std::vector<Value * > & inputs,size_t input)112   void insertIntegerInput(std::vector<Value*>& inputs, size_t input) {
113     auto constant = graph_->create(prim::Constant);
114     constant->output()->setType(c10::IntType::get());
115     constant->i_(attr::value, input);
116     graph_->prependNode(constant);
117     inputs.push_back(constant->output());
118   }
119 
deleteOpNodeAndLists(Node * op_node)120   void deleteOpNodeAndLists(Node* op_node) {
121     // Collect the lists before we destroy op_node
122     std::vector<Node*> lists;
123     const auto& list_indices = getListIndices(op_node);
124     lists.reserve(list_indices.size());
125     for (const size_t list_idx : list_indices) {
126       auto* list = op_node->input(list_idx)->node();
127       lists.push_back(list);
128     }
129 
130     GRAPH_UPDATE("Deleting\n", *op_node);
131     op_node->destroy();
132     for (auto* list : lists) {
133       if (!list->hasUses()) {
134         GRAPH_UPDATE("Deleting\n", *list);
135         list->destroy();
136       }
137     }
138   }
139 
replaceWithVariadicOp(Node * op_node)140   bool replaceWithVariadicOp(Node* op_node) {
141     if (!allListInputsAreValid(op_node)) {
142       return false;
143     }
144 
145     std::vector<Value*> inputs;
146     size_t cur_idx = 0;
147     std::vector<size_t> list_lens;
148     for (const size_t list_idx : getListIndices(op_node)) {
149       insertAllInputsBetween(inputs, op_node, cur_idx, list_idx);
150       const auto list = op_node->input(list_idx)->node();
151       const auto list_len = list->inputs().size();
152       list_lens.push_back(list_len);
153       insertAllInputsBetween(inputs, list, 0, list_len);
154       cur_idx = list_idx + 1;
155     }
156     insertAllInputsBetween(inputs, op_node, cur_idx, op_node->inputs().size());
157 
158     // We insert these extra integers at the end of the argument list only if we
159     // have more than one variadic list (the information is redundant when there
160     // is only one list because the interpreter knows how many arguments there
161     // are).
162     if (list_lens.size() > 1) {
163       for (const size_t list_len : list_lens) {
164         insertIntegerInput(inputs, list_len);
165       }
166     }
167 
168     auto var_op_node = op_node->owningGraph()->create(variadic_op_, inputs);
169     var_op_node->output()->setType(op_node->output()->type());
170     GRAPH_UPDATE("Adding\n", *var_op_node);
171     var_op_node->insertBefore(op_node);
172     GRAPH_UPDATE("Replacing\n", *op_node, "with\n", *var_op_node);
173     op_node->output()->replaceAllUsesWith(var_op_node->output());
174     deleteOpNodeAndLists(op_node);
175     return true;
176   }
177 
178   std::shared_ptr<Graph> graph_;
179   std::vector<Node*> op_nodes_;
180 
181   AliasDb alias_db_;
182 
183   NodeKind op_;
184   NodeKind variadic_op_;
185 
186   std::unordered_map<std::string, std::vector<size_t>> schema_to_list_indices_;
187 };
188 
189 } // namespace
190 
UseVariadicOp(const std::shared_ptr<Graph> & graph,NodeKind op,NodeKind variadic_op)191 bool UseVariadicOp(
192     const std::shared_ptr<Graph>& graph,
193     NodeKind op,
194     NodeKind variadic_op) {
195   const std::string pass_name = std::string("variadic ") + op.toQualString();
196   GRAPH_DUMP("Before " + pass_name, graph);
197   bool changed = VariadicUpdater(graph, op, variadic_op).run();
198   if (changed) {
199     ConstantPooling(graph);
200     GRAPH_DUMP("After " + pass_name, graph);
201   }
202   return changed;
203 }
204 
RemoveListMutationAndUseVariadicOp(const std::shared_ptr<Graph> & graph,NodeKind op,NodeKind variadic_op)205 bool RemoveListMutationAndUseVariadicOp(
206     const std::shared_ptr<Graph>& graph,
207     NodeKind op,
208     NodeKind variadic_op) {
209   bool changed_in_last_iter = true;
210   bool changed = false;
211   while (changed_in_last_iter) {
212     changed_in_last_iter = RemoveListMutation(graph);
213     changed_in_last_iter =
214         UseVariadicOp(graph, op, variadic_op) || changed_in_last_iter;
215     changed = changed || changed_in_last_iter;
216   }
217   return changed;
218 }
219 
UseVariadicCat(const std::shared_ptr<Graph> & graph)220 bool UseVariadicCat(const std::shared_ptr<Graph>& graph) {
221   return UseVariadicOp(graph, aten::cat, prim::VarConcat);
222 }
RemoveListMutationAndUseVariadicCat(const std::shared_ptr<Graph> & graph)223 bool RemoveListMutationAndUseVariadicCat(const std::shared_ptr<Graph>& graph) {
224   return RemoveListMutationAndUseVariadicOp(graph, aten::cat, prim::VarConcat);
225 }
226 
UseVariadicStack(const std::shared_ptr<Graph> & graph)227 bool UseVariadicStack(const std::shared_ptr<Graph>& graph) {
228   return UseVariadicOp(graph, aten::stack, prim::VarStack);
229 }
RemoveListMutationAndUseVariadicStack(const std::shared_ptr<Graph> & graph)230 bool RemoveListMutationAndUseVariadicStack(
231     const std::shared_ptr<Graph>& graph) {
232   return RemoveListMutationAndUseVariadicOp(graph, aten::stack, prim::VarStack);
233 }
234 
235 } // namespace torch::jit
236