1 #include <torch/csrc/jit/passes/variadic_ops.h>
2
3 #include <torch/csrc/jit/ir/alias_analysis.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/constant_pooling.h>
6 #include <torch/csrc/jit/passes/remove_mutation.h>
7
8 namespace torch::jit {
9
10 namespace {
11
identifyListArgIndices(const c10::FunctionSchema & schema)12 std::vector<size_t> identifyListArgIndices(const c10::FunctionSchema& schema) {
13 std::vector<size_t> list_indices;
14 const auto& args = schema.arguments();
15 for (const auto i : c10::irange(args.size())) {
16 auto list_type = args[i].type()->castRaw<ListType>();
17 if (list_type && list_type->getElementType()->castRaw<TensorType>()) {
18 list_indices.push_back(i);
19 }
20 }
21 return list_indices;
22 }
23
isTensorListConstruct(Node * node)24 bool isTensorListConstruct(Node* node) {
25 if (node->kind() != prim::ListConstruct) {
26 return false;
27 }
28 const auto type = node->output()->type()->castRaw<ListType>();
29 TORCH_CHECK(type != nullptr);
30 const auto& elem_type = type->getElementType();
31 return elem_type->castRaw<TensorType>();
32 }
33
34 class VariadicUpdater {
35 public:
VariadicUpdater(std::shared_ptr<Graph> graph,NodeKind op,NodeKind variadic_op)36 VariadicUpdater(
37 std::shared_ptr<Graph> graph,
38 NodeKind op,
39 NodeKind variadic_op)
40 : graph_(std::move(graph)),
41 alias_db_(graph_),
42 op_(op),
43 variadic_op_(variadic_op) {}
44
run()45 bool run() {
46 collectOpNodes(graph_->block());
47 bool changed = false;
48 for (auto n : op_nodes_) {
49 changed |= replaceWithVariadicOp(n);
50 }
51 return changed;
52 }
53
54 private:
recordSchema(Node * op_node)55 void recordSchema(Node* op_node) {
56 const auto& schema = op_node->schema();
57 auto it = schema_to_list_indices_.find(schema.name());
58 if (it == schema_to_list_indices_.end()) {
59 schema_to_list_indices_.emplace(
60 schema.overload_name(), identifyListArgIndices(schema));
61 }
62 }
63
getListIndices(Node * op_node) const64 const std::vector<size_t>& getListIndices(Node* op_node) const {
65 const auto& schema = op_node->schema();
66 auto it = schema_to_list_indices_.find(schema.overload_name());
67 TORCH_CHECK(it != schema_to_list_indices_.end());
68 return it->second;
69 }
70
collectOpNodes(Block * block)71 void collectOpNodes(Block* block) {
72 for (auto node : block->nodes()) {
73 if (node->kind() == op_) {
74 op_nodes_.push_back(node);
75 recordSchema(node);
76 }
77 for (Block* b : node->blocks()) {
78 collectOpNodes(b);
79 }
80 }
81 }
82
allListInputsAreValid(Node * op_node)83 bool allListInputsAreValid(Node* op_node) {
84 const size_t num_inputs = op_node->inputs().size();
85 for (const auto list_idx : getListIndices(op_node)) {
86 TORCH_CHECK(list_idx < num_inputs);
87 const auto list = op_node->input(list_idx)->node();
88 // We do not transform ops whose list input can not be moved to the
89 // position before op. This in turn implies that there is some mutation
90 // of the input list before op.
91 if (!isTensorListConstruct(list) ||
92 !alias_db_.couldMoveBeforeTopologically(list, op_node)) {
93 return false;
94 }
95 }
96 return true;
97 }
98
insertAllInputsBetween(std::vector<Value * > & inputs,Node * node,size_t start_idx,size_t end_idx) const99 void insertAllInputsBetween(
100 std::vector<Value*>& inputs,
101 Node* node,
102 size_t start_idx,
103 size_t end_idx) const {
104 const size_t num_inputs = node->inputs().size();
105 TORCH_CHECK(start_idx <= end_idx && end_idx <= num_inputs);
106 inputs.insert(
107 inputs.end(),
108 node->inputs().begin() + start_idx,
109 node->inputs().begin() + end_idx);
110 }
111
insertIntegerInput(std::vector<Value * > & inputs,size_t input)112 void insertIntegerInput(std::vector<Value*>& inputs, size_t input) {
113 auto constant = graph_->create(prim::Constant);
114 constant->output()->setType(c10::IntType::get());
115 constant->i_(attr::value, input);
116 graph_->prependNode(constant);
117 inputs.push_back(constant->output());
118 }
119
deleteOpNodeAndLists(Node * op_node)120 void deleteOpNodeAndLists(Node* op_node) {
121 // Collect the lists before we destroy op_node
122 std::vector<Node*> lists;
123 const auto& list_indices = getListIndices(op_node);
124 lists.reserve(list_indices.size());
125 for (const size_t list_idx : list_indices) {
126 auto* list = op_node->input(list_idx)->node();
127 lists.push_back(list);
128 }
129
130 GRAPH_UPDATE("Deleting\n", *op_node);
131 op_node->destroy();
132 for (auto* list : lists) {
133 if (!list->hasUses()) {
134 GRAPH_UPDATE("Deleting\n", *list);
135 list->destroy();
136 }
137 }
138 }
139
replaceWithVariadicOp(Node * op_node)140 bool replaceWithVariadicOp(Node* op_node) {
141 if (!allListInputsAreValid(op_node)) {
142 return false;
143 }
144
145 std::vector<Value*> inputs;
146 size_t cur_idx = 0;
147 std::vector<size_t> list_lens;
148 for (const size_t list_idx : getListIndices(op_node)) {
149 insertAllInputsBetween(inputs, op_node, cur_idx, list_idx);
150 const auto list = op_node->input(list_idx)->node();
151 const auto list_len = list->inputs().size();
152 list_lens.push_back(list_len);
153 insertAllInputsBetween(inputs, list, 0, list_len);
154 cur_idx = list_idx + 1;
155 }
156 insertAllInputsBetween(inputs, op_node, cur_idx, op_node->inputs().size());
157
158 // We insert these extra integers at the end of the argument list only if we
159 // have more than one variadic list (the information is redundant when there
160 // is only one list because the interpreter knows how many arguments there
161 // are).
162 if (list_lens.size() > 1) {
163 for (const size_t list_len : list_lens) {
164 insertIntegerInput(inputs, list_len);
165 }
166 }
167
168 auto var_op_node = op_node->owningGraph()->create(variadic_op_, inputs);
169 var_op_node->output()->setType(op_node->output()->type());
170 GRAPH_UPDATE("Adding\n", *var_op_node);
171 var_op_node->insertBefore(op_node);
172 GRAPH_UPDATE("Replacing\n", *op_node, "with\n", *var_op_node);
173 op_node->output()->replaceAllUsesWith(var_op_node->output());
174 deleteOpNodeAndLists(op_node);
175 return true;
176 }
177
178 std::shared_ptr<Graph> graph_;
179 std::vector<Node*> op_nodes_;
180
181 AliasDb alias_db_;
182
183 NodeKind op_;
184 NodeKind variadic_op_;
185
186 std::unordered_map<std::string, std::vector<size_t>> schema_to_list_indices_;
187 };
188
189 } // namespace
190
UseVariadicOp(const std::shared_ptr<Graph> & graph,NodeKind op,NodeKind variadic_op)191 bool UseVariadicOp(
192 const std::shared_ptr<Graph>& graph,
193 NodeKind op,
194 NodeKind variadic_op) {
195 const std::string pass_name = std::string("variadic ") + op.toQualString();
196 GRAPH_DUMP("Before " + pass_name, graph);
197 bool changed = VariadicUpdater(graph, op, variadic_op).run();
198 if (changed) {
199 ConstantPooling(graph);
200 GRAPH_DUMP("After " + pass_name, graph);
201 }
202 return changed;
203 }
204
RemoveListMutationAndUseVariadicOp(const std::shared_ptr<Graph> & graph,NodeKind op,NodeKind variadic_op)205 bool RemoveListMutationAndUseVariadicOp(
206 const std::shared_ptr<Graph>& graph,
207 NodeKind op,
208 NodeKind variadic_op) {
209 bool changed_in_last_iter = true;
210 bool changed = false;
211 while (changed_in_last_iter) {
212 changed_in_last_iter = RemoveListMutation(graph);
213 changed_in_last_iter =
214 UseVariadicOp(graph, op, variadic_op) || changed_in_last_iter;
215 changed = changed || changed_in_last_iter;
216 }
217 return changed;
218 }
219
UseVariadicCat(const std::shared_ptr<Graph> & graph)220 bool UseVariadicCat(const std::shared_ptr<Graph>& graph) {
221 return UseVariadicOp(graph, aten::cat, prim::VarConcat);
222 }
RemoveListMutationAndUseVariadicCat(const std::shared_ptr<Graph> & graph)223 bool RemoveListMutationAndUseVariadicCat(const std::shared_ptr<Graph>& graph) {
224 return RemoveListMutationAndUseVariadicOp(graph, aten::cat, prim::VarConcat);
225 }
226
UseVariadicStack(const std::shared_ptr<Graph> & graph)227 bool UseVariadicStack(const std::shared_ptr<Graph>& graph) {
228 return UseVariadicOp(graph, aten::stack, prim::VarStack);
229 }
RemoveListMutationAndUseVariadicStack(const std::shared_ptr<Graph> & graph)230 bool RemoveListMutationAndUseVariadicStack(
231 const std::shared_ptr<Graph>& graph) {
232 return RemoveListMutationAndUseVariadicOp(graph, aten::stack, prim::VarStack);
233 }
234
235 } // namespace torch::jit
236