xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/symbolic_shape_analysis.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/symbol.h>
2 #include <c10/util/Exception.h>
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/ir/constants.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/ir/ir_views.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
10 #include <torch/csrc/jit/passes/constant_pooling.h>
11 #include <torch/csrc/jit/passes/constant_propagation.h>
12 #include <torch/csrc/jit/passes/dead_code_elimination.h>
13 #include <torch/csrc/jit/passes/integer_value_refinement.h>
14 #include <torch/csrc/jit/passes/loop_unrolling.h>
15 #include <torch/csrc/jit/passes/lower_tuples.h>
16 #include <torch/csrc/jit/passes/peephole.h>
17 #include <torch/csrc/jit/passes/peephole_list_idioms.h>
18 #include <torch/csrc/jit/passes/peephole_non_tensor.h>
19 #include <torch/csrc/jit/passes/remove_mutation.h>
20 #include <torch/csrc/jit/passes/shape_analysis.h>
21 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
22 #include <torch/csrc/jit/passes/symbolic_shape_cache.h>
23 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
24 #include <torch/csrc/jit/runtime/exception_message.h>
25 #include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
26 #include <algorithm>
27 #include <memory>
28 #include <unordered_map>
29 #include <utility>
30 #include <vector>
31 
32 /*
33 XXX: this is still in prototype phase and has much work left to do, including
34 but not limited to:
35 - Refactor APIs
36 - Add decent coverage of common ops
37 - Add shape analysis pass on Graph that handles Loops
38 - Allow concurrent reads to the operator map
39 - Supporting returning partially evaluated shape compute graph
40 */
41 
42 static bool symbolic_shape_analysis_test_mode = false;
43 
44 namespace torch::jit {
45 
46 // This is similar to c10::SymbolicShape, but instead of either having
47 // a concrete dimension or a symbolic dimension, an argument may be:
48 // - A Symbolic Dimension
49 // - A Constant Integer
50 // - Neither of the above. The third case can occur due to inputs to
51 // ops like view that accept negative values. Maintaining the distinction
52 // between an unknown symbolic dimension and an unknown integer allows
53 // us to optimize out comparisons to values < 0 (symbolic shapes are always >=
54 // 0) For example, a call like graph(%y: Tensor(SS(-1), 10, 10), %inp: int):
55 //   %five: int = prim::Constant[value=5]()
56 //   %zero: int = prim::Constant[value=0]()
57 //   %1 : int = aten::size(%y, %zero)
58 //   %2 : int[] = prim::ListConstruct(%five, %1, %inp)
59 //   %y.2: Tensor(5, SS(-1), (New Symbolic Shape)) = aten::view(%y, %2)
60 //
61 // x.view([5, y.size(0), inp])
62 // will have inputs equal to [5, SS(-1), std::nullopt]
63 
64 struct ShapeArg
65     : public std::
66           pair<std::optional<c10::ShapeSymbol>, std::optional<int64_t>> {
67   using pair::pair;
68 
unknownIntegertorch::jit::ShapeArg69   static ShapeArg unknownInteger() {
70     return ShapeArg();
71   }
72 
ShapeArgtorch::jit::ShapeArg73   ShapeArg(int64_t int_value) {
74     this->first = std::nullopt;
75     this->second = int_value;
76   }
77 
ShapeArgtorch::jit::ShapeArg78   ShapeArg(c10::ShapeSymbol ss) {
79     if (ss.is_static()) {
80       this->first = std::nullopt;
81       this->second = ss.value();
82     } else {
83       this->first = ss;
84       this->second = std::nullopt;
85     }
86   }
87 
asConstantInttorch::jit::ShapeArg88   std::optional<int64_t> asConstantInt() const {
89     return this->second;
90   }
91 
asShapeSymboltorch::jit::ShapeArg92   std::optional<c10::ShapeSymbol> asShapeSymbol() const {
93     return this->first;
94   }
95 
96  private:
ShapeArgtorch::jit::ShapeArg97   ShapeArg() {
98     this->first = std::nullopt;
99     this->second = std::nullopt;
100   }
101 };
102 
operator <<(std::ostream & out,const ShapeArg & sa)103 static std::ostream& operator<<(std::ostream& out, const ShapeArg& sa) {
104   if (auto val = sa.asConstantInt()) {
105     out << *val;
106   } else if (auto ss = sa.asShapeSymbol()) {
107     out << *ss;
108   } else {
109     out << "UNK";
110   }
111   return out;
112 }
113 
114 struct ShapeArguments {
115   // Superset of SymbolicShape, with additional support for unknown, nonsymbolic
116   // vals
117  public:
ShapeArgumentstorch::jit::ShapeArguments118   ShapeArguments(const c10::SymbolicShape& ss) {
119     has_dim_ = ss.rank().has_value();
120     if (has_dim_) {
121       for (size_t i = 0; i < *ss.rank(); ++i) {
122         maybe_shape_symbols_.emplace_back(ss.at(i));
123       }
124     }
125   }
126 
ShapeArgumentstorch::jit::ShapeArguments127   ShapeArguments(std::vector<ShapeArg> ss)
128       : has_dim_(true), maybe_shape_symbols_(std::move(ss)) {}
129 
has_dimtorch::jit::ShapeArguments130   bool has_dim() const {
131     return has_dim_;
132   }
133 
lentorch::jit::ShapeArguments134   int64_t len() const {
135     TORCH_INTERNAL_ASSERT(has_dim_, "ShapeArguments has no known dim")
136     return (int64_t)maybe_shape_symbols_.size();
137   }
138 
attorch::jit::ShapeArguments139   const ShapeArg at(size_t i) const {
140     TORCH_INTERNAL_ASSERT(has_dim_, "ShapeArguments has no known dim")
141     return maybe_shape_symbols_.at(i);
142   }
143 
144  private:
145   bool has_dim_;
146   std::vector<ShapeArg> maybe_shape_symbols_;
147 };
148 
operator <<(std::ostream & os,const ShapeArguments & sa)149 static std::ostream& operator<<(std::ostream& os, const ShapeArguments& sa) {
150   if (!sa.has_dim()) {
151     os << "(UNKNOWN DIM)";
152     return os;
153   }
154 
155   os << "(";
156   for (const auto i : c10::irange(sa.len())) {
157     os << sa.at(i);
158   }
159   os << ")";
160 
161   return os;
162 }
163 
setSymbolicShapeAnalysisTestMode(bool value)164 bool setSymbolicShapeAnalysisTestMode(bool value) {
165   bool old_value = symbolic_shape_analysis_test_mode;
166   symbolic_shape_analysis_test_mode = value;
167   return old_value;
168 }
169 
symbolicShapeAnalysisTestModeEnabled()170 bool symbolicShapeAnalysisTestModeEnabled() {
171   return symbolic_shape_analysis_test_mode;
172 }
173 
174 using SSArgument = std::variant<ShapeArguments, IValue>;
175 
operator <<(std::ostream & out,const SSArgument & sa)176 static std::ostream& operator<<(std::ostream& out, const SSArgument& sa) {
177   if (const IValue* iv = std::get_if<IValue>(&sa)) {
178     out << *iv;
179   } else {
180     out << std::get<ShapeArguments>(sa);
181   }
182   return out;
183 }
184 
185 namespace {
186 
isListOfInts(const TypePtr & type)187 bool isListOfInts(const TypePtr& type) {
188   return type->cast<ListType>() &&
189       type->cast<ListType>()->getElementType()->cast<IntType>();
190 }
191 
isListOfListOfInts(const TypePtr & type)192 bool isListOfListOfInts(const TypePtr& type) {
193   // Allows List[Optional[List[Int]]]
194   if (!type->cast<ListType>()) {
195     return false;
196   }
197   TypePtr element_type = type->cast<ListType>()->getElementType();
198   if (element_type->cast<OptionalType>()) {
199     element_type = element_type->cast<OptionalType>()->getElementType();
200   }
201   return isListOfInts(element_type);
202 }
203 
isListOfTensors(const TypePtr & type)204 bool isListOfTensors(const TypePtr& type) {
205   return type->cast<ListType>() &&
206       type->cast<ListType>()->getElementType()->cast<TensorType>();
207 }
208 
normIndex(int64_t index,size_t len)209 std::optional<size_t> normIndex(int64_t index, size_t len) {
210   if (index < 0) {
211     index = index + static_cast<int64_t>(len);
212   }
213   if (index >= 0 && index < static_cast<int64_t>(len)) {
214     return index;
215   } else {
216     return std::nullopt;
217   }
218 }
219 
shapeGraphCleanupPasses(std::shared_ptr<Graph> graph)220 bool shapeGraphCleanupPasses(std::shared_ptr<Graph> graph) {
221   // TODO: lower simple tuples ?
222   bool made_change = RemoveListMutation(graph);
223   made_change |= UnrollConstantLoops(graph);
224   made_change |= ConstantPropagation(graph);
225   made_change |= PeepholeOptimizeNonTensor(graph);
226   made_change |= PeepholeOptimizeListIdioms(graph, /*refine_list_len*/ true);
227   made_change |= RefineIntegerValues(graph);
228   made_change |= ConstantPropagation(graph);
229   // todo add return change for constant pooling
230   ConstantPooling(graph);
231   made_change |= EliminateCommonSubexpression(graph);
232   EliminateDeadCode(graph);
233   return made_change;
234 }
235 
replaceWithIValue(Value * v,const IValue & val)236 void replaceWithIValue(Value* v, const IValue& val) {
237   WithInsertPoint guard(*v->node()->owningBlock()->nodes().begin());
238   v->replaceAllUsesWith(v->owningGraph()->insertConstant(val));
239 }
240 
extractListShape(Value * list,std::unordered_map<Value *,int64_t> & symbolic_shape_values,const AliasDb & db)241 c10::SymbolicShape extractListShape(
242     Value* list,
243     std::unordered_map<Value*, int64_t>& symbolic_shape_values,
244     const AliasDb& db) {
245   if (list->node()->kind() == prim::Constant) {
246     auto int_list = toIValue(list)->toIntVector();
247     return c10::SymbolicShape(int_list);
248   }
249   // We need a list construct or a constant output
250   // that is not written to in order to analyze the output shape
251   if (list->node()->kind() != prim::ListConstruct || db.hasWriters(list)) {
252     GRAPH_DEBUG("Could not extract shape");
253     return c10::SymbolicShape();
254   }
255   Node* list_construct = list->node();
256   std::vector<std::optional<int64_t>> output_shape;
257   for (Value* input : list_construct->inputs()) {
258     if (symbolic_shape_values.count(input)) {
259       output_shape.emplace_back(symbolic_shape_values[input]);
260     } else {
261       output_shape.push_back(constant_as<int64_t>(input));
262     }
263   }
264   return c10::SymbolicShape(output_shape);
265 }
266 
267 // Symbolic Shape Analysis works through iteratively partially evaluating
268 // a TorchScript shape compute graph by inputting properties from input
269 // Tensors. We can substitute in properties like `len(x)` and `x[1]`
270 // if they are statically on the input Tensors. We can also use
271 // assertions like `assert len(x) == 4` in order to refine the input
272 // length and unroll loops over its elements. We iteratively optimize and
273 // substitute in properties until we are unable to make any further
274 // optimizations. Finally, we try to extract Tensor properties from the output.
275 // For instance `return [1, 2, inp[2] + 1, inp[3]]` we know that the output
276 // will be length 4 with first two dimensions equal to 1 and 2. We can also
277 // deduce that the 4th dimension has the same symbolic shape as inp[3], which
278 // means that we do know its concrete value statically but we can assign sets
279 // of tensor dimensions which must be equal at runtime.
280 
281 struct SymbolicShapeOpAnalyzer {
282   std::shared_ptr<Graph> shape_compute_graph_;
283   const FunctionSchema* schema_;
284   std::vector<SSArgument> inputs_;
285 
286   // For the case where we have a JIT graph,
287   // substitute optional types for their component types
288   // if the type is known. This doesn't need to be done
289   // for known IValues.
refineInputUnionTypestorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer290   void refineInputUnionTypes(const Node* parent_graph_node) {
291     for (size_t op_in_index = 0;
292          op_in_index < shape_compute_graph_->inputs().size();
293          op_in_index++) {
294       auto type = parent_graph_node->input(op_in_index)->type();
295       if (auto opt_type = shape_compute_graph_->inputs()
296                               .at(op_in_index)
297                               ->type()
298                               ->cast<OptionalType>()) {
299         // None will get handled with constant substitution later
300         if (!type->cast<OptionalType>() &&
301             !NoneType::get()->isSubtypeOf(*type)) {
302           shape_compute_graph_->inputs()
303               .at(op_in_index)
304               ->setType(opt_type->getElementType());
305         }
306       } else if (shape_compute_graph_->inputs()
307                      .at(op_in_index)
308                      ->type()
309                      ->cast<NumberType>()) {
310         shape_compute_graph_->inputs().at(op_in_index)->setType(type);
311       }
312     }
313   }
314 
315   // We handle non-constant values in the shape propagation step
substituteConstantInputstorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer316   void substituteConstantInputs() {
317     if (shape_compute_graph_->inputs().empty()) {
318       return;
319     }
320 
321     bool seen_tensor_list = false;
322 
323     size_t op_in_index = 0;
324     while (op_in_index < shape_compute_graph_->inputs().size()) {
325       Value* graph_in_var = shape_compute_graph_->inputs().at(op_in_index);
326       if (!isListOfListOfInts(graph_in_var->type())) {
327         op_in_index++;
328         continue;
329       }
330 
331       // Modifying the graph where _node is part of to not use the tensor
332       // construct
333 
334       // When we have partially evaluate a list of Tensors like cat(tensor[])
335       // We have a few problems:
336       // - optimizing out calls to the length of the list: len(tensors)
337       // - resolving accesses of the list to the tensor symbolic sizes the
338       // corresponding list element We can solve both of these problems by
339       // replacing the partial evaluation of cat([x, y]) def cat(tensors:
340       // List[List[int]], dim: int)
341       //    body
342       // with
343       // def cat(x, y, dim: int)
344       //     tensors = [x, y]
345       //     body
346       TORCH_INTERNAL_ASSERT(
347           !seen_tensor_list,
348           "SSA doesn't handle case with multiple tensor lists")
349       seen_tensor_list = true;
350 
351       uint64_t li_length = inputs_.size() - (schema_->arguments().size() - 1);
352       std::vector<Value*> li_inputs;
353 
354       TypePtr element_type =
355           graph_in_var->type()->cast<ListType>()->getElementType();
356       for (size_t j = op_in_index; j < op_in_index + li_length; ++j) {
357         auto new_inp = shape_compute_graph_->insertInput(op_in_index + j);
358         new_inp->setType(element_type);
359         li_inputs.push_back(new_inp);
360       }
361       WithInsertPoint guard(*shape_compute_graph_->block()->nodes().begin());
362       auto new_li = shape_compute_graph_->insertNode(
363           shape_compute_graph_->createList(element_type, li_inputs));
364       graph_in_var->replaceAllUsesWith(new_li->output());
365       shape_compute_graph_->eraseInput(op_in_index + li_length);
366     }
367 
368     TORCH_INTERNAL_ASSERT(
369         shape_compute_graph_->inputs().size() <= inputs_.size(),
370         "Shape Compute Graph expected to have less inputs than actual inputs"); //?
371 
372     for (size_t op_in_index = 0;
373          op_in_index < shape_compute_graph_->inputs().size();
374          op_in_index++) {
375       SSArgument& argument = inputs_[op_in_index];
376       Value* graph_in_var = shape_compute_graph_->inputs().at(op_in_index);
377 
378       if (IValue* cur_val = std::get_if<IValue>(&argument)) {
379         GRAPH_DEBUG("Substituting constant input ", *cur_val);
380         replaceWithIValue(graph_in_var, *cur_val);
381       } else {
382         auto cur_arg = std::get<ShapeArguments>(argument);
383         if (cur_arg.has_dim()) {
384           graph_in_var->setType(ListType::ofInts());
385         }
386       }
387     }
388   }
389 
substituteSymbolicPropertiestorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer390   void substituteSymbolicProperties(
391       std::unordered_map<Value*, int64_t>* symbolic_shape_values) {
392     // clang-format off
393     // here we iteratively substitute properties of the node's input tensors
394     // into the shape compute graph. we can substitute constants into the
395     // like len(inp) or inp[0] if the tensor has a fixed length or a fixed
396     // first dimension. we also try to resolve symbolic shapes of the same
397     // symbolic value to the same Value * in the shape compute graph.
398     // for the shape logic:
399     // dim1 = inp1[0]
400     // dim2 = inp2[0]
401     // return dim1 if dim2 == 1 else dim2
402     // if we see that inp1[0] and inp2[0] both have the same symbolic shape
403     // value, then it is a valid transformation to replace dim2 with dim1 or
404     // vice versa. to do this we collect all Value * for a particular symbolic
405     // shape. Then, we replace all Value * within that set with their dominator.
406     // In the example above, this allows us to infer  that the output will be the
407     // symbolic dimension value of dim1.
408 
409     // if `symbolic_shape_values` is not null, record list accesses
410     // which resolve to symbolic dimension values with their concrete symbolic
411     // shape value. Because symbolic dimensions are represented as negative numbers and
412     // are not real values, inserting them as constants in the graph would invalidate
413     // the graph for further use. Instead, we keep track of what their value would be
414     // for extracting output shapes.
415     // clang-format on
416 
417     std::unordered_map<int64_t, std::vector<Value*>> symbolic_shape_map;
418 
419     TORCH_INTERNAL_ASSERT(
420         inputs_.size() >= shape_compute_graph_->inputs().size(),
421         "Missing Arg for Shape Graph");
422     for (const auto index :
423          c10::irange(shape_compute_graph_->inputs().size())) {
424       auto shape_arguments = std::get_if<ShapeArguments>(&inputs_[index]);
425       if (!shape_arguments || !shape_arguments->has_dim()) {
426         continue;
427       }
428       // Add support for testing symbolic shapes with dynamic dims
429 
430       for (const Use& use : shape_compute_graph_->inputs().at(index)->uses()) {
431         // TODO: either decompose composite ops like slice or add handling here
432         switch (use.user->kind()) {
433           case aten::len: {
434             size_t len = shape_arguments->len();
435             replaceWithIValue(use.user->output(), static_cast<int64_t>(len));
436           } break;
437           case aten::__getitem__: {
438             auto index = constant_as<int64_t>(use.user->inputs().at(1));
439             if (!index) {
440               continue;
441             }
442             auto norm_index = normIndex(*index, shape_arguments->len());
443             if (!norm_index) {
444               continue;
445             }
446             auto shape_arg = shape_arguments->at(*norm_index);
447             if (auto const_int = shape_arg.asConstantInt()) {
448               replaceWithIValue(use.user->output(), const_int);
449               continue;
450             }
451             auto maybe_shape_symbol = shape_arg.asShapeSymbol();
452             if (!maybe_shape_symbol) {
453               continue;
454             }
455             auto shape_symbol = *maybe_shape_symbol;
456             if (symbolic_shape_values) {
457               symbolic_shape_values->emplace(
458                   use.user->output(), shape_symbol.value());
459             } else {
460               int64_t symbolic_index = shape_symbol.value();
461               symbolic_shape_map[symbolic_index].push_back(use.user->output());
462             }
463             for (const auto& sym_uses : use.user->output()->uses()) {
464               auto k = sym_uses.user->kind();
465               if (k != aten::ge && k != aten::le && k != aten::ne &&
466                   k != aten::eq && k != aten::lt && k != aten::gt) {
467                 break;
468               }
469               auto other_index = 1 - sym_uses.offset;
470               auto other_value =
471                   constant_as<int64_t>(sym_uses.user->input(other_index));
472               if (!other_value) {
473                 continue;
474               }
475 
476               // check for dim >= 0, 0 <= dim
477               // dim >= 0
478               if (k == aten::ge && *other_value == 0 && other_index == 1) {
479                 replaceWithIValue(sym_uses.user->output(), true);
480                 continue;
481               }
482               // 0 <= dim
483               if (k == aten::le && *other_value == 0 && other_index == 0) {
484                 replaceWithIValue(sym_uses.user->output(), true);
485                 continue;
486               }
487 
488               // check for dim comparisons to negative number
489               if (*other_value >= 0) {
490                 continue;
491               }
492               if (k == aten::eq || k == aten::ne) {
493                 // True if:
494                 // -2 != {Positive}
495                 replaceWithIValue(sym_uses.user->output(), k == aten::ne);
496               } else {
497                 // True if:
498                 // -2 <= / < {Positive}
499                 // {Positive} >= / > {-2}
500                 bool true_val =
501                     ((other_index == 0 && (k == aten::le || k == aten::lt)) ||
502                      (other_index == 1 && (k == aten::ge || k == aten::gt)));
503                 replaceWithIValue(sym_uses.user->output(), true_val);
504               }
505             }
506           }
507         }
508       }
509 
510       for (const auto& symbolic_set : symbolic_shape_map) {
511         mergeSymbolicShapeSets(symbolic_set.second);
512       }
513     }
514   }
515 
mergeSymbolicShapeSetstorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer516   void mergeSymbolicShapeSets(const std::vector<Value*>& symbolic_set) {
517     // `symbolic_set` represents a set of Value * which are all equal
518     // to each other. Here, we optimize the graph by replacing values
519     // in the set with other dominating values.
520     // in the following example, where a, b and c are all in the same
521     // symbolic set:
522     // if cond:
523     //    a = li[0]
524     //    b = li[1]
525     //    return [a, b]
526     // else:
527     //    c = li[0]
528     //    return [c, c]
529     // we can replace `b` with `a` because it is dominated by `a`,
530     // but we cannot replace `c` with another dominating value
531 
532     // there are ways to compute this more efficiently but typically number of
533     // Values for each symbolic set is low and this is cheap to run
534     for (const auto i : c10::irange(symbolic_set.size())) {
535       Value* v = symbolic_set[i];
536       Value* dominating_value = v;
537       for (const auto& sym_set : symbolic_set) {
538         if (dominating_value->node()->isDominatedBy(sym_set->node())) {
539           dominating_value = sym_set;
540         }
541       }
542       if (dominating_value != v) {
543         v->replaceAllUsesWith(dominating_value);
544       }
545     }
546   }
547 
propagateShapesInGraphtorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer548   std::vector<c10::SymbolicShape> propagateShapesInGraph() {
549     bool made_change = true;
550     constexpr size_t MAX_ATTEMPTS = 8;
551     for (unsigned attempt_num = 0; made_change && attempt_num < MAX_ATTEMPTS;
552          attempt_num++) {
553       // symbolic shape concrete values are only used in final shape extraction
554       GRAPH_DUMP("Before substitution: ", shape_compute_graph_);
555       substituteSymbolicProperties(/*symbolic_shape_values*/ nullptr);
556       GRAPH_DUMP("Before Opt: ", shape_compute_graph_);
557       made_change = shapeGraphCleanupPasses(shape_compute_graph_);
558     }
559     std::unordered_map<Value*, int64_t> symbolic_shape_values;
560     substituteSymbolicProperties(&symbolic_shape_values);
561     GRAPH_DUMP("Done with partial evaluation", shape_compute_graph_);
562 
563     return extractOutputShape(symbolic_shape_values);
564   }
565 
extractOutputShapetorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer566   std::vector<c10::SymbolicShape> extractOutputShape(
567       std::unordered_map<Value*, int64_t>& symbolic_shape_values) {
568     TORCH_INTERNAL_ASSERT(
569         shape_compute_graph_->outputs().size() == schema_->returns().size());
570     // TODO: would be nice if there were easy facility to look at uses and see
571     // if they are all pure instead of instantiating db.
572     auto res = std::vector<c10::SymbolicShape>();
573     AliasDb db(shape_compute_graph_);
574     for (size_t i = 0; i < shape_compute_graph_->outputs().size(); ++i) {
575       auto output = shape_compute_graph_->outputs().at(i);
576       auto type = output->type();
577       TORCH_INTERNAL_ASSERT(isListOfInts(type));
578       c10::SymbolicShape ss =
579           extractListShape(output, symbolic_shape_values, db);
580       GRAPH_DEBUG("Extracted Output: ", ss);
581       res.push_back(ss);
582     }
583     return res;
584   }
585 
586  public:
SymbolicShapeOpAnalyzertorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer587   SymbolicShapeOpAnalyzer(const FunctionSchema* schema) : schema_(schema) {
588     shape_compute_graph_ = nullptr;
589     if (!schema_) {
590       return;
591     }
592     auto maybe_graph = shapeComputeGraphForSchema(*schema_);
593     if (!maybe_graph) {
594       return;
595     }
596     shape_compute_graph_ = (*maybe_graph)->copy();
597   }
598 
SymbolicShapeOpAnalyzertorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer599   SymbolicShapeOpAnalyzer(
600       const FunctionSchema* schema,
601       const std::shared_ptr<Graph>& graph)
602       : schema_(schema) {
603     shape_compute_graph_ = graph->copy();
604   }
605 
runtorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer606   std::optional<std::vector<c10::SymbolicShape>> run(
607       std::vector<SSArgument>& inputs) {
608     if (!shape_compute_graph_) {
609       return std::nullopt;
610     }
611     inputs_ = inputs;
612     substituteConstantInputs();
613     GRAPH_DEBUG(inputs_)
614     return propagateShapesInGraph();
615   }
616 
getShapeComputeGraphtorch::jit::__anond0486ca20111::SymbolicShapeOpAnalyzer617   std::shared_ptr<Graph> getShapeComputeGraph() {
618     return shape_compute_graph_;
619   }
620 };
621 
tensorShapeArg(Value * tensor_v)622 SSArgument tensorShapeArg(Value* tensor_v) {
623   auto tt = tensor_v->type()->expect<TensorType>();
624   c10::SymbolicShape symbolic_shapes = tt->symbolic_sizes();
625 
626   // for testing, we don't insert complete tensor shapes and rely on our
627   // partial evaluation pipeline to propagate information.
628   // this is a good proxy for our ability to propagate non-complete shape
629   // information.
630   if (symbolic_shapes.isComplete() && !symbolic_shape_analysis_test_mode) {
631     return IValue(tt->sizes().concrete_sizes());
632   }
633   if (toIValue(tensor_v)) {
634     auto size = constant_as<at::Tensor>(tensor_v)->sizes();
635     if (!symbolic_shape_analysis_test_mode) {
636       return IValue(size);
637     } else {
638       return c10::SymbolicShape(size);
639     }
640   }
641   return symbolic_shapes;
642 }
643 
getNodeInputShapes(Node * n,const AliasDb & db)644 std::vector<SSArgument> getNodeInputShapes(Node* n, const AliasDb& db) {
645   // TODO: fix the List of integers implementation, and
646   // extract out the shape changes, otherwise this is complete
647   // NB: shape compute graphs may have less inputs than their node
648   // counterparts to allow e.g. sharing one single unary definition
649   // so iterate on # of shape inputs
650   // We make lists of Tensor inputs variadic, which results in
651   // offset between a node index and its corresponding graph index
652   std::vector<SSArgument> input_shapes = std::vector<SSArgument>();
653 
654   for (size_t node_index = 0; node_index < n->inputs().size(); ++node_index) {
655     auto type = n->input(node_index)->type();
656 
657     if (type->castRaw<TensorType>()) {
658       input_shapes.push_back(tensorShapeArg(n->input(node_index)));
659       continue;
660     }
661     if (isListOfTensors(type)) {
662       // waiting for more use cases to decide on best generalization
663       if (n->input(node_index)->node()->kind() == prim::Constant) {
664         auto ival = toIValue(n->input(node_index));
665         for (const auto& ten : ival->toTensorVector()) {
666           input_shapes.emplace_back(c10::List<int64_t>(ten.sizes()));
667         }
668       } else if (
669           n->input(node_index)->node()->kind() == prim::ListConstruct &&
670           !db.hasWriters(n->input(node_index))) {
671         auto li_construct_node = n->input(node_index)->node();
672         for (size_t j = 0; j < li_construct_node->inputs().size(); ++j) {
673           input_shapes.push_back(tensorShapeArg(li_construct_node->input(j)));
674         }
675       } else {
676         TORCH_INTERNAL_ASSERT(false, "Unhandled List, we shouldn't get here");
677       }
678       continue;
679     }
680     if (auto ival = toIValue(n->input(node_index))) {
681       input_shapes.emplace_back(*ival);
682       continue;
683     }
684     if (type->cast<ListType>() &&
685         type->cast<ListType>()->getElementType()->cast<IntType>()) {
686       auto input_src_node = n->input(node_index)->node();
687       if (input_src_node->kind() == prim::ListConstruct &&
688           !db.hasWriters(n->input(node_index))) {
689         // it is a very common in graphs to see patterns like:
690         // z = x.view(y.size())
691         // or:
692         // z = x.view(1, 10, y.size(0), y.size(1))
693         // We want to propagate symbolic dimensions and concrete sizes
694         // from y to z. To do this we try to associate symbolic dimensions
695         // or concrete sizes with the integer list inputs that have a
696         // constructor taken from constants or y.size() or y.size(0)
697         auto list_construct = n->input(node_index)->node();
698         std::vector<ShapeArg> shape;
699         for (Value* v : list_construct->inputs()) {
700           if (auto constant = constant_as<int64_t>(v)) {
701             shape.emplace_back(*constant);
702           } else if (v->node()->kind() == aten::size) {
703             auto const_index = constant_as<int64_t>(v->node()->input(1));
704             auto tt = v->node()->input(0)->type()->expect<TensorType>();
705             auto ss = tt->symbolic_sizes();
706             if (!ss.rank() || !const_index) {
707               // if we are getting a size of a tensor, it is an unknown
708               // symbolic dimension instead of an unknown integer (must be
709               // >=0)
710               shape.emplace_back(at::ShapeSymbol::newSymbol());
711               continue;
712             }
713             auto norm_index = normIndex(*const_index, *ss.rank());
714             if (!norm_index) {
715               shape.emplace_back(at::ShapeSymbol::newSymbol());
716               continue;
717             }
718             shape.emplace_back(ss[*norm_index]);
719           } else {
720             shape.emplace_back(ShapeArg::unknownInteger());
721           }
722         }
723         input_shapes.emplace_back(ShapeArguments(shape));
724         continue;
725       }
726       if (input_src_node->kind() == aten::size &&
727           !db.hasWriters(n->input(node_index))) {
728         auto ten_inp = input_src_node->input();
729         auto ss = ten_inp->type()->expect<TensorType>()->symbolic_sizes();
730         input_shapes.emplace_back(ss);
731         continue;
732       }
733     }
734     GRAPH_DEBUG(
735         "Unhandled input: ",
736         n->kind().toDisplayString(),
737         " arg num: ",
738         node_index);
739     input_shapes.emplace_back(c10::SymbolicShape());
740   }
741   TORCH_INTERNAL_ASSERT(
742       input_shapes.size() >= n->inputs().size(),
743       "input_shapes size: ",
744       input_shapes.size(),
745       " n inputs size: ",
746       n->inputs().size());
747   return input_shapes;
748 }
749 
applyOutputShapeToGraph(Node * node,const std::vector<c10::SymbolicShape> & output_shapes)750 void applyOutputShapeToGraph(
751     Node* node,
752     const std::vector<c10::SymbolicShape>& output_shapes) {
753   TORCH_INTERNAL_ASSERT(
754       node->outputs().size() == output_shapes.size(),
755       "Output shape size mismatch");
756   for (size_t i = 0; i < output_shapes.size(); ++i) {
757     auto& ss = output_shapes.at(i);
758     node->output(i)->setType(
759         node->output(i)->type()->expect<TensorType>()->withSymbolicShapes(ss));
760   }
761 }
762 
PropagateShapesWithShapeFunction(Node * n,const AliasDb & db)763 std::shared_ptr<Graph> PropagateShapesWithShapeFunction(
764     Node* n,
765     const AliasDb& db) {
766   const FunctionSchema* func_schema = n->maybeSchema();
767   if (!func_schema) {
768     return nullptr;
769   }
770   auto op_analyzer = SymbolicShapeOpAnalyzer(func_schema);
771   if (!op_analyzer.getShapeComputeGraph()) {
772     return nullptr;
773   }
774   auto input_shapes = getNodeInputShapes(n, db);
775   op_analyzer.refineInputUnionTypes(n);
776 
777   if (auto output_shapes = op_analyzer.run(input_shapes)) {
778     applyOutputShapeToGraph(n, *output_shapes);
779   }
780 
781   return op_analyzer.getShapeComputeGraph();
782 }
783 
combine_bounds(c10::SymbolicShape & lower_bound,c10::SymbolicShape & upper_bound)784 c10::SymbolicShape combine_bounds(
785     c10::SymbolicShape& lower_bound,
786     c10::SymbolicShape& upper_bound) {
787   // TODO: At some point we might want to add support for dynamic dims
788   TORCH_INTERNAL_ASSERT(lower_bound.rank() == upper_bound.rank());
789   if (lower_bound.rank() == std::nullopt) {
790     return c10::SymbolicShape();
791   }
792   std::vector<c10::ShapeSymbol> merged_shapes;
793   for (const auto i : c10::irange(*lower_bound.rank())) {
794     // TODO: Merge equivalent expressions (not needed for current use case)
795     if (lower_bound[i] == upper_bound[i]) {
796       merged_shapes.push_back(lower_bound[i]);
797     } else {
798       merged_shapes.push_back(c10::ShapeSymbol::newSymbol());
799     }
800   }
801   return c10::SymbolicShape(std::move(merged_shapes));
802 }
803 
804 struct SymbolicShapeGraphAnalyzer {
SymbolicShapeGraphAnalyzertorch::jit::__anond0486ca20111::SymbolicShapeGraphAnalyzer805   SymbolicShapeGraphAnalyzer(
806       std::shared_ptr<Graph>& graph,
807       Node* beg,
808       Node* end)
809       : graph_(graph), beg_(beg), end_(end) {
810     TORCH_INTERNAL_ASSERT(
811         beg_->owningBlock() == end_->owningBlock() && end_->isAfter(beg_));
812   }
813 
runtorch::jit::__anond0486ca20111::SymbolicShapeGraphAnalyzer814   std::optional<ShapeComputeGraphMapping> run() {
815     AliasDb db(graph_);
816     std::unordered_map<Node*, std::shared_ptr<Graph>> partial_evaluated_graphs =
817         propagateShapesAndGatherPartialEvalShapeGraphs(db);
818 
819     auto stitched_shape_compute_graph = std::make_shared<Graph>();
820     // We want to build up a computational graph which computes all shapes
821     // we dont know statically - that is, all symbolic shapes within
822     // the region [beg, end). it must be executable before beg.
823     // TODO: dont require dimensions of tensors to be set AOT ?
824 
825     for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
826       auto curr = *it;
827       if (curr->kind() == prim::Constant) {
828         continue;
829       }
830       // TODO: generalize logic to for other tensor input ops when they are
831       // added
832       if (curr->kind() == prim::ListConstruct) {
833         auto uses = curr->output()->uses();
834         if (!std::all_of(uses.begin(), uses.end(), [](const Use& use) {
835               return use.user->kind() == aten::cat;
836             })) {
837           GRAPH_DEBUG("Non cat list use ", getHeader(curr));
838           return std::nullopt;
839         }
840         continue;
841       }
842 
843       if (!partial_evaluated_graphs.count(curr)) {
844         GRAPH_DEBUG("No graph ", getHeader(curr));
845         return std::nullopt;
846       }
847 
848       auto outputs = curr->outputs();
849       for (Value* v : outputs) {
850         auto tt = v->type()->cast<TensorType>();
851         if (!tt) {
852           GRAPH_DEBUG("Non tensor node", getHeader(curr));
853           return std::nullopt;
854         }
855         auto symbolic_sizes = tt->symbolic_sizes();
856         // TODO: dont require # of dimensions of tensors set ?
857         if (!symbolic_sizes.rank()) {
858           GRAPH_DEBUG("No rank on output ", getHeader(curr));
859           return std::nullopt;
860         }
861       }
862       auto partial_eval_graph = partial_evaluated_graphs[curr];
863       joinPartialEvaluatedShapeGraphToLargeShapeGraph(
864           curr, partial_eval_graph, stitched_shape_compute_graph);
865     }
866 
867     size_t MAX_ITER = 8;
868     bool made_change = true;
869     size_t i = 0;
870     while (i < MAX_ITER && made_change) {
871       i++;
872       made_change = shapeGraphCleanupPasses(stitched_shape_compute_graph);
873     }
874 
875     // for any output that is duplicated, the symbolic shape must be equal
876     // take the symbolic shape that is generated first and get equivalent ones
877     std::unordered_map<int64_t, int64_t> discovered_sym_shape_equalities;
878     std::unordered_map<Value*, int64_t> graph_output_to_symbolic_shape_dim;
879     std::vector<size_t> erase_indices;
880 
881     for (size_t i = 0; i < stitched_shape_compute_graph->outputs().size();
882          ++i) {
883       Value* output = stitched_shape_compute_graph->outputs().at(i);
884       // this Value is already contained, so the symbolic shape for i must be
885       // equal to the symbolic shape at the existing index
886       if (graph_output_to_symbolic_shape_dim.count(output)) {
887         auto curr_sym_shape = output_index_to_symbolic_shape_[i];
888         auto existing_sym_shape = graph_output_to_symbolic_shape_dim[output];
889         discovered_sym_shape_equalities[curr_sym_shape] = existing_sym_shape;
890         erase_indices.push_back(i);
891       } else {
892         graph_output_to_symbolic_shape_dim[output] =
893             output_index_to_symbolic_shape_[i];
894       }
895     }
896     for (auto i = static_cast<int64_t>(erase_indices.size()) - 1; i >= 0; i--) {
897       stitched_shape_compute_graph->eraseOutput(erase_indices[i]);
898     }
899     for (size_t i = 0; i < stitched_shape_compute_graph->inputs().size();) {
900       if (!stitched_shape_compute_graph->inputs().at(i)->hasUses()) {
901         enclosing_graph_value_to_shape_graph_input_.erase(
902             stitched_shape_compute_graph->inputs().at(i));
903         stitched_shape_compute_graph->eraseInput(i);
904       } else {
905         ++i;
906       }
907     }
908 
909     updateGraphWithSymbolicShapeEqualities(discovered_sym_shape_equalities);
910     return ShapeComputeGraphMapping(
911         std::move(stitched_shape_compute_graph),
912         enclosing_graph_value_to_shape_graph_input_,
913         std::move(graph_output_to_symbolic_shape_dim));
914   }
915 
updateGraphWithSymbolicShapeEqualitiestorch::jit::__anond0486ca20111::SymbolicShapeGraphAnalyzer916   void updateGraphWithSymbolicShapeEqualities(
917       std::unordered_map<int64_t, int64_t>& sym_shape_equalities) {
918     for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
919       auto curr = *it;
920       for (size_t i = 0; i < curr->outputs().size(); ++i) {
921         auto output = curr->output(i);
922         auto tt = output->type()->cast<TensorType>();
923         if (!tt || !tt->symbolic_sizes().rank()) {
924           continue;
925         }
926         bool changed = false;
927         std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
928         auto new_sizes =
929             c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
930               auto value = shape.value();
931               if (sym_shape_equalities.count(value)) {
932                 changed = true;
933                 return sym_shape_equalities[value];
934               }
935               return value;
936             });
937         if (changed) {
938           output->setType(
939               tt->withSymbolicShapes(c10::SymbolicShape(new_sizes)));
940         }
941       }
942     }
943   }
944 
registerStitchedComputeOutputtorch::jit::__anond0486ca20111::SymbolicShapeGraphAnalyzer945   void registerStitchedComputeOutput(
946       const std::shared_ptr<Graph>& stitched_shape_compute_graph,
947       Value* output,
948       int64_t symbolic_shape) {
949     stitched_shape_compute_graph->registerOutput(output);
950     output_index_to_symbolic_shape_
951         [stitched_shape_compute_graph->outputs().size() - 1] = symbolic_shape;
952     symbolic_shape_value_to_graph_output_[symbolic_shape] =
953         stitched_shape_compute_graph->outputs().at(
954             stitched_shape_compute_graph->outputs().size() - 1);
955   }
956 
joinPartialEvaluatedShapeGraphToLargeShapeGraphtorch::jit::__anond0486ca20111::SymbolicShapeGraphAnalyzer957   void joinPartialEvaluatedShapeGraphToLargeShapeGraph(
958       Node* curr,
959       const std::shared_ptr<Graph>& partial_eval_graph,
960       const std::shared_ptr<Graph>& stitched_shape_compute_graph) {
961     // we are building up the large shape compute graph by iteratively
962     // combining partially evaluated individual node shape graphs.
963 
964     // We need to maintain two mappings, one from non-Tensor inputs in the
965     // enclosing graph to their equivalent mappings within the large shape
966     // compute graph, and one from symbolic shape dimension to new node output
967 
968     // When we add a new tensor node, we do two things:
969     // 1: record a mapping from the tensor node output to its shape in the
970     // partial eval graph 2: add each symbolic shape dimension that we have
971     // not already added as a output to the large shape compute graph
972 
973     // Once we are done stitching together all partial eval'd graphs, we can
974     // cleanup the graph and remove the unneeded complete shapes as outputs,
975     // leaving us only compute for calculating the runtime value of symbolic
976     // dimensions
977     // leaving us only compute for calculating the runtime value of symbolic
978     // dimensions
979 
980     std::vector<Value*> node_inputs;
981     // TODO: generalize logic
982     if (curr->kind() == aten::cat) {
983       TORCH_INTERNAL_ASSERT(
984           curr->input(0)->node()->kind() == prim::ListConstruct);
985       for (Value* v : curr->input(0)->node()->inputs()) {
986         node_inputs.push_back(v);
987       }
988       node_inputs.push_back(curr->namedInput("dim"));
989     } else {
990       for (size_t i = 0; i < partial_eval_graph->inputs().size(); ++i) {
991         node_inputs.push_back(curr->input(i));
992       }
993     }
994 
995     std::vector<Value*> partial_eval_inputs;
996     for (size_t i = 0; i < node_inputs.size(); ++i) {
997       auto node_input = node_inputs[i];
998       auto existing_graph_mapping =
999           enclosing_graph_value_to_shape_graph_input_.find(node_input);
1000       if (existing_graph_mapping !=
1001           enclosing_graph_value_to_shape_graph_input_.end()) {
1002         partial_eval_inputs.push_back(existing_graph_mapping->second);
1003       } else {
1004         Value* shape_graph_input =
1005             stitched_shape_compute_graph->addInput()->copyMetadata(
1006                 partial_eval_graph->inputs().at(i));
1007         enclosing_graph_value_to_shape_graph_input_[node_input] =
1008             shape_graph_input;
1009         partial_eval_inputs.push_back(shape_graph_input);
1010       }
1011       // make sure all symbolic dimensions in the graph we are creating are
1012       // computed in the partial eval graph
1013       if (auto tt = node_input->type()->cast<TensorType>()) {
1014         if (!tt->symbolic_sizes().rank()) {
1015           continue;
1016         }
1017         auto rank = *tt->symbolic_sizes().rank();
1018         for (size_t j = 0; j < rank; ++j) {
1019           auto shape = tt->symbolic_sizes()[j];
1020           if (shape.is_static() ||
1021               symbolic_shape_value_to_graph_output_.count(shape.value())) {
1022             continue;
1023           }
1024           auto input = enclosing_graph_value_to_shape_graph_input_[node_input];
1025           WithInsertPoint guard(stitched_shape_compute_graph->block());
1026           auto index = stitched_shape_compute_graph->insertConstant(
1027               static_cast<int64_t>(j));
1028           auto li_index = stitched_shape_compute_graph->insert(
1029               aten::__getitem__, {input, index});
1030           registerStitchedComputeOutput(
1031               stitched_shape_compute_graph, li_index, shape.value());
1032         }
1033       }
1034     }
1035 
1036     WithInsertPoint guard(stitched_shape_compute_graph->block());
1037     std::unordered_map<Value*, Value*> value_map;
1038     insertGraph(
1039         *stitched_shape_compute_graph,
1040         *partial_eval_graph,
1041         partial_eval_inputs,
1042         value_map);
1043 
1044     for (size_t i = 0; i < curr->outputs().size(); ++i) {
1045       Value* new_list_output = value_map[partial_eval_graph->outputs().at(i)];
1046       enclosing_graph_value_to_shape_graph_input_[curr->output(i)] =
1047           new_list_output;
1048 
1049       TORCH_INTERNAL_ASSERT(
1050           new_list_output->node()->kind() == prim::ListConstruct ||
1051           new_list_output->node()->kind() == prim::Constant);
1052       TORCH_INTERNAL_ASSERT(!new_list_output->node()->hasUses());
1053 
1054       auto symbolic_sizes =
1055           curr->output(i)->type()->expect<TensorType>()->symbolic_sizes();
1056       TORCH_INTERNAL_ASSERT(symbolic_sizes.rank());
1057 
1058       for (size_t i = 0; i < *symbolic_sizes.rank(); i++) {
1059         if (symbolic_sizes[i].is_static()) {
1060           continue;
1061         }
1062         int64_t symbolic_shape = symbolic_sizes[i].value();
1063         if (symbolic_shape_value_to_graph_output_.count(symbolic_shape)) {
1064           continue;
1065         }
1066         registerStitchedComputeOutput(
1067             stitched_shape_compute_graph,
1068             new_list_output->node()->input(i),
1069             symbolic_shape);
1070       }
1071     }
1072   }
1073 
1074   std::unordered_map<Node*, std::shared_ptr<Graph>>
propagateShapesAndGatherPartialEvalShapeGraphstorch::jit::__anond0486ca20111::SymbolicShapeGraphAnalyzer1075   propagateShapesAndGatherPartialEvalShapeGraphs(AliasDb& db) {
1076     std::unordered_map<Node*, std::shared_ptr<Graph>> partial_evaluated_graphs;
1077     for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
1078       auto curr = *it;
1079       if (auto maybe_graph = PropagateShapesWithShapeFunction(curr, db)) {
1080         partial_evaluated_graphs[curr] = maybe_graph;
1081       }
1082     }
1083     return partial_evaluated_graphs;
1084   }
1085 
1086   std::unordered_map<Value*, Value*>
1087       enclosing_graph_value_to_shape_graph_input_;
1088   std::unordered_map<int64_t, Value*> symbolic_shape_value_to_graph_output_;
1089   std::unordered_map<size_t, int64_t> output_index_to_symbolic_shape_;
1090 
1091   std::shared_ptr<Graph>& graph_;
1092   Node* beg_;
1093   Node* end_;
1094 };
1095 
PropagateShapesOnBlock(Block * b,const AliasDb & db)1096 void PropagateShapesOnBlock(Block* b, const AliasDb& db) {
1097   for (Node* n : b->nodes()) {
1098     // TODO: handle loop
1099     if (n->kind() == prim::If) {
1100       IfView if_v(n);
1101       PropagateShapesOnBlock(if_v.thenBlock(), db);
1102       PropagateShapesOnBlock(if_v.elseBlock(), db);
1103       mergeTypes(if_v.thenOutputs(), if_v.elseOutputs(), if_v.outputs());
1104     } else if (n->maybeSchema()) {
1105       PropagateShapesWithShapeFunction(n, db);
1106     } else if (n->kind() == prim::TupleConstruct) {
1107       auto orig_type = n->output()->type()->expect<TupleType>();
1108       auto new_types = fmap(n->inputs(), [](Value* v) { return v->type(); });
1109       n->output()->setType(
1110           orig_type->createWithContained(std::move(new_types)));
1111     }
1112   }
1113 }
1114 } // namespace
1115 
PropagateShapesOnGraph(std::shared_ptr<Graph> & graph)1116 void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph) {
1117   AliasDb db(graph);
1118   PropagateShapesOnBlock(graph->block(), db);
1119 }
1120 
1121 std::optional<ShapeComputeGraphMapping>
PropagateShapesAndBuildLargeShapeComputeGraph(std::shared_ptr<Graph> & graph,Node * beg,Node * end)1122 PropagateShapesAndBuildLargeShapeComputeGraph(
1123     std::shared_ptr<Graph>& graph,
1124     Node* beg,
1125     Node* end) {
1126   return SymbolicShapeGraphAnalyzer(graph, beg, end).run();
1127 }
1128 
1129 TORCH_API std::optional<std::vector<c10::SymbolicShape>>
calculateSymbolicShapesOnOp(const FunctionSchema * schema,const std::vector<SSAInput> & inputs)1130 calculateSymbolicShapesOnOp(
1131     const FunctionSchema* schema,
1132     const std::vector<SSAInput>& inputs) {
1133   auto bounded_graphs = boundedGraphsForSchema(*schema);
1134   auto has_shape_compute = shapeComputeGraphForSchema(*schema) != std::nullopt;
1135   if (!has_shape_compute && bounded_graphs == std::nullopt) {
1136     // Avoid doing all this work for functions that don't have a
1137     // supported schema
1138     return std::nullopt;
1139   }
1140 
1141   if (auto cached_ret_vec = get_cached_shape_function(schema, inputs)) {
1142     return cached_ret_vec;
1143   }
1144 
1145   std::vector<SSArgument> ssa_args;
1146   for (auto& arg : inputs) {
1147     if (const IValue* ival = std::get_if<IValue>(&arg)) {
1148       ssa_args.emplace_back(*ival);
1149     } else {
1150       const c10::SymbolicShape* ss = std::get_if<c10::SymbolicShape>(&arg);
1151       ssa_args.emplace_back(ShapeArguments(*ss));
1152     }
1153   }
1154   // Handle bounded shape option
1155   if (bounded_graphs) {
1156     auto lower_bound =
1157         SymbolicShapeOpAnalyzer(schema, bounded_graphs->lower_bound);
1158     auto lower_bound_res = lower_bound.run(ssa_args);
1159     auto upper_bound =
1160         SymbolicShapeOpAnalyzer(schema, bounded_graphs->upper_bound);
1161     auto upper_bound_res = upper_bound.run(ssa_args);
1162     // Stitch together the values
1163     if (lower_bound_res.has_value() && upper_bound_res.has_value()) {
1164       TORCH_INTERNAL_ASSERT(lower_bound_res->size() == upper_bound_res->size());
1165       auto merged_res = std::vector<c10::SymbolicShape>();
1166       for (size_t i = 0; i < lower_bound_res->size(); i++) {
1167         merged_res.push_back(
1168             combine_bounds(lower_bound_res->at(i), upper_bound_res->at(i)));
1169       }
1170       cache_shape_function(schema, inputs, merged_res);
1171       return merged_res;
1172     }
1173     return std::nullopt;
1174   }
1175 
1176   auto op_analyzer = SymbolicShapeOpAnalyzer(schema);
1177   auto res = op_analyzer.run(ssa_args);
1178   if (res.has_value()) {
1179     cache_shape_function(schema, inputs, res.value());
1180   }
1181   return res;
1182 }
1183 
1184 } // namespace torch::jit
1185