1 #include <ATen/core/symbol.h>
2 #include <c10/util/Exception.h>
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/ir/constants.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/ir/ir_views.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
10 #include <torch/csrc/jit/passes/constant_pooling.h>
11 #include <torch/csrc/jit/passes/constant_propagation.h>
12 #include <torch/csrc/jit/passes/dead_code_elimination.h>
13 #include <torch/csrc/jit/passes/integer_value_refinement.h>
14 #include <torch/csrc/jit/passes/loop_unrolling.h>
15 #include <torch/csrc/jit/passes/lower_tuples.h>
16 #include <torch/csrc/jit/passes/peephole.h>
17 #include <torch/csrc/jit/passes/peephole_list_idioms.h>
18 #include <torch/csrc/jit/passes/peephole_non_tensor.h>
19 #include <torch/csrc/jit/passes/remove_mutation.h>
20 #include <torch/csrc/jit/passes/shape_analysis.h>
21 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
22 #include <torch/csrc/jit/passes/symbolic_shape_cache.h>
23 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
24 #include <torch/csrc/jit/runtime/exception_message.h>
25 #include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
26 #include <algorithm>
27 #include <memory>
28 #include <unordered_map>
29 #include <utility>
30 #include <vector>
31
32 /*
33 XXX: this is still in prototype phase and has much work left to do, including
34 but not limited to:
35 - Refactor APIs
36 - Add decent coverage of common ops
37 - Add shape analysis pass on Graph that handles Loops
38 - Allow concurrent reads to the operator map
39 - Supporting returning partially evaluated shape compute graph
40 */
41
42 static bool symbolic_shape_analysis_test_mode = false;
43
44 namespace torch::jit {
45
46 // This is similar to c10::SymbolicShape, but instead of either having
47 // a concrete dimension or a symbolic dimension, an argument may be:
48 // - A Symbolic Dimension
49 // - A Constant Integer
50 // - Neither of the above. The third case can occur due to inputs to
51 // ops like view that accept negative values. Maintaining the distinction
52 // between an unknown symbolic dimension and an unknown integer allows
53 // us to optimize out comparisons to values < 0 (symbolic shapes are always >=
54 // 0) For example, a call like graph(%y: Tensor(SS(-1), 10, 10), %inp: int):
55 // %five: int = prim::Constant[value=5]()
56 // %zero: int = prim::Constant[value=0]()
57 // %1 : int = aten::size(%y, %zero)
58 // %2 : int[] = prim::ListConstruct(%five, %1, %inp)
59 // %y.2: Tensor(5, SS(-1), (New Symbolic Shape)) = aten::view(%y, %2)
60 //
61 // x.view([5, y.size(0), inp])
62 // will have inputs equal to [5, SS(-1), std::nullopt]
63
64 struct ShapeArg
65 : public std::
66 pair<std::optional<c10::ShapeSymbol>, std::optional<int64_t>> {
67 using pair::pair;
68
unknownIntegertorch::jit::ShapeArg69 static ShapeArg unknownInteger() {
70 return ShapeArg();
71 }
72
ShapeArgtorch::jit::ShapeArg73 ShapeArg(int64_t int_value) {
74 this->first = std::nullopt;
75 this->second = int_value;
76 }
77
ShapeArgtorch::jit::ShapeArg78 ShapeArg(c10::ShapeSymbol ss) {
79 if (ss.is_static()) {
80 this->first = std::nullopt;
81 this->second = ss.value();
82 } else {
83 this->first = ss;
84 this->second = std::nullopt;
85 }
86 }
87
asConstantInttorch::jit::ShapeArg88 std::optional<int64_t> asConstantInt() const {
89 return this->second;
90 }
91
asShapeSymboltorch::jit::ShapeArg92 std::optional<c10::ShapeSymbol> asShapeSymbol() const {
93 return this->first;
94 }
95
96 private:
ShapeArgtorch::jit::ShapeArg97 ShapeArg() {
98 this->first = std::nullopt;
99 this->second = std::nullopt;
100 }
101 };
102
operator <<(std::ostream & out,const ShapeArg & sa)103 static std::ostream& operator<<(std::ostream& out, const ShapeArg& sa) {
104 if (auto val = sa.asConstantInt()) {
105 out << *val;
106 } else if (auto ss = sa.asShapeSymbol()) {
107 out << *ss;
108 } else {
109 out << "UNK";
110 }
111 return out;
112 }
113
114 struct ShapeArguments {
115 // Superset of SymbolicShape, with additional support for unknown, nonsymbolic
116 // vals
117 public:
ShapeArgumentstorch::jit::ShapeArguments118 ShapeArguments(const c10::SymbolicShape& ss) {
119 has_dim_ = ss.rank().has_value();
120 if (has_dim_) {
121 for (size_t i = 0; i < *ss.rank(); ++i) {
122 maybe_shape_symbols_.emplace_back(ss.at(i));
123 }
124 }
125 }
126
ShapeArgumentstorch::jit::ShapeArguments127 ShapeArguments(std::vector<ShapeArg> ss)
128 : has_dim_(true), maybe_shape_symbols_(std::move(ss)) {}
129
has_dimtorch::jit::ShapeArguments130 bool has_dim() const {
131 return has_dim_;
132 }
133
lentorch::jit::ShapeArguments134 int64_t len() const {
135 TORCH_INTERNAL_ASSERT(has_dim_, "ShapeArguments has no known dim")
136 return (int64_t)maybe_shape_symbols_.size();
137 }
138
attorch::jit::ShapeArguments139 const ShapeArg at(size_t i) const {
140 TORCH_INTERNAL_ASSERT(has_dim_, "ShapeArguments has no known dim")
141 return maybe_shape_symbols_.at(i);
142 }
143
144 private:
145 bool has_dim_;
146 std::vector<ShapeArg> maybe_shape_symbols_;
147 };
148
operator <<(std::ostream & os,const ShapeArguments & sa)149 static std::ostream& operator<<(std::ostream& os, const ShapeArguments& sa) {
150 if (!sa.has_dim()) {
151 os << "(UNKNOWN DIM)";
152 return os;
153 }
154
155 os << "(";
156 for (const auto i : c10::irange(sa.len())) {
157 os << sa.at(i);
158 }
159 os << ")";
160
161 return os;
162 }
163
setSymbolicShapeAnalysisTestMode(bool value)164 bool setSymbolicShapeAnalysisTestMode(bool value) {
165 bool old_value = symbolic_shape_analysis_test_mode;
166 symbolic_shape_analysis_test_mode = value;
167 return old_value;
168 }
169
symbolicShapeAnalysisTestModeEnabled()170 bool symbolicShapeAnalysisTestModeEnabled() {
171 return symbolic_shape_analysis_test_mode;
172 }
173
174 using SSArgument = std::variant<ShapeArguments, IValue>;
175
operator <<(std::ostream & out,const SSArgument & sa)176 static std::ostream& operator<<(std::ostream& out, const SSArgument& sa) {
177 if (const IValue* iv = std::get_if<IValue>(&sa)) {
178 out << *iv;
179 } else {
180 out << std::get<ShapeArguments>(sa);
181 }
182 return out;
183 }
184
185 namespace {
186
isListOfInts(const TypePtr & type)187 bool isListOfInts(const TypePtr& type) {
188 return type->cast<ListType>() &&
189 type->cast<ListType>()->getElementType()->cast<IntType>();
190 }
191
isListOfListOfInts(const TypePtr & type)192 bool isListOfListOfInts(const TypePtr& type) {
193 // Allows List[Optional[List[Int]]]
194 if (!type->cast<ListType>()) {
195 return false;
196 }
197 TypePtr element_type = type->cast<ListType>()->getElementType();
198 if (element_type->cast<OptionalType>()) {
199 element_type = element_type->cast<OptionalType>()->getElementType();
200 }
201 return isListOfInts(element_type);
202 }
203
isListOfTensors(const TypePtr & type)204 bool isListOfTensors(const TypePtr& type) {
205 return type->cast<ListType>() &&
206 type->cast<ListType>()->getElementType()->cast<TensorType>();
207 }
208
normIndex(int64_t index,size_t len)209 std::optional<size_t> normIndex(int64_t index, size_t len) {
210 if (index < 0) {
211 index = index + static_cast<int64_t>(len);
212 }
213 if (index >= 0 && index < static_cast<int64_t>(len)) {
214 return index;
215 } else {
216 return std::nullopt;
217 }
218 }
219
shapeGraphCleanupPasses(std::shared_ptr<Graph> graph)220 bool shapeGraphCleanupPasses(std::shared_ptr<Graph> graph) {
221 // TODO: lower simple tuples ?
222 bool made_change = RemoveListMutation(graph);
223 made_change |= UnrollConstantLoops(graph);
224 made_change |= ConstantPropagation(graph);
225 made_change |= PeepholeOptimizeNonTensor(graph);
226 made_change |= PeepholeOptimizeListIdioms(graph, /*refine_list_len*/ true);
227 made_change |= RefineIntegerValues(graph);
228 made_change |= ConstantPropagation(graph);
229 // todo add return change for constant pooling
230 ConstantPooling(graph);
231 made_change |= EliminateCommonSubexpression(graph);
232 EliminateDeadCode(graph);
233 return made_change;
234 }
235
replaceWithIValue(Value * v,const IValue & val)236 void replaceWithIValue(Value* v, const IValue& val) {
237 WithInsertPoint guard(*v->node()->owningBlock()->nodes().begin());
238 v->replaceAllUsesWith(v->owningGraph()->insertConstant(val));
239 }
240
extractListShape(Value * list,std::unordered_map<Value *,int64_t> & symbolic_shape_values,const AliasDb & db)241 c10::SymbolicShape extractListShape(
242 Value* list,
243 std::unordered_map<Value*, int64_t>& symbolic_shape_values,
244 const AliasDb& db) {
245 if (list->node()->kind() == prim::Constant) {
246 auto int_list = toIValue(list)->toIntVector();
247 return c10::SymbolicShape(int_list);
248 }
249 // We need a list construct or a constant output
250 // that is not written to in order to analyze the output shape
251 if (list->node()->kind() != prim::ListConstruct || db.hasWriters(list)) {
252 GRAPH_DEBUG("Could not extract shape");
253 return c10::SymbolicShape();
254 }
255 Node* list_construct = list->node();
256 std::vector<std::optional<int64_t>> output_shape;
257 for (Value* input : list_construct->inputs()) {
258 if (symbolic_shape_values.count(input)) {
259 output_shape.emplace_back(symbolic_shape_values[input]);
260 } else {
261 output_shape.push_back(constant_as<int64_t>(input));
262 }
263 }
264 return c10::SymbolicShape(output_shape);
265 }
266
267 // Symbolic Shape Analysis works through iteratively partially evaluating
268 // a TorchScript shape compute graph by inputting properties from input
269 // Tensors. We can substitute in properties like `len(x)` and `x[1]`
270 // if they are statically on the input Tensors. We can also use
271 // assertions like `assert len(x) == 4` in order to refine the input
272 // length and unroll loops over its elements. We iteratively optimize and
273 // substitute in properties until we are unable to make any further
274 // optimizations. Finally, we try to extract Tensor properties from the output.
275 // For instance `return [1, 2, inp[2] + 1, inp[3]]` we know that the output
276 // will be length 4 with first two dimensions equal to 1 and 2. We can also
277 // deduce that the 4th dimension has the same symbolic shape as inp[3], which
278 // means that we do know its concrete value statically but we can assign sets
279 // of tensor dimensions which must be equal at runtime.
280
281 struct SymbolicShapeOpAnalyzer {
282 std::shared_ptr<Graph> shape_compute_graph_;
283 const FunctionSchema* schema_;
284 std::vector<SSArgument> inputs_;
285
286 // For the case where we have a JIT graph,
287 // substitute optional types for their component types
288 // if the type is known. This doesn't need to be done
289 // for known IValues.
refineInputUnionTypestorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer290 void refineInputUnionTypes(const Node* parent_graph_node) {
291 for (size_t op_in_index = 0;
292 op_in_index < shape_compute_graph_->inputs().size();
293 op_in_index++) {
294 auto type = parent_graph_node->input(op_in_index)->type();
295 if (auto opt_type = shape_compute_graph_->inputs()
296 .at(op_in_index)
297 ->type()
298 ->cast<OptionalType>()) {
299 // None will get handled with constant substitution later
300 if (!type->cast<OptionalType>() &&
301 !NoneType::get()->isSubtypeOf(*type)) {
302 shape_compute_graph_->inputs()
303 .at(op_in_index)
304 ->setType(opt_type->getElementType());
305 }
306 } else if (shape_compute_graph_->inputs()
307 .at(op_in_index)
308 ->type()
309 ->cast<NumberType>()) {
310 shape_compute_graph_->inputs().at(op_in_index)->setType(type);
311 }
312 }
313 }
314
315 // We handle non-constant values in the shape propagation step
substituteConstantInputstorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer316 void substituteConstantInputs() {
317 if (shape_compute_graph_->inputs().empty()) {
318 return;
319 }
320
321 bool seen_tensor_list = false;
322
323 size_t op_in_index = 0;
324 while (op_in_index < shape_compute_graph_->inputs().size()) {
325 Value* graph_in_var = shape_compute_graph_->inputs().at(op_in_index);
326 if (!isListOfListOfInts(graph_in_var->type())) {
327 op_in_index++;
328 continue;
329 }
330
331 // Modifying the graph where _node is part of to not use the tensor
332 // construct
333
334 // When we have partially evaluate a list of Tensors like cat(tensor[])
335 // We have a few problems:
336 // - optimizing out calls to the length of the list: len(tensors)
337 // - resolving accesses of the list to the tensor symbolic sizes the
338 // corresponding list element We can solve both of these problems by
339 // replacing the partial evaluation of cat([x, y]) def cat(tensors:
340 // List[List[int]], dim: int)
341 // body
342 // with
343 // def cat(x, y, dim: int)
344 // tensors = [x, y]
345 // body
346 TORCH_INTERNAL_ASSERT(
347 !seen_tensor_list,
348 "SSA doesn't handle case with multiple tensor lists")
349 seen_tensor_list = true;
350
351 uint64_t li_length = inputs_.size() - (schema_->arguments().size() - 1);
352 std::vector<Value*> li_inputs;
353
354 TypePtr element_type =
355 graph_in_var->type()->cast<ListType>()->getElementType();
356 for (size_t j = op_in_index; j < op_in_index + li_length; ++j) {
357 auto new_inp = shape_compute_graph_->insertInput(op_in_index + j);
358 new_inp->setType(element_type);
359 li_inputs.push_back(new_inp);
360 }
361 WithInsertPoint guard(*shape_compute_graph_->block()->nodes().begin());
362 auto new_li = shape_compute_graph_->insertNode(
363 shape_compute_graph_->createList(element_type, li_inputs));
364 graph_in_var->replaceAllUsesWith(new_li->output());
365 shape_compute_graph_->eraseInput(op_in_index + li_length);
366 }
367
368 TORCH_INTERNAL_ASSERT(
369 shape_compute_graph_->inputs().size() <= inputs_.size(),
370 "Shape Compute Graph expected to have less inputs than actual inputs"); //?
371
372 for (size_t op_in_index = 0;
373 op_in_index < shape_compute_graph_->inputs().size();
374 op_in_index++) {
375 SSArgument& argument = inputs_[op_in_index];
376 Value* graph_in_var = shape_compute_graph_->inputs().at(op_in_index);
377
378 if (IValue* cur_val = std::get_if<IValue>(&argument)) {
379 GRAPH_DEBUG("Substituting constant input ", *cur_val);
380 replaceWithIValue(graph_in_var, *cur_val);
381 } else {
382 auto cur_arg = std::get<ShapeArguments>(argument);
383 if (cur_arg.has_dim()) {
384 graph_in_var->setType(ListType::ofInts());
385 }
386 }
387 }
388 }
389
substituteSymbolicPropertiestorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer390 void substituteSymbolicProperties(
391 std::unordered_map<Value*, int64_t>* symbolic_shape_values) {
392 // clang-format off
393 // here we iteratively substitute properties of the node's input tensors
394 // into the shape compute graph. we can substitute constants into the
395 // like len(inp) or inp[0] if the tensor has a fixed length or a fixed
396 // first dimension. we also try to resolve symbolic shapes of the same
397 // symbolic value to the same Value * in the shape compute graph.
398 // for the shape logic:
399 // dim1 = inp1[0]
400 // dim2 = inp2[0]
401 // return dim1 if dim2 == 1 else dim2
402 // if we see that inp1[0] and inp2[0] both have the same symbolic shape
403 // value, then it is a valid transformation to replace dim2 with dim1 or
404 // vice versa. to do this we collect all Value * for a particular symbolic
405 // shape. Then, we replace all Value * within that set with their dominator.
406 // In the example above, this allows us to infer that the output will be the
407 // symbolic dimension value of dim1.
408
409 // if `symbolic_shape_values` is not null, record list accesses
410 // which resolve to symbolic dimension values with their concrete symbolic
411 // shape value. Because symbolic dimensions are represented as negative numbers and
412 // are not real values, inserting them as constants in the graph would invalidate
413 // the graph for further use. Instead, we keep track of what their value would be
414 // for extracting output shapes.
415 // clang-format on
416
417 std::unordered_map<int64_t, std::vector<Value*>> symbolic_shape_map;
418
419 TORCH_INTERNAL_ASSERT(
420 inputs_.size() >= shape_compute_graph_->inputs().size(),
421 "Missing Arg for Shape Graph");
422 for (const auto index :
423 c10::irange(shape_compute_graph_->inputs().size())) {
424 auto shape_arguments = std::get_if<ShapeArguments>(&inputs_[index]);
425 if (!shape_arguments || !shape_arguments->has_dim()) {
426 continue;
427 }
428 // Add support for testing symbolic shapes with dynamic dims
429
430 for (const Use& use : shape_compute_graph_->inputs().at(index)->uses()) {
431 // TODO: either decompose composite ops like slice or add handling here
432 switch (use.user->kind()) {
433 case aten::len: {
434 size_t len = shape_arguments->len();
435 replaceWithIValue(use.user->output(), static_cast<int64_t>(len));
436 } break;
437 case aten::__getitem__: {
438 auto index = constant_as<int64_t>(use.user->inputs().at(1));
439 if (!index) {
440 continue;
441 }
442 auto norm_index = normIndex(*index, shape_arguments->len());
443 if (!norm_index) {
444 continue;
445 }
446 auto shape_arg = shape_arguments->at(*norm_index);
447 if (auto const_int = shape_arg.asConstantInt()) {
448 replaceWithIValue(use.user->output(), const_int);
449 continue;
450 }
451 auto maybe_shape_symbol = shape_arg.asShapeSymbol();
452 if (!maybe_shape_symbol) {
453 continue;
454 }
455 auto shape_symbol = *maybe_shape_symbol;
456 if (symbolic_shape_values) {
457 symbolic_shape_values->emplace(
458 use.user->output(), shape_symbol.value());
459 } else {
460 int64_t symbolic_index = shape_symbol.value();
461 symbolic_shape_map[symbolic_index].push_back(use.user->output());
462 }
463 for (const auto& sym_uses : use.user->output()->uses()) {
464 auto k = sym_uses.user->kind();
465 if (k != aten::ge && k != aten::le && k != aten::ne &&
466 k != aten::eq && k != aten::lt && k != aten::gt) {
467 break;
468 }
469 auto other_index = 1 - sym_uses.offset;
470 auto other_value =
471 constant_as<int64_t>(sym_uses.user->input(other_index));
472 if (!other_value) {
473 continue;
474 }
475
476 // check for dim >= 0, 0 <= dim
477 // dim >= 0
478 if (k == aten::ge && *other_value == 0 && other_index == 1) {
479 replaceWithIValue(sym_uses.user->output(), true);
480 continue;
481 }
482 // 0 <= dim
483 if (k == aten::le && *other_value == 0 && other_index == 0) {
484 replaceWithIValue(sym_uses.user->output(), true);
485 continue;
486 }
487
488 // check for dim comparisons to negative number
489 if (*other_value >= 0) {
490 continue;
491 }
492 if (k == aten::eq || k == aten::ne) {
493 // True if:
494 // -2 != {Positive}
495 replaceWithIValue(sym_uses.user->output(), k == aten::ne);
496 } else {
497 // True if:
498 // -2 <= / < {Positive}
499 // {Positive} >= / > {-2}
500 bool true_val =
501 ((other_index == 0 && (k == aten::le || k == aten::lt)) ||
502 (other_index == 1 && (k == aten::ge || k == aten::gt)));
503 replaceWithIValue(sym_uses.user->output(), true_val);
504 }
505 }
506 }
507 }
508 }
509
510 for (const auto& symbolic_set : symbolic_shape_map) {
511 mergeSymbolicShapeSets(symbolic_set.second);
512 }
513 }
514 }
515
mergeSymbolicShapeSetstorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer516 void mergeSymbolicShapeSets(const std::vector<Value*>& symbolic_set) {
517 // `symbolic_set` represents a set of Value * which are all equal
518 // to each other. Here, we optimize the graph by replacing values
519 // in the set with other dominating values.
520 // in the following example, where a, b and c are all in the same
521 // symbolic set:
522 // if cond:
523 // a = li[0]
524 // b = li[1]
525 // return [a, b]
526 // else:
527 // c = li[0]
528 // return [c, c]
529 // we can replace `b` with `a` because it is dominated by `a`,
530 // but we cannot replace `c` with another dominating value
531
532 // there are ways to compute this more efficiently but typically number of
533 // Values for each symbolic set is low and this is cheap to run
534 for (const auto i : c10::irange(symbolic_set.size())) {
535 Value* v = symbolic_set[i];
536 Value* dominating_value = v;
537 for (const auto& sym_set : symbolic_set) {
538 if (dominating_value->node()->isDominatedBy(sym_set->node())) {
539 dominating_value = sym_set;
540 }
541 }
542 if (dominating_value != v) {
543 v->replaceAllUsesWith(dominating_value);
544 }
545 }
546 }
547
propagateShapesInGraphtorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer548 std::vector<c10::SymbolicShape> propagateShapesInGraph() {
549 bool made_change = true;
550 constexpr size_t MAX_ATTEMPTS = 8;
551 for (unsigned attempt_num = 0; made_change && attempt_num < MAX_ATTEMPTS;
552 attempt_num++) {
553 // symbolic shape concrete values are only used in final shape extraction
554 GRAPH_DUMP("Before substitution: ", shape_compute_graph_);
555 substituteSymbolicProperties(/*symbolic_shape_values*/ nullptr);
556 GRAPH_DUMP("Before Opt: ", shape_compute_graph_);
557 made_change = shapeGraphCleanupPasses(shape_compute_graph_);
558 }
559 std::unordered_map<Value*, int64_t> symbolic_shape_values;
560 substituteSymbolicProperties(&symbolic_shape_values);
561 GRAPH_DUMP("Done with partial evaluation", shape_compute_graph_);
562
563 return extractOutputShape(symbolic_shape_values);
564 }
565
extractOutputShapetorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer566 std::vector<c10::SymbolicShape> extractOutputShape(
567 std::unordered_map<Value*, int64_t>& symbolic_shape_values) {
568 TORCH_INTERNAL_ASSERT(
569 shape_compute_graph_->outputs().size() == schema_->returns().size());
570 // TODO: would be nice if there were easy facility to look at uses and see
571 // if they are all pure instead of instantiating db.
572 auto res = std::vector<c10::SymbolicShape>();
573 AliasDb db(shape_compute_graph_);
574 for (size_t i = 0; i < shape_compute_graph_->outputs().size(); ++i) {
575 auto output = shape_compute_graph_->outputs().at(i);
576 auto type = output->type();
577 TORCH_INTERNAL_ASSERT(isListOfInts(type));
578 c10::SymbolicShape ss =
579 extractListShape(output, symbolic_shape_values, db);
580 GRAPH_DEBUG("Extracted Output: ", ss);
581 res.push_back(ss);
582 }
583 return res;
584 }
585
586 public:
SymbolicShapeOpAnalyzertorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer587 SymbolicShapeOpAnalyzer(const FunctionSchema* schema) : schema_(schema) {
588 shape_compute_graph_ = nullptr;
589 if (!schema_) {
590 return;
591 }
592 auto maybe_graph = shapeComputeGraphForSchema(*schema_);
593 if (!maybe_graph) {
594 return;
595 }
596 shape_compute_graph_ = (*maybe_graph)->copy();
597 }
598
SymbolicShapeOpAnalyzertorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer599 SymbolicShapeOpAnalyzer(
600 const FunctionSchema* schema,
601 const std::shared_ptr<Graph>& graph)
602 : schema_(schema) {
603 shape_compute_graph_ = graph->copy();
604 }
605
runtorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer606 std::optional<std::vector<c10::SymbolicShape>> run(
607 std::vector<SSArgument>& inputs) {
608 if (!shape_compute_graph_) {
609 return std::nullopt;
610 }
611 inputs_ = inputs;
612 substituteConstantInputs();
613 GRAPH_DEBUG(inputs_)
614 return propagateShapesInGraph();
615 }
616
getShapeComputeGraphtorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer617 std::shared_ptr<Graph> getShapeComputeGraph() {
618 return shape_compute_graph_;
619 }
620 };
621
tensorShapeArg(Value * tensor_v)622 SSArgument tensorShapeArg(Value* tensor_v) {
623 auto tt = tensor_v->type()->expect<TensorType>();
624 c10::SymbolicShape symbolic_shapes = tt->symbolic_sizes();
625
626 // for testing, we don't insert complete tensor shapes and rely on our
627 // partial evaluation pipeline to propagate information.
628 // this is a good proxy for our ability to propagate non-complete shape
629 // information.
630 if (symbolic_shapes.isComplete() && !symbolic_shape_analysis_test_mode) {
631 return IValue(tt->sizes().concrete_sizes());
632 }
633 if (toIValue(tensor_v)) {
634 auto size = constant_as<at::Tensor>(tensor_v)->sizes();
635 if (!symbolic_shape_analysis_test_mode) {
636 return IValue(size);
637 } else {
638 return c10::SymbolicShape(size);
639 }
640 }
641 return symbolic_shapes;
642 }
643
getNodeInputShapes(Node * n,const AliasDb & db)644 std::vector<SSArgument> getNodeInputShapes(Node* n, const AliasDb& db) {
645 // TODO: fix the List of integers implementation, and
646 // extract out the shape changes, otherwise this is complete
647 // NB: shape compute graphs may have less inputs than their node
648 // counterparts to allow e.g. sharing one single unary definition
649 // so iterate on # of shape inputs
650 // We make lists of Tensor inputs variadic, which results in
651 // offset between a node index and its corresponding graph index
652 std::vector<SSArgument> input_shapes = std::vector<SSArgument>();
653
654 for (size_t node_index = 0; node_index < n->inputs().size(); ++node_index) {
655 auto type = n->input(node_index)->type();
656
657 if (type->castRaw<TensorType>()) {
658 input_shapes.push_back(tensorShapeArg(n->input(node_index)));
659 continue;
660 }
661 if (isListOfTensors(type)) {
662 // waiting for more use cases to decide on best generalization
663 if (n->input(node_index)->node()->kind() == prim::Constant) {
664 auto ival = toIValue(n->input(node_index));
665 for (const auto& ten : ival->toTensorVector()) {
666 input_shapes.emplace_back(c10::List<int64_t>(ten.sizes()));
667 }
668 } else if (
669 n->input(node_index)->node()->kind() == prim::ListConstruct &&
670 !db.hasWriters(n->input(node_index))) {
671 auto li_construct_node = n->input(node_index)->node();
672 for (size_t j = 0; j < li_construct_node->inputs().size(); ++j) {
673 input_shapes.push_back(tensorShapeArg(li_construct_node->input(j)));
674 }
675 } else {
676 TORCH_INTERNAL_ASSERT(false, "Unhandled List, we shouldn't get here");
677 }
678 continue;
679 }
680 if (auto ival = toIValue(n->input(node_index))) {
681 input_shapes.emplace_back(*ival);
682 continue;
683 }
684 if (type->cast<ListType>() &&
685 type->cast<ListType>()->getElementType()->cast<IntType>()) {
686 auto input_src_node = n->input(node_index)->node();
687 if (input_src_node->kind() == prim::ListConstruct &&
688 !db.hasWriters(n->input(node_index))) {
689 // it is a very common in graphs to see patterns like:
690 // z = x.view(y.size())
691 // or:
692 // z = x.view(1, 10, y.size(0), y.size(1))
693 // We want to propagate symbolic dimensions and concrete sizes
694 // from y to z. To do this we try to associate symbolic dimensions
695 // or concrete sizes with the integer list inputs that have a
696 // constructor taken from constants or y.size() or y.size(0)
697 auto list_construct = n->input(node_index)->node();
698 std::vector<ShapeArg> shape;
699 for (Value* v : list_construct->inputs()) {
700 if (auto constant = constant_as<int64_t>(v)) {
701 shape.emplace_back(*constant);
702 } else if (v->node()->kind() == aten::size) {
703 auto const_index = constant_as<int64_t>(v->node()->input(1));
704 auto tt = v->node()->input(0)->type()->expect<TensorType>();
705 auto ss = tt->symbolic_sizes();
706 if (!ss.rank() || !const_index) {
707 // if we are getting a size of a tensor, it is an unknown
708 // symbolic dimension instead of an unknown integer (must be
709 // >=0)
710 shape.emplace_back(at::ShapeSymbol::newSymbol());
711 continue;
712 }
713 auto norm_index = normIndex(*const_index, *ss.rank());
714 if (!norm_index) {
715 shape.emplace_back(at::ShapeSymbol::newSymbol());
716 continue;
717 }
718 shape.emplace_back(ss[*norm_index]);
719 } else {
720 shape.emplace_back(ShapeArg::unknownInteger());
721 }
722 }
723 input_shapes.emplace_back(ShapeArguments(shape));
724 continue;
725 }
726 if (input_src_node->kind() == aten::size &&
727 !db.hasWriters(n->input(node_index))) {
728 auto ten_inp = input_src_node->input();
729 auto ss = ten_inp->type()->expect<TensorType>()->symbolic_sizes();
730 input_shapes.emplace_back(ss);
731 continue;
732 }
733 }
734 GRAPH_DEBUG(
735 "Unhandled input: ",
736 n->kind().toDisplayString(),
737 " arg num: ",
738 node_index);
739 input_shapes.emplace_back(c10::SymbolicShape());
740 }
741 TORCH_INTERNAL_ASSERT(
742 input_shapes.size() >= n->inputs().size(),
743 "input_shapes size: ",
744 input_shapes.size(),
745 " n inputs size: ",
746 n->inputs().size());
747 return input_shapes;
748 }
749
applyOutputShapeToGraph(Node * node,const std::vector<c10::SymbolicShape> & output_shapes)750 void applyOutputShapeToGraph(
751 Node* node,
752 const std::vector<c10::SymbolicShape>& output_shapes) {
753 TORCH_INTERNAL_ASSERT(
754 node->outputs().size() == output_shapes.size(),
755 "Output shape size mismatch");
756 for (size_t i = 0; i < output_shapes.size(); ++i) {
757 auto& ss = output_shapes.at(i);
758 node->output(i)->setType(
759 node->output(i)->type()->expect<TensorType>()->withSymbolicShapes(ss));
760 }
761 }
762
PropagateShapesWithShapeFunction(Node * n,const AliasDb & db)763 std::shared_ptr<Graph> PropagateShapesWithShapeFunction(
764 Node* n,
765 const AliasDb& db) {
766 const FunctionSchema* func_schema = n->maybeSchema();
767 if (!func_schema) {
768 return nullptr;
769 }
770 auto op_analyzer = SymbolicShapeOpAnalyzer(func_schema);
771 if (!op_analyzer.getShapeComputeGraph()) {
772 return nullptr;
773 }
774 auto input_shapes = getNodeInputShapes(n, db);
775 op_analyzer.refineInputUnionTypes(n);
776
777 if (auto output_shapes = op_analyzer.run(input_shapes)) {
778 applyOutputShapeToGraph(n, *output_shapes);
779 }
780
781 return op_analyzer.getShapeComputeGraph();
782 }
783
combine_bounds(c10::SymbolicShape & lower_bound,c10::SymbolicShape & upper_bound)784 c10::SymbolicShape combine_bounds(
785 c10::SymbolicShape& lower_bound,
786 c10::SymbolicShape& upper_bound) {
787 // TODO: At some point we might want to add support for dynamic dims
788 TORCH_INTERNAL_ASSERT(lower_bound.rank() == upper_bound.rank());
789 if (lower_bound.rank() == std::nullopt) {
790 return c10::SymbolicShape();
791 }
792 std::vector<c10::ShapeSymbol> merged_shapes;
793 for (const auto i : c10::irange(*lower_bound.rank())) {
794 // TODO: Merge equivalent expressions (not needed for current use case)
795 if (lower_bound[i] == upper_bound[i]) {
796 merged_shapes.push_back(lower_bound[i]);
797 } else {
798 merged_shapes.push_back(c10::ShapeSymbol::newSymbol());
799 }
800 }
801 return c10::SymbolicShape(std::move(merged_shapes));
802 }
803
804 struct SymbolicShapeGraphAnalyzer {
SymbolicShapeGraphAnalyzertorch::jit::__anond0486ca20111::SymbolicShapeGraphAnalyzer805 SymbolicShapeGraphAnalyzer(
806 std::shared_ptr<Graph>& graph,
807 Node* beg,
808 Node* end)
809 : graph_(graph), beg_(beg), end_(end) {
810 TORCH_INTERNAL_ASSERT(
811 beg_->owningBlock() == end_->owningBlock() && end_->isAfter(beg_));
812 }
813
runtorch::jit::__anond0486ca20111::SymbolicShapeGraphAnalyzer814 std::optional<ShapeComputeGraphMapping> run() {
815 AliasDb db(graph_);
816 std::unordered_map<Node*, std::shared_ptr<Graph>> partial_evaluated_graphs =
817 propagateShapesAndGatherPartialEvalShapeGraphs(db);
818
819 auto stitched_shape_compute_graph = std::make_shared<Graph>();
820 // We want to build up a computational graph which computes all shapes
821 // we dont know statically - that is, all symbolic shapes within
822 // the region [beg, end). it must be executable before beg.
823 // TODO: dont require dimensions of tensors to be set AOT ?
824
825 for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
826 auto curr = *it;
827 if (curr->kind() == prim::Constant) {
828 continue;
829 }
830 // TODO: generalize logic to for other tensor input ops when they are
831 // added
832 if (curr->kind() == prim::ListConstruct) {
833 auto uses = curr->output()->uses();
834 if (!std::all_of(uses.begin(), uses.end(), [](const Use& use) {
835 return use.user->kind() == aten::cat;
836 })) {
837 GRAPH_DEBUG("Non cat list use ", getHeader(curr));
838 return std::nullopt;
839 }
840 continue;
841 }
842
843 if (!partial_evaluated_graphs.count(curr)) {
844 GRAPH_DEBUG("No graph ", getHeader(curr));
845 return std::nullopt;
846 }
847
848 auto outputs = curr->outputs();
849 for (Value* v : outputs) {
850 auto tt = v->type()->cast<TensorType>();
851 if (!tt) {
852 GRAPH_DEBUG("Non tensor node", getHeader(curr));
853 return std::nullopt;
854 }
855 auto symbolic_sizes = tt->symbolic_sizes();
856 // TODO: dont require # of dimensions of tensors set ?
857 if (!symbolic_sizes.rank()) {
858 GRAPH_DEBUG("No rank on output ", getHeader(curr));
859 return std::nullopt;
860 }
861 }
862 auto partial_eval_graph = partial_evaluated_graphs[curr];
863 joinPartialEvaluatedShapeGraphToLargeShapeGraph(
864 curr, partial_eval_graph, stitched_shape_compute_graph);
865 }
866
867 size_t MAX_ITER = 8;
868 bool made_change = true;
869 size_t i = 0;
870 while (i < MAX_ITER && made_change) {
871 i++;
872 made_change = shapeGraphCleanupPasses(stitched_shape_compute_graph);
873 }
874
875 // for any output that is duplicated, the symbolic shape must be equal
876 // take the symbolic shape that is generated first and get equivalent ones
877 std::unordered_map<int64_t, int64_t> discovered_sym_shape_equalities;
878 std::unordered_map<Value*, int64_t> graph_output_to_symbolic_shape_dim;
879 std::vector<size_t> erase_indices;
880
881 for (size_t i = 0; i < stitched_shape_compute_graph->outputs().size();
882 ++i) {
883 Value* output = stitched_shape_compute_graph->outputs().at(i);
884 // this Value is already contained, so the symbolic shape for i must be
885 // equal to the symbolic shape at the existing index
886 if (graph_output_to_symbolic_shape_dim.count(output)) {
887 auto curr_sym_shape = output_index_to_symbolic_shape_[i];
888 auto existing_sym_shape = graph_output_to_symbolic_shape_dim[output];
889 discovered_sym_shape_equalities[curr_sym_shape] = existing_sym_shape;
890 erase_indices.push_back(i);
891 } else {
892 graph_output_to_symbolic_shape_dim[output] =
893 output_index_to_symbolic_shape_[i];
894 }
895 }
896 for (auto i = static_cast<int64_t>(erase_indices.size()) - 1; i >= 0; i--) {
897 stitched_shape_compute_graph->eraseOutput(erase_indices[i]);
898 }
899 for (size_t i = 0; i < stitched_shape_compute_graph->inputs().size();) {
900 if (!stitched_shape_compute_graph->inputs().at(i)->hasUses()) {
901 enclosing_graph_value_to_shape_graph_input_.erase(
902 stitched_shape_compute_graph->inputs().at(i));
903 stitched_shape_compute_graph->eraseInput(i);
904 } else {
905 ++i;
906 }
907 }
908
909 updateGraphWithSymbolicShapeEqualities(discovered_sym_shape_equalities);
910 return ShapeComputeGraphMapping(
911 std::move(stitched_shape_compute_graph),
912 enclosing_graph_value_to_shape_graph_input_,
913 std::move(graph_output_to_symbolic_shape_dim));
914 }
915
updateGraphWithSymbolicShapeEqualitiestorch::jit::__anond0486ca20111::SymbolicShapeGraphAnalyzer916 void updateGraphWithSymbolicShapeEqualities(
917 std::unordered_map<int64_t, int64_t>& sym_shape_equalities) {
918 for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
919 auto curr = *it;
920 for (size_t i = 0; i < curr->outputs().size(); ++i) {
921 auto output = curr->output(i);
922 auto tt = output->type()->cast<TensorType>();
923 if (!tt || !tt->symbolic_sizes().rank()) {
924 continue;
925 }
926 bool changed = false;
927 std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
928 auto new_sizes =
929 c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
930 auto value = shape.value();
931 if (sym_shape_equalities.count(value)) {
932 changed = true;
933 return sym_shape_equalities[value];
934 }
935 return value;
936 });
937 if (changed) {
938 output->setType(
939 tt->withSymbolicShapes(c10::SymbolicShape(new_sizes)));
940 }
941 }
942 }
943 }
944
registerStitchedComputeOutputtorch::jit::__anond0486ca20111::SymbolicShapeGraphAnalyzer945 void registerStitchedComputeOutput(
946 const std::shared_ptr<Graph>& stitched_shape_compute_graph,
947 Value* output,
948 int64_t symbolic_shape) {
949 stitched_shape_compute_graph->registerOutput(output);
950 output_index_to_symbolic_shape_
951 [stitched_shape_compute_graph->outputs().size() - 1] = symbolic_shape;
952 symbolic_shape_value_to_graph_output_[symbolic_shape] =
953 stitched_shape_compute_graph->outputs().at(
954 stitched_shape_compute_graph->outputs().size() - 1);
955 }
956
joinPartialEvaluatedShapeGraphToLargeShapeGraphtorch::jit::__anond0486ca20111::SymbolicShapeGraphAnalyzer957 void joinPartialEvaluatedShapeGraphToLargeShapeGraph(
958 Node* curr,
959 const std::shared_ptr<Graph>& partial_eval_graph,
960 const std::shared_ptr<Graph>& stitched_shape_compute_graph) {
961 // we are building up the large shape compute graph by iteratively
962 // combining partially evaluated individual node shape graphs.
963
964 // We need to maintain two mappings, one from non-Tensor inputs in the
965 // enclosing graph to their equivalent mappings within the large shape
966 // compute graph, and one from symbolic shape dimension to new node output
967
968 // When we add a new tensor node, we do two things:
969 // 1: record a mapping from the tensor node output to its shape in the
970 // partial eval graph 2: add each symbolic shape dimension that we have
971 // not already added as a output to the large shape compute graph
972
973 // Once we are done stitching together all partial eval'd graphs, we can
974 // cleanup the graph and remove the unneeded complete shapes as outputs,
975 // leaving us only compute for calculating the runtime value of symbolic
976 // dimensions
977 // leaving us only compute for calculating the runtime value of symbolic
978 // dimensions
979
980 std::vector<Value*> node_inputs;
981 // TODO: generalize logic
982 if (curr->kind() == aten::cat) {
983 TORCH_INTERNAL_ASSERT(
984 curr->input(0)->node()->kind() == prim::ListConstruct);
985 for (Value* v : curr->input(0)->node()->inputs()) {
986 node_inputs.push_back(v);
987 }
988 node_inputs.push_back(curr->namedInput("dim"));
989 } else {
990 for (size_t i = 0; i < partial_eval_graph->inputs().size(); ++i) {
991 node_inputs.push_back(curr->input(i));
992 }
993 }
994
995 std::vector<Value*> partial_eval_inputs;
996 for (size_t i = 0; i < node_inputs.size(); ++i) {
997 auto node_input = node_inputs[i];
998 auto existing_graph_mapping =
999 enclosing_graph_value_to_shape_graph_input_.find(node_input);
1000 if (existing_graph_mapping !=
1001 enclosing_graph_value_to_shape_graph_input_.end()) {
1002 partial_eval_inputs.push_back(existing_graph_mapping->second);
1003 } else {
1004 Value* shape_graph_input =
1005 stitched_shape_compute_graph->addInput()->copyMetadata(
1006 partial_eval_graph->inputs().at(i));
1007 enclosing_graph_value_to_shape_graph_input_[node_input] =
1008 shape_graph_input;
1009 partial_eval_inputs.push_back(shape_graph_input);
1010 }
1011 // make sure all symbolic dimensions in the graph we are creating are
1012 // computed in the partial eval graph
1013 if (auto tt = node_input->type()->cast<TensorType>()) {
1014 if (!tt->symbolic_sizes().rank()) {
1015 continue;
1016 }
1017 auto rank = *tt->symbolic_sizes().rank();
1018 for (size_t j = 0; j < rank; ++j) {
1019 auto shape = tt->symbolic_sizes()[j];
1020 if (shape.is_static() ||
1021 symbolic_shape_value_to_graph_output_.count(shape.value())) {
1022 continue;
1023 }
1024 auto input = enclosing_graph_value_to_shape_graph_input_[node_input];
1025 WithInsertPoint guard(stitched_shape_compute_graph->block());
1026 auto index = stitched_shape_compute_graph->insertConstant(
1027 static_cast<int64_t>(j));
1028 auto li_index = stitched_shape_compute_graph->insert(
1029 aten::__getitem__, {input, index});
1030 registerStitchedComputeOutput(
1031 stitched_shape_compute_graph, li_index, shape.value());
1032 }
1033 }
1034 }
1035
1036 WithInsertPoint guard(stitched_shape_compute_graph->block());
1037 std::unordered_map<Value*, Value*> value_map;
1038 insertGraph(
1039 *stitched_shape_compute_graph,
1040 *partial_eval_graph,
1041 partial_eval_inputs,
1042 value_map);
1043
1044 for (size_t i = 0; i < curr->outputs().size(); ++i) {
1045 Value* new_list_output = value_map[partial_eval_graph->outputs().at(i)];
1046 enclosing_graph_value_to_shape_graph_input_[curr->output(i)] =
1047 new_list_output;
1048
1049 TORCH_INTERNAL_ASSERT(
1050 new_list_output->node()->kind() == prim::ListConstruct ||
1051 new_list_output->node()->kind() == prim::Constant);
1052 TORCH_INTERNAL_ASSERT(!new_list_output->node()->hasUses());
1053
1054 auto symbolic_sizes =
1055 curr->output(i)->type()->expect<TensorType>()->symbolic_sizes();
1056 TORCH_INTERNAL_ASSERT(symbolic_sizes.rank());
1057
1058 for (size_t i = 0; i < *symbolic_sizes.rank(); i++) {
1059 if (symbolic_sizes[i].is_static()) {
1060 continue;
1061 }
1062 int64_t symbolic_shape = symbolic_sizes[i].value();
1063 if (symbolic_shape_value_to_graph_output_.count(symbolic_shape)) {
1064 continue;
1065 }
1066 registerStitchedComputeOutput(
1067 stitched_shape_compute_graph,
1068 new_list_output->node()->input(i),
1069 symbolic_shape);
1070 }
1071 }
1072 }
1073
1074 std::unordered_map<Node*, std::shared_ptr<Graph>>
propagateShapesAndGatherPartialEvalShapeGraphstorch::jit::__anond0486ca20111::SymbolicShapeGraphAnalyzer1075 propagateShapesAndGatherPartialEvalShapeGraphs(AliasDb& db) {
1076 std::unordered_map<Node*, std::shared_ptr<Graph>> partial_evaluated_graphs;
1077 for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
1078 auto curr = *it;
1079 if (auto maybe_graph = PropagateShapesWithShapeFunction(curr, db)) {
1080 partial_evaluated_graphs[curr] = maybe_graph;
1081 }
1082 }
1083 return partial_evaluated_graphs;
1084 }
1085
1086 std::unordered_map<Value*, Value*>
1087 enclosing_graph_value_to_shape_graph_input_;
1088 std::unordered_map<int64_t, Value*> symbolic_shape_value_to_graph_output_;
1089 std::unordered_map<size_t, int64_t> output_index_to_symbolic_shape_;
1090
1091 std::shared_ptr<Graph>& graph_;
1092 Node* beg_;
1093 Node* end_;
1094 };
1095
PropagateShapesOnBlock(Block * b,const AliasDb & db)1096 void PropagateShapesOnBlock(Block* b, const AliasDb& db) {
1097 for (Node* n : b->nodes()) {
1098 // TODO: handle loop
1099 if (n->kind() == prim::If) {
1100 IfView if_v(n);
1101 PropagateShapesOnBlock(if_v.thenBlock(), db);
1102 PropagateShapesOnBlock(if_v.elseBlock(), db);
1103 mergeTypes(if_v.thenOutputs(), if_v.elseOutputs(), if_v.outputs());
1104 } else if (n->maybeSchema()) {
1105 PropagateShapesWithShapeFunction(n, db);
1106 } else if (n->kind() == prim::TupleConstruct) {
1107 auto orig_type = n->output()->type()->expect<TupleType>();
1108 auto new_types = fmap(n->inputs(), [](Value* v) { return v->type(); });
1109 n->output()->setType(
1110 orig_type->createWithContained(std::move(new_types)));
1111 }
1112 }
1113 }
1114 } // namespace
1115
PropagateShapesOnGraph(std::shared_ptr<Graph> & graph)1116 void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph) {
1117 AliasDb db(graph);
1118 PropagateShapesOnBlock(graph->block(), db);
1119 }
1120
1121 std::optional<ShapeComputeGraphMapping>
PropagateShapesAndBuildLargeShapeComputeGraph(std::shared_ptr<Graph> & graph,Node * beg,Node * end)1122 PropagateShapesAndBuildLargeShapeComputeGraph(
1123 std::shared_ptr<Graph>& graph,
1124 Node* beg,
1125 Node* end) {
1126 return SymbolicShapeGraphAnalyzer(graph, beg, end).run();
1127 }
1128
1129 TORCH_API std::optional<std::vector<c10::SymbolicShape>>
calculateSymbolicShapesOnOp(const FunctionSchema * schema,const std::vector<SSAInput> & inputs)1130 calculateSymbolicShapesOnOp(
1131 const FunctionSchema* schema,
1132 const std::vector<SSAInput>& inputs) {
1133 auto bounded_graphs = boundedGraphsForSchema(*schema);
1134 auto has_shape_compute = shapeComputeGraphForSchema(*schema) != std::nullopt;
1135 if (!has_shape_compute && bounded_graphs == std::nullopt) {
1136 // Avoid doing all this work for functions that don't have a
1137 // supported schema
1138 return std::nullopt;
1139 }
1140
1141 if (auto cached_ret_vec = get_cached_shape_function(schema, inputs)) {
1142 return cached_ret_vec;
1143 }
1144
1145 std::vector<SSArgument> ssa_args;
1146 for (auto& arg : inputs) {
1147 if (const IValue* ival = std::get_if<IValue>(&arg)) {
1148 ssa_args.emplace_back(*ival);
1149 } else {
1150 const c10::SymbolicShape* ss = std::get_if<c10::SymbolicShape>(&arg);
1151 ssa_args.emplace_back(ShapeArguments(*ss));
1152 }
1153 }
1154 // Handle bounded shape option
1155 if (bounded_graphs) {
1156 auto lower_bound =
1157 SymbolicShapeOpAnalyzer(schema, bounded_graphs->lower_bound);
1158 auto lower_bound_res = lower_bound.run(ssa_args);
1159 auto upper_bound =
1160 SymbolicShapeOpAnalyzer(schema, bounded_graphs->upper_bound);
1161 auto upper_bound_res = upper_bound.run(ssa_args);
1162 // Stitch together the values
1163 if (lower_bound_res.has_value() && upper_bound_res.has_value()) {
1164 TORCH_INTERNAL_ASSERT(lower_bound_res->size() == upper_bound_res->size());
1165 auto merged_res = std::vector<c10::SymbolicShape>();
1166 for (size_t i = 0; i < lower_bound_res->size(); i++) {
1167 merged_res.push_back(
1168 combine_bounds(lower_bound_res->at(i), upper_bound_res->at(i)));
1169 }
1170 cache_shape_function(schema, inputs, merged_res);
1171 return merged_res;
1172 }
1173 return std::nullopt;
1174 }
1175
1176 auto op_analyzer = SymbolicShapeOpAnalyzer(schema);
1177 auto res = op_analyzer.run(ssa_args);
1178 if (res.has_value()) {
1179 cache_shape_function(schema, inputs, res.value());
1180 }
1181 return res;
1182 }
1183
1184 } // namespace torch::jit
1185