xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/graph_opt.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
2 
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
6 #include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
7 #include <torch/csrc/jit/tensorexpr/kernel.h>
8 
9 namespace torch::jit::tensorexpr {
10 
11 // Move the given user of `aten::cat` op to its inputs.
moveCatAfterUse(Node * cat,Node * user,const std::shared_ptr<Graph> & subgraph)12 static Node* moveCatAfterUse(
13     Node* cat,
14     Node* user,
15     const std::shared_ptr<Graph>& subgraph) {
16   // Example IR:
17   //   %1 = ...
18   //   %2 = ...
19   //   %3 = prim::ListConstruct(%1, %2)
20   //   %4 = aten::cat(%3, ...)
21   //   %5 = aten::relu(%4)
22   //   return (%5)
23   //
24   // To be transformed to:
25   //   %1 = ...
26   //   %2 = ...
27   //   %5.1 = aten::relu(%1)
28   //   %5.2 = aten::relu(%2)
29   //   %3 = prim::ListConstruct(%5.1, %5.2)
30   //   %4 = aten::cat(%3, ...)
31   //   return (%4)
32 
33   TORCH_INTERNAL_ASSERT(
34       cat->output()->hasUses(),
35       buildErrorMessage("aten::cat output is not used."));
36   TORCH_INTERNAL_ASSERT(
37       cat->output()->uses().size() == 1,
38       buildErrorMessage("aten::cat output is used in multiple places."));
39   TORCH_INTERNAL_ASSERT(
40       cat->input(0)->node()->kind() == prim::ListConstruct,
41       buildErrorMessage("aten::cat inputs are not expected."));
42   auto cat_list = cat->input(0)->node();
43   auto cat_inputs = cat_list->inputs();
44 
45   auto user_tensor_type = user->output()->type()->cast<c10::TensorType>();
46   TORCH_INTERNAL_ASSERT(
47       user_tensor_type, buildErrorMessage("Unexpected user tensor type"));
48   std::unordered_map<Value*, Value*> new_cat_inputs;
49   for (auto inp : cat_inputs) {
50     auto new_cat_input = subgraph->createClone(
51         user, [&](Value* k) { return (k == cat->output()) ? inp : k; });
52     // Since we are cloning user, its result should be the same scalar type
53     // as the user. But the dims should correspond to that of the input.
54     auto input_tensor_type = inp->type()->cast<c10::TensorType>();
55     TORCH_INTERNAL_ASSERT(
56         input_tensor_type, buildErrorMessage("Unexpected input tensor type"));
57     auto new_input_type =
58         input_tensor_type->withScalarType(user_tensor_type->scalarType());
59     new_cat_input->output()->setType(new_input_type);
60     new_cat_input->insertBefore(cat_list);
61     new_cat_inputs[inp] = new_cat_input->output();
62   }
63   auto new_cat_list = subgraph->createClone(
64       cat_list, [&](Value* k) { return new_cat_inputs[k]; });
65   new_cat_list->insertBefore(cat);
66   auto new_cat = subgraph->createClone(cat, [&](Value* k) {
67     return (k == cat_list->output()) ? new_cat_list->output() : k;
68   });
69   new_cat->output()->setType(user_tensor_type);
70   new_cat->insertBefore(cat);
71 
72   user->output()->replaceAllUsesWith(new_cat->output());
73   user->destroy();
74 
75   TORCH_INTERNAL_ASSERT(
76       !cat->output()->hasUses(),
77       buildErrorMessage("aten::cat output is not used."));
78   cat->destroy();
79 
80   if (!cat_list->output()->hasUses()) {
81     cat_list->destroy();
82   }
83 
84   return new_cat;
85 }
86 
numTensorInputs(Node * node)87 static int numTensorInputs(Node* node) {
88   int count = 0;
89   for (auto v : node->inputs()) {
90     if (v->type()->cast<c10::TensorType>()) {
91       ++count;
92     }
93   }
94   return count;
95 }
96 
97 // Returns true if the given `cat` node promotes types.
98 // If the inputs to `cat` are of different types, then the implementation
99 // of `cat` is expected to promote type.
doesCatPromoteTypes(Node * node)100 static bool doesCatPromoteTypes(Node* node) {
101   TORCH_INTERNAL_ASSERT(
102       node->kind() == aten::cat,
103       buildErrorMessage("Graph node is not aten::cat."));
104   TORCH_INTERNAL_ASSERT(
105       node->input(0)->node()->kind() == prim::ListConstruct,
106       buildErrorMessage("aten::cat inputs are not expected."));
107   auto inputs = node->input(0)->node()->inputs();
108   TORCH_INTERNAL_ASSERT(
109       !inputs.empty(), buildErrorMessage("Empty inputs of ListConstruct"));
110   auto scalar_type =
111       inputs.front()->type()->cast<c10::TensorType>()->scalarType();
112   for (size_t i = 1; i < inputs.size(); ++i) {
113     auto inp_scalar_type =
114         inputs[i]->type()->cast<c10::TensorType>()->scalarType();
115     if (scalar_type != inp_scalar_type) {
116       return true;
117     }
118   }
119   return false;
120 }
121 
122 // Move the users of the given `aten::cat` op to its inputs.
123 // The following constraints need to be satisfied on the cat op and its user.
124 //   * the cat op should have only one use.
125 //   * the user should be an element-wise op.
126 //   * the user should have only one tensor input.
127 //     - If the user has > 1 tensor inputs, that user op cannot be applied on
128 //       the inputs of cat because the other tensor inputs will not be split,
129 //       and hence the shape of those tensors would not match that of the
130 //       inputs of cat.
131 //       For example:
132 //           %1 = ...
133 //           %2 = ...
134 //           %3 = prim::ListConstruct([%1, %2])
135 //           %4 = aten::cat(%3, ...)
136 //           %5 = aten::add(%4, %0)
137 //       In this example, we cannot move `aten::add` to the inputs of
138 //       `aten::cat`, %1 and %2, because the shape of %0 will be different.
139 //    * the cat op does not promote types.
140 //      - When the cat op promote types, the type of inputs to cat after moving
141 //        it user needs to reflect the original type. This is currently not
142 //        handled. TODO
moveCatOpToEnd(Node * cat,const std::shared_ptr<Graph> & subgraph)143 static void moveCatOpToEnd(Node* cat, const std::shared_ptr<Graph>& subgraph) {
144   TORCH_INTERNAL_ASSERT(
145       cat->kind() == aten::cat,
146       buildErrorMessage("Graph node is not aten::cat."));
147   if (cat->output()->uses().size() == 1) {
148     auto use = cat->output()->uses().front();
149     if (get_tensorexpr_elementwise_set().contains(use.user) &&
150         numTensorInputs(use.user) == 1) {
151       if (!doesCatPromoteTypes(cat)) {
152         TORCH_INTERNAL_ASSERT(
153             use.user->output()->owningGraph() == subgraph.get(),
154             buildErrorMessage(
155                 "aten::cat user graph does not math the given subgraph."));
156         auto new_cat = moveCatAfterUse(cat, use.user, subgraph);
157         moveCatOpToEnd(new_cat, subgraph);
158       }
159     }
160   }
161 }
162 
163 // Moves the users of `aten::cat` ops to its inputs whenever possible
164 // in the given subgraph.
moveCatOpsToEnd(const std::shared_ptr<Graph> & subgraph)165 static void moveCatOpsToEnd(const std::shared_ptr<Graph>& subgraph) {
166   std::vector<Node*> cat_nodes;
167   for (Node* n : subgraph->nodes()) {
168     if (n->kind() == aten::cat) {
169       cat_nodes.push_back(n);
170     }
171   }
172   for (auto cat : cat_nodes) {
173     moveCatOpToEnd(cat, subgraph);
174   }
175 }
176 
OptimizeCat(const std::shared_ptr<Graph> & graph)177 bool OptimizeCat(const std::shared_ptr<Graph>& graph) {
178   if (getCatWoConditionals()) {
179     moveCatOpsToEnd(graph);
180     return true;
181   }
182   return false;
183 }
184 
annotateInputShapes(const std::shared_ptr<Graph> & graph,const std::vector<std::optional<at::Tensor>> & example_inputs)185 void annotateInputShapes(
186     const std::shared_ptr<Graph>& graph,
187     const std::vector<std::optional<at::Tensor>>& example_inputs) {
188   TORCH_INTERNAL_ASSERT(
189       graph->inputs().size() == example_inputs.size(),
190       buildErrorMessage("Given inputs do not match the fuser graph inputs."));
191   for (size_t idx = 0; idx < example_inputs.size(); idx++) {
192     if (auto t = example_inputs[idx]) {
193       auto concrete_tensor_type = tensorTypeInCurrentExecutionContext(*t);
194       graph->inputs().at(idx)->setType(concrete_tensor_type);
195     }
196   }
197 }
198 
removeUnusedSelfArgument(const std::shared_ptr<Graph> & graph)199 std::shared_ptr<Graph> removeUnusedSelfArgument(
200     const std::shared_ptr<Graph>& graph) {
201   if (graph->inputs().empty()) {
202     return graph;
203   }
204   jit::Value* self_argument = graph->inputs().at(0);
205   if (!self_argument->uses().empty() || !self_argument->type()->is_module()) {
206     return graph;
207   }
208   graph->eraseInput(0);
209   return graph;
210 }
211 
makeShapesSymbolic(std::shared_ptr<Graph> & graph,const std::vector<int64_t> & size_vals)212 std::vector<int64_t> makeShapesSymbolic(
213     std::shared_ptr<Graph>& graph,
214     const std::vector<int64_t>& size_vals) {
215   std::unordered_set<Value*> values;
216   for (auto v : graph->inputs()) {
217     values.insert(v);
218   }
219   for (auto v : graph->outputs()) {
220     values.insert(v);
221   }
222   for (auto n : graph->nodes()) {
223     for (auto v : n->inputs()) {
224       values.insert(v);
225     }
226     for (auto v : n->outputs()) {
227       values.insert(v);
228     }
229   }
230   std::unordered_map<int64_t, int64_t> shape_to_sym_shape;
231   std::vector<int64_t> new_syms;
232   for (int64_t size_val : size_vals) {
233     auto new_shape_symbol = at::ShapeSymbol::newSymbol().value();
234     shape_to_sym_shape[size_val] = new_shape_symbol;
235     new_syms.push_back(new_shape_symbol);
236     graph->addInput("sym_shape")->setType(IntType::get());
237   }
238 
239   for (auto v : values) {
240     if (!v->type()->cast<TensorType>()) {
241       continue;
242     }
243     auto tt = v->type()->expect<TensorType>();
244     if (!tt->symbolic_sizes().sizes()) {
245       continue;
246     }
247     std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
248 
249     auto new_sizes = c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
250       auto value = shape.value();
251       if (shape_to_sym_shape.count(value)) {
252         return shape_to_sym_shape.at(value);
253       }
254       return value;
255     });
256     v->setType(tt->withSymbolicShapes(c10::SymbolicShape(new_sizes)));
257   }
258 
259   return new_syms;
260 }
261 
isGraphCompilable(const std::shared_ptr<Graph> & graph)262 bool isGraphCompilable(const std::shared_ptr<Graph>& graph) {
263   for (auto input : graph->inputs()) {
264     auto const& t = input->type();
265     auto const& k = t->kind();
266     if (k != TypeKind::TensorType && k != TypeKind::FloatType &&
267         k != TypeKind::BoolType && k != TypeKind::IntType) {
268       GRAPH_DEBUG("Input %", input->debugName(), " has unsupported type ", *t);
269       return false;
270     }
271   }
272 
273   for (auto n : graph->nodes()) {
274     for (auto v : n->inputs()) {
275       auto const& t = v->type();
276       if (t->kind() == TypeKind::TensorType) {
277         auto tt = t->cast<TensorType>();
278         if (!tt->isComplete()) {
279           GRAPH_DEBUG(
280               "%",
281               v->debugName(),
282               " is not a complete tensor! The type is: ",
283               *t);
284           return false;
285         }
286       }
287     }
288     for (auto v : n->outputs()) {
289       auto const& t = v->type();
290       if (t->kind() == TypeKind::TensorType) {
291         auto tt = t->cast<TensorType>();
292         if (!tt->isComplete()) {
293           GRAPH_DEBUG(
294               "%", v->debugName(), " is not a complete! The type is: ", *t);
295           return false;
296         }
297       }
298     }
299   }
300 
301   // TODO: check if all nodes have lowerings
302   return true;
303 }
304 
fixupTypeInfoForValue(Value * v,std::optional<at::ScalarType> scalar_type,std::optional<at::Device> device)305 static void fixupTypeInfoForValue(
306     Value* v,
307     std::optional<at::ScalarType> scalar_type,
308     std::optional<at::Device> device) {
309   Node* n = v->node();
310   auto const& t = v->type();
311   if (t->kind() != TypeKind::TensorType) {
312     return;
313   }
314 
315   if (n->kind() == prim::Constant) {
316     auto const_tensor = toIValue(v)->toTensor();
317     auto concrete_tensor_type =
318         tensorTypeInCurrentExecutionContext(const_tensor);
319     v->setType(concrete_tensor_type);
320     return;
321   }
322 
323   TensorTypePtr new_tt;
324   auto tt = t->cast<TensorType>();
325   auto sizes = tt->sizes();
326   if (!sizes.concrete_sizes()) {
327     GRAPH_DEBUG("No concrete sizes for %", v->debugName());
328     return;
329   }
330   auto strides = tt->strides();
331   auto dtype = tt->scalarType() ? tt->scalarType() : scalar_type;
332   auto concrete_sizes = *sizes.concrete_sizes();
333   auto concrete_strides = strides.concrete_sizes()
334       ? *strides.concrete_sizes()
335       : TensorType::contiguousStridesOf(concrete_sizes);
336   new_tt = TensorType::create(
337       dtype, device, concrete_sizes, concrete_strides, false);
338 
339   v->setType(new_tt);
340 }
341 
inferScalarType(Node * n)342 static std::optional<at::ScalarType> inferScalarType(Node* n) {
343   std::optional<at::ScalarType> scalar_type;
344   for (auto v : n->inputs()) {
345     auto const& t = v->type();
346     if (t->kind() == TypeKind::TensorType) {
347       auto tt = t->cast<TensorType>();
348       if (!scalar_type) {
349         scalar_type = tt->scalarType();
350       }
351       if (tt->scalarType() && *tt->scalarType() != scalar_type) {
352         GRAPH_DEBUG(
353             "Inputs of ", n, " have different scalar types, cannot fixup!");
354         return std::nullopt;
355       }
356     }
357   }
358   return scalar_type;
359 }
360 
inferDevice(Node * n)361 static std::optional<at::Device> inferDevice(Node* n) {
362   std::optional<at::Device> device;
363   for (auto v : n->inputs()) {
364     auto const& t = v->type();
365     if (t->kind() == TypeKind::TensorType) {
366       auto tt = t->cast<TensorType>();
367       if (!device) {
368         device = tt->device();
369       }
370       if (tt->device() && *tt->device() != device) {
371         GRAPH_DEBUG("Inputs of ", n, " have different devices, cannot fixup!");
372         return std::nullopt;
373       }
374     }
375   }
376   if (!device) {
377     device = at::kCPU;
378   }
379   return device;
380 }
381 
fixupMissingShapeInfo(const std::shared_ptr<Graph> & graph)382 void fixupMissingShapeInfo(const std::shared_ptr<Graph>& graph) {
383   for (auto input : graph->inputs()) {
384     auto const& t = input->type();
385     if (t->kind() == TypeKind::TensorType) {
386       auto tt = t->cast<TensorType>();
387       if (!tt->scalarType()) {
388         GRAPH_DEBUG("No dtype for %", input->debugName());
389         return;
390       }
391       fixupTypeInfoForValue(
392           input, tt->scalarType(), tt->device() ? tt->device() : at::kCPU);
393     }
394   }
395 
396   for (auto n : graph->nodes()) {
397     std::optional<at::ScalarType> scalar_type = inferScalarType(n);
398     std::optional<at::Device> device = inferDevice(n);
399 
400     for (auto v : n->outputs()) {
401       fixupTypeInfoForValue(v, scalar_type, device);
402     }
403   }
404 }
405 
removeGraphOutput(const std::shared_ptr<Graph> & graph,size_t idx)406 std::shared_ptr<Graph> removeGraphOutput(
407     const std::shared_ptr<Graph>& graph,
408     size_t idx) {
409   graph->eraseOutput(idx);
410   return graph;
411 }
412 
replaceListOutputWithTuple(const std::shared_ptr<Graph> & graph)413 std::shared_ptr<Graph> replaceListOutputWithTuple(
414     const std::shared_ptr<Graph>& graph) {
415   auto out = graph->outputs()[0];
416   auto out_node = out->node();
417   if (out_node->kind() != prim::ListConstruct) {
418     return graph;
419   }
420   auto tuple_node = graph->createTuple(out_node->inputs());
421   tuple_node->insertAfter(out_node);
422   out->replaceAllUsesWith(tuple_node->output());
423   return graph;
424 }
425 
trimGraphOnce(const std::shared_ptr<Graph> & graph)426 static bool trimGraphOnce(const std::shared_ptr<Graph>& graph) {
427   Node* ret = graph->return_node();
428   std::unordered_set<Value*> graph_inputs(
429       graph->inputs().begin(), graph->inputs().end());
430   std::unordered_set<Value*> outputs(
431       graph->outputs().begin(), graph->outputs().end());
432   bool changed = false;
433   for (size_t idx = 0; idx < ret->inputs().size(); idx++) {
434     auto v = ret->inputs()[idx];
435     if (graph_inputs.count(v)) {
436       continue;
437     }
438     // Delete the graph output IDX and add all inputs of the node producing that
439     // value to the graph outputs
440     graph->eraseOutput(idx);
441     for (auto v_ins : v->node()->inputs()) {
442       if (outputs.count(v_ins)) {
443         continue;
444       }
445       if (v_ins->node()->kind() == prim::Constant) {
446         continue;
447       }
448 
449       graph->registerOutput(v_ins);
450     }
451     changed = true;
452     break;
453   }
454   return changed;
455 }
456 
dequantizeResults(const std::shared_ptr<Graph> & graph)457 static std::shared_ptr<Graph> dequantizeResults(
458     const std::shared_ptr<Graph>& graph) {
459   for (auto v : graph->outputs()) {
460     auto& t = v->type();
461     if (t->kind() == TypeKind::TensorType) {
462       auto tt = t->cast<TensorType>();
463       if (!tt->scalarType() || !c10::isQIntType(*tt->scalarType())) {
464         continue;
465       }
466       Node* deq = graph->create(aten::dequantize, {v});
467       graph->appendNode(deq);
468       deq->output()->setType(tt->withScalarType(c10::kFloat));
469       v->replaceAllUsesAfterNodeWith(deq, deq->output());
470     }
471   }
472   return graph;
473 }
474 
trimGraph(const std::shared_ptr<Graph> & graph,int64_t iters)475 std::shared_ptr<Graph> trimGraph(
476     const std::shared_ptr<Graph>& graph,
477     int64_t iters) {
478   bool changed = true;
479   int64_t iter = 0;
480   while (changed && iter++ < iters) {
481     changed = trimGraphOnce(graph);
482     EliminateDeadCode(graph->block());
483   }
484   // Avoid letting quantized values to graph outputs.
485   // Ideally we should allow quantized outputs as well, but currently the main
486   // user of this pass - AOT NNC - does not support it.
487   // TODO: remove output dequantization once NNC supports quantized outputs.
488   dequantizeResults(graph);
489   return graph;
490 }
491 
492 } // namespace torch::jit::tensorexpr
493