1 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
2
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
6 #include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
7 #include <torch/csrc/jit/tensorexpr/kernel.h>
8
9 namespace torch::jit::tensorexpr {
10
11 // Move the given user of `aten::cat` op to its inputs.
moveCatAfterUse(Node * cat,Node * user,const std::shared_ptr<Graph> & subgraph)12 static Node* moveCatAfterUse(
13 Node* cat,
14 Node* user,
15 const std::shared_ptr<Graph>& subgraph) {
16 // Example IR:
17 // %1 = ...
18 // %2 = ...
19 // %3 = prim::ListConstruct(%1, %2)
20 // %4 = aten::cat(%3, ...)
21 // %5 = aten::relu(%4)
22 // return (%5)
23 //
24 // To be transformed to:
25 // %1 = ...
26 // %2 = ...
27 // %5.1 = aten::relu(%1)
28 // %5.2 = aten::relu(%2)
29 // %3 = prim::ListConstruct(%5.1, %5.2)
30 // %4 = aten::cat(%3, ...)
31 // return (%4)
32
33 TORCH_INTERNAL_ASSERT(
34 cat->output()->hasUses(),
35 buildErrorMessage("aten::cat output is not used."));
36 TORCH_INTERNAL_ASSERT(
37 cat->output()->uses().size() == 1,
38 buildErrorMessage("aten::cat output is used in multiple places."));
39 TORCH_INTERNAL_ASSERT(
40 cat->input(0)->node()->kind() == prim::ListConstruct,
41 buildErrorMessage("aten::cat inputs are not expected."));
42 auto cat_list = cat->input(0)->node();
43 auto cat_inputs = cat_list->inputs();
44
45 auto user_tensor_type = user->output()->type()->cast<c10::TensorType>();
46 TORCH_INTERNAL_ASSERT(
47 user_tensor_type, buildErrorMessage("Unexpected user tensor type"));
48 std::unordered_map<Value*, Value*> new_cat_inputs;
49 for (auto inp : cat_inputs) {
50 auto new_cat_input = subgraph->createClone(
51 user, [&](Value* k) { return (k == cat->output()) ? inp : k; });
52 // Since we are cloning user, its result should be the same scalar type
53 // as the user. But the dims should correspond to that of the input.
54 auto input_tensor_type = inp->type()->cast<c10::TensorType>();
55 TORCH_INTERNAL_ASSERT(
56 input_tensor_type, buildErrorMessage("Unexpected input tensor type"));
57 auto new_input_type =
58 input_tensor_type->withScalarType(user_tensor_type->scalarType());
59 new_cat_input->output()->setType(new_input_type);
60 new_cat_input->insertBefore(cat_list);
61 new_cat_inputs[inp] = new_cat_input->output();
62 }
63 auto new_cat_list = subgraph->createClone(
64 cat_list, [&](Value* k) { return new_cat_inputs[k]; });
65 new_cat_list->insertBefore(cat);
66 auto new_cat = subgraph->createClone(cat, [&](Value* k) {
67 return (k == cat_list->output()) ? new_cat_list->output() : k;
68 });
69 new_cat->output()->setType(user_tensor_type);
70 new_cat->insertBefore(cat);
71
72 user->output()->replaceAllUsesWith(new_cat->output());
73 user->destroy();
74
75 TORCH_INTERNAL_ASSERT(
76 !cat->output()->hasUses(),
77 buildErrorMessage("aten::cat output is not used."));
78 cat->destroy();
79
80 if (!cat_list->output()->hasUses()) {
81 cat_list->destroy();
82 }
83
84 return new_cat;
85 }
86
numTensorInputs(Node * node)87 static int numTensorInputs(Node* node) {
88 int count = 0;
89 for (auto v : node->inputs()) {
90 if (v->type()->cast<c10::TensorType>()) {
91 ++count;
92 }
93 }
94 return count;
95 }
96
97 // Returns true if the given `cat` node promotes types.
98 // If the inputs to `cat` are of different types, then the implementation
99 // of `cat` is expected to promote type.
doesCatPromoteTypes(Node * node)100 static bool doesCatPromoteTypes(Node* node) {
101 TORCH_INTERNAL_ASSERT(
102 node->kind() == aten::cat,
103 buildErrorMessage("Graph node is not aten::cat."));
104 TORCH_INTERNAL_ASSERT(
105 node->input(0)->node()->kind() == prim::ListConstruct,
106 buildErrorMessage("aten::cat inputs are not expected."));
107 auto inputs = node->input(0)->node()->inputs();
108 TORCH_INTERNAL_ASSERT(
109 !inputs.empty(), buildErrorMessage("Empty inputs of ListConstruct"));
110 auto scalar_type =
111 inputs.front()->type()->cast<c10::TensorType>()->scalarType();
112 for (size_t i = 1; i < inputs.size(); ++i) {
113 auto inp_scalar_type =
114 inputs[i]->type()->cast<c10::TensorType>()->scalarType();
115 if (scalar_type != inp_scalar_type) {
116 return true;
117 }
118 }
119 return false;
120 }
121
122 // Move the users of the given `aten::cat` op to its inputs.
123 // The following constraints need to be satisfied on the cat op and its user.
124 // * the cat op should have only one use.
125 // * the user should be an element-wise op.
126 // * the user should have only one tensor input.
127 // - If the user has > 1 tensor inputs, that user op cannot be applied on
128 // the inputs of cat because the other tensor inputs will not be split,
129 // and hence the shape of those tensors would not match that of the
130 // inputs of cat.
131 // For example:
132 // %1 = ...
133 // %2 = ...
134 // %3 = prim::ListConstruct([%1, %2])
135 // %4 = aten::cat(%3, ...)
136 // %5 = aten::add(%4, %0)
137 // In this example, we cannot move `aten::add` to the inputs of
138 // `aten::cat`, %1 and %2, because the shape of %0 will be different.
139 // * the cat op does not promote types.
140 // - When the cat op promote types, the type of inputs to cat after moving
141 // it user needs to reflect the original type. This is currently not
142 // handled. TODO
moveCatOpToEnd(Node * cat,const std::shared_ptr<Graph> & subgraph)143 static void moveCatOpToEnd(Node* cat, const std::shared_ptr<Graph>& subgraph) {
144 TORCH_INTERNAL_ASSERT(
145 cat->kind() == aten::cat,
146 buildErrorMessage("Graph node is not aten::cat."));
147 if (cat->output()->uses().size() == 1) {
148 auto use = cat->output()->uses().front();
149 if (get_tensorexpr_elementwise_set().contains(use.user) &&
150 numTensorInputs(use.user) == 1) {
151 if (!doesCatPromoteTypes(cat)) {
152 TORCH_INTERNAL_ASSERT(
153 use.user->output()->owningGraph() == subgraph.get(),
154 buildErrorMessage(
155 "aten::cat user graph does not math the given subgraph."));
156 auto new_cat = moveCatAfterUse(cat, use.user, subgraph);
157 moveCatOpToEnd(new_cat, subgraph);
158 }
159 }
160 }
161 }
162
163 // Moves the users of `aten::cat` ops to its inputs whenever possible
164 // in the given subgraph.
moveCatOpsToEnd(const std::shared_ptr<Graph> & subgraph)165 static void moveCatOpsToEnd(const std::shared_ptr<Graph>& subgraph) {
166 std::vector<Node*> cat_nodes;
167 for (Node* n : subgraph->nodes()) {
168 if (n->kind() == aten::cat) {
169 cat_nodes.push_back(n);
170 }
171 }
172 for (auto cat : cat_nodes) {
173 moveCatOpToEnd(cat, subgraph);
174 }
175 }
176
OptimizeCat(const std::shared_ptr<Graph> & graph)177 bool OptimizeCat(const std::shared_ptr<Graph>& graph) {
178 if (getCatWoConditionals()) {
179 moveCatOpsToEnd(graph);
180 return true;
181 }
182 return false;
183 }
184
annotateInputShapes(const std::shared_ptr<Graph> & graph,const std::vector<std::optional<at::Tensor>> & example_inputs)185 void annotateInputShapes(
186 const std::shared_ptr<Graph>& graph,
187 const std::vector<std::optional<at::Tensor>>& example_inputs) {
188 TORCH_INTERNAL_ASSERT(
189 graph->inputs().size() == example_inputs.size(),
190 buildErrorMessage("Given inputs do not match the fuser graph inputs."));
191 for (size_t idx = 0; idx < example_inputs.size(); idx++) {
192 if (auto t = example_inputs[idx]) {
193 auto concrete_tensor_type = tensorTypeInCurrentExecutionContext(*t);
194 graph->inputs().at(idx)->setType(concrete_tensor_type);
195 }
196 }
197 }
198
removeUnusedSelfArgument(const std::shared_ptr<Graph> & graph)199 std::shared_ptr<Graph> removeUnusedSelfArgument(
200 const std::shared_ptr<Graph>& graph) {
201 if (graph->inputs().empty()) {
202 return graph;
203 }
204 jit::Value* self_argument = graph->inputs().at(0);
205 if (!self_argument->uses().empty() || !self_argument->type()->is_module()) {
206 return graph;
207 }
208 graph->eraseInput(0);
209 return graph;
210 }
211
makeShapesSymbolic(std::shared_ptr<Graph> & graph,const std::vector<int64_t> & size_vals)212 std::vector<int64_t> makeShapesSymbolic(
213 std::shared_ptr<Graph>& graph,
214 const std::vector<int64_t>& size_vals) {
215 std::unordered_set<Value*> values;
216 for (auto v : graph->inputs()) {
217 values.insert(v);
218 }
219 for (auto v : graph->outputs()) {
220 values.insert(v);
221 }
222 for (auto n : graph->nodes()) {
223 for (auto v : n->inputs()) {
224 values.insert(v);
225 }
226 for (auto v : n->outputs()) {
227 values.insert(v);
228 }
229 }
230 std::unordered_map<int64_t, int64_t> shape_to_sym_shape;
231 std::vector<int64_t> new_syms;
232 for (int64_t size_val : size_vals) {
233 auto new_shape_symbol = at::ShapeSymbol::newSymbol().value();
234 shape_to_sym_shape[size_val] = new_shape_symbol;
235 new_syms.push_back(new_shape_symbol);
236 graph->addInput("sym_shape")->setType(IntType::get());
237 }
238
239 for (auto v : values) {
240 if (!v->type()->cast<TensorType>()) {
241 continue;
242 }
243 auto tt = v->type()->expect<TensorType>();
244 if (!tt->symbolic_sizes().sizes()) {
245 continue;
246 }
247 std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
248
249 auto new_sizes = c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
250 auto value = shape.value();
251 if (shape_to_sym_shape.count(value)) {
252 return shape_to_sym_shape.at(value);
253 }
254 return value;
255 });
256 v->setType(tt->withSymbolicShapes(c10::SymbolicShape(new_sizes)));
257 }
258
259 return new_syms;
260 }
261
isGraphCompilable(const std::shared_ptr<Graph> & graph)262 bool isGraphCompilable(const std::shared_ptr<Graph>& graph) {
263 for (auto input : graph->inputs()) {
264 auto const& t = input->type();
265 auto const& k = t->kind();
266 if (k != TypeKind::TensorType && k != TypeKind::FloatType &&
267 k != TypeKind::BoolType && k != TypeKind::IntType) {
268 GRAPH_DEBUG("Input %", input->debugName(), " has unsupported type ", *t);
269 return false;
270 }
271 }
272
273 for (auto n : graph->nodes()) {
274 for (auto v : n->inputs()) {
275 auto const& t = v->type();
276 if (t->kind() == TypeKind::TensorType) {
277 auto tt = t->cast<TensorType>();
278 if (!tt->isComplete()) {
279 GRAPH_DEBUG(
280 "%",
281 v->debugName(),
282 " is not a complete tensor! The type is: ",
283 *t);
284 return false;
285 }
286 }
287 }
288 for (auto v : n->outputs()) {
289 auto const& t = v->type();
290 if (t->kind() == TypeKind::TensorType) {
291 auto tt = t->cast<TensorType>();
292 if (!tt->isComplete()) {
293 GRAPH_DEBUG(
294 "%", v->debugName(), " is not a complete! The type is: ", *t);
295 return false;
296 }
297 }
298 }
299 }
300
301 // TODO: check if all nodes have lowerings
302 return true;
303 }
304
fixupTypeInfoForValue(Value * v,std::optional<at::ScalarType> scalar_type,std::optional<at::Device> device)305 static void fixupTypeInfoForValue(
306 Value* v,
307 std::optional<at::ScalarType> scalar_type,
308 std::optional<at::Device> device) {
309 Node* n = v->node();
310 auto const& t = v->type();
311 if (t->kind() != TypeKind::TensorType) {
312 return;
313 }
314
315 if (n->kind() == prim::Constant) {
316 auto const_tensor = toIValue(v)->toTensor();
317 auto concrete_tensor_type =
318 tensorTypeInCurrentExecutionContext(const_tensor);
319 v->setType(concrete_tensor_type);
320 return;
321 }
322
323 TensorTypePtr new_tt;
324 auto tt = t->cast<TensorType>();
325 auto sizes = tt->sizes();
326 if (!sizes.concrete_sizes()) {
327 GRAPH_DEBUG("No concrete sizes for %", v->debugName());
328 return;
329 }
330 auto strides = tt->strides();
331 auto dtype = tt->scalarType() ? tt->scalarType() : scalar_type;
332 auto concrete_sizes = *sizes.concrete_sizes();
333 auto concrete_strides = strides.concrete_sizes()
334 ? *strides.concrete_sizes()
335 : TensorType::contiguousStridesOf(concrete_sizes);
336 new_tt = TensorType::create(
337 dtype, device, concrete_sizes, concrete_strides, false);
338
339 v->setType(new_tt);
340 }
341
inferScalarType(Node * n)342 static std::optional<at::ScalarType> inferScalarType(Node* n) {
343 std::optional<at::ScalarType> scalar_type;
344 for (auto v : n->inputs()) {
345 auto const& t = v->type();
346 if (t->kind() == TypeKind::TensorType) {
347 auto tt = t->cast<TensorType>();
348 if (!scalar_type) {
349 scalar_type = tt->scalarType();
350 }
351 if (tt->scalarType() && *tt->scalarType() != scalar_type) {
352 GRAPH_DEBUG(
353 "Inputs of ", n, " have different scalar types, cannot fixup!");
354 return std::nullopt;
355 }
356 }
357 }
358 return scalar_type;
359 }
360
inferDevice(Node * n)361 static std::optional<at::Device> inferDevice(Node* n) {
362 std::optional<at::Device> device;
363 for (auto v : n->inputs()) {
364 auto const& t = v->type();
365 if (t->kind() == TypeKind::TensorType) {
366 auto tt = t->cast<TensorType>();
367 if (!device) {
368 device = tt->device();
369 }
370 if (tt->device() && *tt->device() != device) {
371 GRAPH_DEBUG("Inputs of ", n, " have different devices, cannot fixup!");
372 return std::nullopt;
373 }
374 }
375 }
376 if (!device) {
377 device = at::kCPU;
378 }
379 return device;
380 }
381
fixupMissingShapeInfo(const std::shared_ptr<Graph> & graph)382 void fixupMissingShapeInfo(const std::shared_ptr<Graph>& graph) {
383 for (auto input : graph->inputs()) {
384 auto const& t = input->type();
385 if (t->kind() == TypeKind::TensorType) {
386 auto tt = t->cast<TensorType>();
387 if (!tt->scalarType()) {
388 GRAPH_DEBUG("No dtype for %", input->debugName());
389 return;
390 }
391 fixupTypeInfoForValue(
392 input, tt->scalarType(), tt->device() ? tt->device() : at::kCPU);
393 }
394 }
395
396 for (auto n : graph->nodes()) {
397 std::optional<at::ScalarType> scalar_type = inferScalarType(n);
398 std::optional<at::Device> device = inferDevice(n);
399
400 for (auto v : n->outputs()) {
401 fixupTypeInfoForValue(v, scalar_type, device);
402 }
403 }
404 }
405
removeGraphOutput(const std::shared_ptr<Graph> & graph,size_t idx)406 std::shared_ptr<Graph> removeGraphOutput(
407 const std::shared_ptr<Graph>& graph,
408 size_t idx) {
409 graph->eraseOutput(idx);
410 return graph;
411 }
412
replaceListOutputWithTuple(const std::shared_ptr<Graph> & graph)413 std::shared_ptr<Graph> replaceListOutputWithTuple(
414 const std::shared_ptr<Graph>& graph) {
415 auto out = graph->outputs()[0];
416 auto out_node = out->node();
417 if (out_node->kind() != prim::ListConstruct) {
418 return graph;
419 }
420 auto tuple_node = graph->createTuple(out_node->inputs());
421 tuple_node->insertAfter(out_node);
422 out->replaceAllUsesWith(tuple_node->output());
423 return graph;
424 }
425
trimGraphOnce(const std::shared_ptr<Graph> & graph)426 static bool trimGraphOnce(const std::shared_ptr<Graph>& graph) {
427 Node* ret = graph->return_node();
428 std::unordered_set<Value*> graph_inputs(
429 graph->inputs().begin(), graph->inputs().end());
430 std::unordered_set<Value*> outputs(
431 graph->outputs().begin(), graph->outputs().end());
432 bool changed = false;
433 for (size_t idx = 0; idx < ret->inputs().size(); idx++) {
434 auto v = ret->inputs()[idx];
435 if (graph_inputs.count(v)) {
436 continue;
437 }
438 // Delete the graph output IDX and add all inputs of the node producing that
439 // value to the graph outputs
440 graph->eraseOutput(idx);
441 for (auto v_ins : v->node()->inputs()) {
442 if (outputs.count(v_ins)) {
443 continue;
444 }
445 if (v_ins->node()->kind() == prim::Constant) {
446 continue;
447 }
448
449 graph->registerOutput(v_ins);
450 }
451 changed = true;
452 break;
453 }
454 return changed;
455 }
456
dequantizeResults(const std::shared_ptr<Graph> & graph)457 static std::shared_ptr<Graph> dequantizeResults(
458 const std::shared_ptr<Graph>& graph) {
459 for (auto v : graph->outputs()) {
460 auto& t = v->type();
461 if (t->kind() == TypeKind::TensorType) {
462 auto tt = t->cast<TensorType>();
463 if (!tt->scalarType() || !c10::isQIntType(*tt->scalarType())) {
464 continue;
465 }
466 Node* deq = graph->create(aten::dequantize, {v});
467 graph->appendNode(deq);
468 deq->output()->setType(tt->withScalarType(c10::kFloat));
469 v->replaceAllUsesAfterNodeWith(deq, deq->output());
470 }
471 }
472 return graph;
473 }
474
trimGraph(const std::shared_ptr<Graph> & graph,int64_t iters)475 std::shared_ptr<Graph> trimGraph(
476 const std::shared_ptr<Graph>& graph,
477 int64_t iters) {
478 bool changed = true;
479 int64_t iter = 0;
480 while (changed && iter++ < iters) {
481 changed = trimGraphOnce(graph);
482 EliminateDeadCode(graph->block());
483 }
484 // Avoid letting quantized values to graph outputs.
485 // Ideally we should allow quantized outputs as well, but currently the main
486 // user of this pass - AOT NNC - does not support it.
487 // TODO: remove output dequantization once NNC supports quantized outputs.
488 dequantizeResults(graph);
489 return graph;
490 }
491
492 } // namespace torch::jit::tensorexpr
493