xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/functional.h>
2 #include <ATen/core/interned_strings.h>
3 #include <c10/core/MemoryFormat.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/ir/ir_views.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
10 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
11 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
12 #include <torch/csrc/jit/runtime/graph_iterator.h>
13 #include <torch/csrc/jit/runtime/register_ops_utils.h>
14 #include <torch/csrc/jit/runtime/static/ops.h>
15 #include <sstream>
16 #include <utility>
17 
18 namespace torch::jit {
19 
20 // Inserts the Compute for Each Symbolic Shape in the TensorExpr Graph
21 // and returns back a map from Symbolic Shape Value to its runtime Value *
InsertSymbolicShapesCompute(const ShapeComputeGraphMapping & shape_mapping,Node * tensorexpr_graph)22 static std::map<int64_t, Value*> InsertSymbolicShapesCompute(
23     const ShapeComputeGraphMapping& shape_mapping,
24     Node* tensorexpr_graph) {
25   WithInsertPoint guard(tensorexpr_graph);
26   auto enclosing_graph = tensorexpr_graph->owningGraph();
27 
28   std::map<Value*, Value*> shape_graph_input_to_enclosing_graph_value;
29   for (const auto& pair :
30        shape_mapping.enclosing_graph_value_to_shape_graph_input_) {
31     shape_graph_input_to_enclosing_graph_value[pair.second] = pair.first;
32   }
33   std::vector<Value*> shape_compute_graph_inputs;
34   for (Value* shape_graph_input :
35        shape_mapping.partial_eval_shape_graph->inputs()) {
36     auto enclosing_graph_input =
37         shape_graph_input_to_enclosing_graph_value.find(shape_graph_input);
38     TORCH_INTERNAL_ASSERT(
39         enclosing_graph_input !=
40         shape_graph_input_to_enclosing_graph_value.end());
41     if (*enclosing_graph_input->second->type() == *shape_graph_input->type()) {
42       shape_compute_graph_inputs.push_back(tensorexpr_graph->inputs().at(
43           enclosing_graph_input->second->offset()));
44     } else {
45       TORCH_INTERNAL_ASSERT(
46           enclosing_graph_input->second->type()->cast<TensorType>() &&
47           shape_graph_input->type()->isSubtypeOf(ListType::ofInts()));
48       shape_compute_graph_inputs.push_back(enclosing_graph->insert(
49           aten::size,
50           {tensorexpr_graph->inputs().at(
51               enclosing_graph_input->second->offset())}));
52     }
53   }
54   auto sym_shape_values = insertGraph(
55       *enclosing_graph,
56       *shape_mapping.partial_eval_shape_graph,
57       shape_compute_graph_inputs);
58   std::map<int64_t, Value*> sym_shape_to_enclosing_graph_value;
59   for (size_t i = 0;
60        i < shape_mapping.partial_eval_shape_graph->outputs().size();
61        ++i) {
62     Value* output = shape_mapping.partial_eval_shape_graph->outputs().at(i);
63     auto sym_shape =
64         shape_mapping.graph_output_to_symbolic_shape_dim_.find(output);
65     TORCH_INTERNAL_ASSERT(
66         sym_shape != shape_mapping.graph_output_to_symbolic_shape_dim_.end());
67     sym_shape_to_enclosing_graph_value[sym_shape->second] = sym_shape_values[i];
68   }
69   return sym_shape_to_enclosing_graph_value;
70 }
71 
72 void insertDynamicShapesGuard(
73     const ShapeComputeGraphMapping& shape_mapping,
74     Node* guarded_node,
75     bool add_composed_op,
76     std::vector<std::vector<StrideInput>>& input_info,
77     std::vector<StrideInput>& output_strides);
78 
toString(StrideInput si)79 std::string toString(StrideInput si) {
80   switch (si) {
81     case StrideInput::TENSOR_CONT:
82       return "TENSOR_CONT";
83     case StrideInput::TENSOR_CONT_CHANNELS_LAST:
84       return "TENSOR_CONT_CHANNELS_LAST";
85     case StrideInput::S_ONE:
86       return "S_ONE";
87     case StrideInput::S_CONT:
88       return "S_CONT";
89     case StrideInput::S_TRAN_CONT:
90       return "S_TRAN_CONT";
91     case StrideInput::S_AS_ARG:
92       return "S_AS_ARG";
93   }
94   TORCH_INTERNAL_ASSERT(false);
95 }
96 
strideInputFromString(const std::string & si)97 StrideInput strideInputFromString(const std::string& si) {
98   if (si == "TENSOR_CONT") {
99     return StrideInput::TENSOR_CONT;
100   } else if (si == "TENSOR_CONT_CHANNELS_LAST") {
101     return StrideInput::TENSOR_CONT_CHANNELS_LAST;
102   } else if (si == "S_ONE") {
103     return StrideInput::S_ONE;
104   } else if (si == "S_CONT") {
105     return StrideInput::S_CONT;
106   } else if (si == "S_TRAN_CONT") {
107     return StrideInput::S_TRAN_CONT;
108   } else if (si == "S_AS_ARG") {
109     return StrideInput::S_AS_ARG;
110   } else {
111     TORCH_INTERNAL_ASSERT(false);
112   }
113 }
114 
115 // in the runtime guard, strides are serialized as one flat
116 // vector. stride_inputs_offset indexes into that vector
117 // where the strides of this tensor begin
summarizeStrideDim(const c10::IntArrayRef sizes,const c10::IntArrayRef strides,size_t dim,const std::vector<StrideInput> & stride_inputs,size_t stride_inputs_offset)118 inline StrideInput summarizeStrideDim(
119     const c10::IntArrayRef sizes,
120     const c10::IntArrayRef strides,
121     size_t dim,
122     const std::vector<StrideInput>& stride_inputs,
123     size_t stride_inputs_offset) {
124   if (strides[dim] == 1) {
125     return StrideInput::S_ONE;
126   } else if (
127       dim + 1 < sizes.size() &&
128       strides[dim] == strides[dim + 1] * sizes[dim + 1]) {
129     return StrideInput::S_CONT;
130     // Transposed Contiguous depends on prior dim and contiguous depends on next
131     // dim, so to avoid a mutual dependence check that the next dim is Stride
132     // Contiguous
133   } else if (
134       dim > 0 && strides[dim] == strides[dim - 1] * sizes[dim - 1] &&
135       (stride_inputs[dim - 1 + stride_inputs_offset] != StrideInput::S_CONT)) {
136     return StrideInput::S_TRAN_CONT;
137   } else {
138     return StrideInput::S_AS_ARG;
139   }
140 }
141 
summarizeInputStrides(const TensorType & tt)142 static std::vector<StrideInput> summarizeInputStrides(const TensorType& tt) {
143   auto strides = *tt.strides().concrete_sizes();
144   auto sizes = *tt.sizes().concrete_sizes();
145   if (c10::is_contiguous_strides(sizes, strides)) {
146     return {StrideInput::TENSOR_CONT};
147     // TODO: channels last 3d
148   } else if (c10::is_channels_last_strides_2d(sizes, strides)) {
149     return {StrideInput::TENSOR_CONT_CHANNELS_LAST};
150   }
151   std::vector<StrideInput> stride_inputs;
152   for (size_t dim = 0; dim < sizes.size(); ++dim) {
153     stride_inputs.push_back(
154         summarizeStrideDim(sizes, strides, dim, stride_inputs, 0));
155   }
156   return stride_inputs;
157 };
158 
159 // Todo: incorporate in codegen
summarizeOutputStrides(const TensorType & tt)160 static StrideInput summarizeOutputStrides(const TensorType& tt) {
161   auto strides = *tt.strides().concrete_sizes();
162   auto sizes = *tt.sizes().concrete_sizes();
163   // We only try to maintain output striding for channels last tensors,
164   // otherwise we defer to contiguous
165   // TODO: channels last 3d
166   if (c10::is_channels_last_strides_2d(sizes, strides)) {
167     return StrideInput::TENSOR_CONT_CHANNELS_LAST;
168   }
169   return StrideInput::TENSOR_CONT;
170 }
171 
172 // Generalize Complete Shapes inputs to Symbolic Shapes.
173 // Dimensions of value 1 will be preserved, otherwise
174 // dimensions with the same value will be bucketed to the same
175 // symbolic shape.
176 // E.g. Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1)
177 // Also summarize input striding behavior. The Size information is stored on the
178 // type, The striding is returned. See StrideInput for description of stride
179 // specializations
180 static std::optional<std::vector<std::vector<StrideInput>>>
TryGeneralizeInputDimensionsToSymbolicShapes(const std::shared_ptr<Graph> & tensorexpr_graph)181 TryGeneralizeInputDimensionsToSymbolicShapes(
182     const std::shared_ptr<Graph>& tensorexpr_graph) {
183   std::map<size_t, int64_t> shape_to_sym_shape;
184   std::vector<std::vector<StrideInput>> input_striding;
185 
186   for (Value* v : tensorexpr_graph->inputs()) {
187     if (!v->type()->cast<TensorType>()) {
188       continue;
189     }
190     auto tt = v->type()->expectRef<TensorType>();
191     if (!tt.sizes().isComplete() || !tt.strides().isComplete()) {
192       return std::nullopt;
193     }
194     input_striding.push_back(summarizeInputStrides(tt));
195     std::vector<at::ShapeSymbol> shape_vec = *tt.symbolic_sizes().sizes();
196     auto new_sizes = c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
197       auto value = shape.value();
198       TORCH_INTERNAL_ASSERT(value >= 0, "Expected complete tensor");
199       if (value == 1) {
200         return value;
201       } else if (shape_to_sym_shape.count(static_cast<size_t>(value))) {
202         return shape_to_sym_shape[value];
203       } else {
204         auto new_shape_symbol = at::ShapeSymbol::newSymbol().value();
205         shape_to_sym_shape[static_cast<size_t>(value)] = new_shape_symbol;
206         return new_shape_symbol;
207       }
208     });
209     v->setType(tt.withSymbolicShapes(c10::SymbolicShape(new_sizes)));
210   }
211   return input_striding;
212 }
213 
moveConstantTensorsOutOfSubgraph(Node * tensorexpr_graph_node,const std::shared_ptr<Graph> & tensorexpr_graph)214 static void moveConstantTensorsOutOfSubgraph(
215     Node* tensorexpr_graph_node,
216     const std::shared_ptr<Graph>& tensorexpr_graph) {
217   auto parent = tensorexpr_graph_node->owningGraph();
218 
219   auto env = [&](Value* v) {
220     TORCH_INTERNAL_ASSERT(
221         false,
222         "this should never happen since constant nodes do not have any inputs",
223         v->debugName());
224     return v;
225   };
226 
227   WithInsertPoint wip(tensorexpr_graph_node);
228   std::vector<Node*> to_destroy;
229   for (auto node : tensorexpr_graph->nodes()) {
230     if (node->kind() == prim::Constant) {
231       if (!node->output()->type()->cast<TensorType>()) {
232         continue;
233       }
234 
235       // copy the constant and insert that copy into the parent graph.
236       auto copy = parent->createClone(node, env);
237       parent->insertNode(copy);
238 
239       // add a new input to the te subgraph and replace the uses of the
240       // constant with this input.
241       auto new_const = tensorexpr_graph->addInput();
242       new_const->setType(node->output()->type());
243       node->output()->replaceAllUsesWith(new_const);
244 
245       // add the copy as input to the te node
246       tensorexpr_graph_node->addInput(copy->output());
247 
248       to_destroy.push_back(node);
249     }
250   }
251 
252   for (auto n : to_destroy) {
253     n->destroy();
254   }
255 }
256 
GenerateGuard(Node * tensorexpr_graph_node,bool add_composed_op)257 bool GenerateGuard(Node* tensorexpr_graph_node, bool add_composed_op) {
258   auto tensorexpr_graph = SubgraphUtils::getSubgraph(tensorexpr_graph_node);
259 
260   // Move constant tensors from the subgraph to the outer scope.
261   // This is necessary because symbolic shape analysis does not handle the
262   // case of broadcast(constant, symbolic_shape) well and that results in poor
263   // performance.
264   moveConstantTensorsOutOfSubgraph(tensorexpr_graph_node, tensorexpr_graph);
265 
266   // Generalize Inputs
267   auto input_striding =
268       TryGeneralizeInputDimensionsToSymbolicShapes(tensorexpr_graph);
269   if (!input_striding) {
270     return false;
271   }
272 
273   // Get output striding behavior
274   std::vector<StrideInput> output_striding;
275   for (Value* v : tensorexpr_graph->outputs()) {
276     if (!v->type()->cast<TensorType>()) {
277       continue;
278     }
279     auto tt = v->type()->expectRef<TensorType>();
280     if (!tt.sizes().isComplete() || !tt.strides().isComplete()) {
281       return false;
282     }
283     output_striding.push_back(summarizeOutputStrides(tt));
284   }
285 
286   // Try To Propagate Shapes
287   auto maybe_shape_compute_mapping =
288       PropagateShapesAndBuildLargeShapeComputeGraph(
289           tensorexpr_graph,
290           *tensorexpr_graph->nodes().begin(),
291           *tensorexpr_graph->nodes().end());
292   if (!maybe_shape_compute_mapping) {
293     return false;
294   }
295 
296   // Insert Guard
297   insertDynamicShapesGuard(
298       *maybe_shape_compute_mapping,
299       tensorexpr_graph_node,
300       add_composed_op,
301       *input_striding,
302       output_striding);
303   return true;
304 }
305 
inlineFallbackGraphAndAddSRCopyOutOp(std::shared_ptr<Graph> graph)306 static void inlineFallbackGraphAndAddSRCopyOutOp(std::shared_ptr<Graph> graph) {
307   DepthFirstGraphNodeIterator it(graph);
308 
309   Node* n = nullptr;
310   while ((n = it.next()) != nullptr) {
311     if (n->kind() == prim::FallbackGraph) {
312       break;
313     }
314   }
315   TORCH_INTERNAL_ASSERT(n != nullptr, "Expected to find fallback graph");
316 
317   auto if_node = n->owningBlock()->owningNode();
318   IfView if_v(if_node);
319   SubgraphUtils::unmergeSubgraph(n);
320 
321   auto false_block = if_v.elseBlock();
322   std::vector<Value*> false_block_outputs(
323       if_v.elseOutputs().begin(), if_v.elseOutputs().end());
324   TORCH_INTERNAL_ASSERT(!false_block_outputs.empty());
325 
326   for (auto out : false_block_outputs) {
327     TORCH_INTERNAL_ASSERT(out->type()->cast<TensorType>());
328   }
329   auto copy_node = graph->create(
330       prim::StaticRuntimeCopyOuts,
331       false_block_outputs,
332       false_block_outputs.size());
333   false_block->appendNode(copy_node);
334   for (size_t i = 0; i < false_block_outputs.size(); ++i) {
335     false_block->replaceOutput(i, copy_node->outputs().at(i));
336   }
337 }
338 
339 // TODO: share more logic with tensorexpr_fuser ?
insertDynamicShapesGuard(const ShapeComputeGraphMapping & shape_mapping,Node * guarded_node,bool add_composed_op,std::vector<std::vector<StrideInput>> & input_info,std::vector<StrideInput> & output_strides)340 void insertDynamicShapesGuard(
341     const ShapeComputeGraphMapping& shape_mapping,
342     Node* guarded_node,
343     bool add_composed_op,
344     std::vector<std::vector<StrideInput>>& input_info,
345     std::vector<StrideInput>& output_strides) {
346   GRAPH_DEBUG(
347       "Inserting a prim::TensorExprDynamicGuard guard for a node",
348       *guarded_node);
349   auto subgraph = SubgraphUtils::getSubgraph(guarded_node);
350 
351   // Fixup types of the subgraph inputs
352   std::vector<Value*> inputs_to_check;
353   std::vector<TypePtr> guard_types;
354   for (const auto i : c10::irange(guarded_node->inputs().size())) {
355     Value* node_input = guarded_node->inputs().at(i);
356     // We only check inputs of the guarded nodes
357     if (!node_input->type()->cast<TensorType>()) {
358       continue;
359     }
360     inputs_to_check.push_back(node_input);
361     guard_types.emplace_back(
362         subgraph->inputs().at(i)->type()->expect<TensorType>()->withStrides(
363             c10::VaryingShape<c10::Stride>()));
364   }
365   TORCH_INTERNAL_ASSERT(inputs_to_check.size());
366 
367   // prim::TensorExprDynamicGuard nodes look like the following:
368   //   %types_match : bool = prim::TypeCheck[attr:types](%inp1 : Tensor, %inp2 :
369   //   Tensor)
370   // The input tensors are checked against the expected types on attr::types
371   // Omitting refining the input Tensors for now because they are not actually
372   // used within tensorexpr/kernel.cpp (only the inputs to the Graph are, not
373   // the inputs to the node) and we would have to redo the mapping to compute
374   // symbolic shapes
375 
376   Node* typecheck_node =
377       guarded_node->owningGraph()
378           ->create(Symbol::prim("TensorExprDynamicGuard"), inputs_to_check, 1)
379           ->insertBefore(guarded_node);
380 
381   typecheck_node->tys_(attr::types, std::move(guard_types));
382   Value* typecheck_result = typecheck_node->output()->setType(BoolType::get());
383 
384   // Insert if
385   auto versioning_if =
386       guarded_node->owningGraph()
387           ->create(prim::If, {typecheck_result}, guarded_node->outputs().size())
388           ->insertAfter(typecheck_node);
389 
390   for (size_t idx = 0; idx < guarded_node->outputs().size(); ++idx) {
391     versioning_if->output(idx)->setType(guarded_node->output(idx)->type());
392     guarded_node->output(idx)->replaceAllUsesWith(versioning_if->output(idx));
393   }
394   auto true_block = versioning_if->addBlock();
395   auto false_block = versioning_if->addBlock();
396 
397   // Fill in the false block. It should contain the unoptimized
398   // copy of the fused subgraph.
399   WithInsertPoint guard(false_block->return_node());
400   const auto subgraph_outputs = insertGraph(
401       *guarded_node->owningGraph(), *subgraph, guarded_node->inputs());
402   for (Value* output : subgraph_outputs) {
403     false_block->registerOutput(output);
404   }
405 
406   // types get copied to the fallback graph, so remove specializations before
407   // replacing
408   removeTensorTypeSpecializations(false_block);
409   replaceBlockWithFallbackGraph(false_block, guarded_node->inputs());
410 
411   // Fill in the true block. It has all inputs type-checked and its
412   // body should be the fusion group node.
413   guarded_node->moveBefore(true_block->return_node());
414 
415   for (Value* output : guarded_node->outputs()) {
416     true_block->registerOutput(output);
417   }
418 
419   // Insert Symbolic Shapes Compute and add as inputs to TE Node/Graph
420   // symbolic_shape_inputs will be a list of each symbolic shape,
421   // and the last N inputs to TE Graph/Node will be the N
422   // symbolic shape values
423   auto map = InsertSymbolicShapesCompute(shape_mapping, guarded_node);
424   std::vector<int64_t> symbolic_shape_inputs;
425   for (const auto& pair : map) {
426     symbolic_shape_inputs.push_back(pair.first);
427     guarded_node->addInput(pair.second);
428     std::stringstream ss;
429     ss << "SS_" << -pair.first;
430     subgraph->addInput(ss.str())->setType(IntType::get());
431   }
432   guarded_node->is_(
433       attr::symbolic_shape_inputs, std::move(symbolic_shape_inputs));
434 
435   std::vector<std::vector<std::string>> input_striding;
436   for (auto& vec : input_info) {
437     auto string_info =
438         fmap(vec, [&](StrideInput inp) { return toString(inp); });
439     input_striding.push_back(string_info);
440   }
441   auto ival = IValue(input_striding);
442   guarded_node->ival_(attr::striding_inputs_desc, ival);
443   typecheck_node->ival_(attr::striding_inputs_desc, std::move(ival));
444 
445   for (Value* v : subgraph->inputs()) {
446     if (auto t = v->type()->cast<TensorType>()) {
447       v->setType(t->withStrides(c10::VaryingShape<c10::Stride>()));
448     }
449   }
450   for (Value* v : subgraph->outputs()) {
451     if (auto t = v->type()->cast<TensorType>()) {
452       v->setType(t->withStrides(c10::VaryingShape<c10::Stride>()));
453     }
454   }
455 
456   std::vector<std::string> output_striding =
457       fmap(output_strides, [&](StrideInput inp) { return toString(inp); });
458   auto output_ival = IValue(output_striding);
459   guarded_node->ival_(attr::striding_outputs_desc, std::move(output_ival));
460 
461   if (add_composed_op) {
462     // only in SR flow do we check for values on the stack and
463     // forward them along as tensor outputs
464     // TODO: - refactor and make explicit part of TE Kernel api
465     guarded_node->i_(attr::allow_stack_outputs, 1);
466 
467     // Create a TensorExprDynamicGroup node
468     auto te_dyn_group = SubgraphUtils::createSingletonSubgraph(
469         typecheck_node, prim::TensorExprDynamicGroup);
470     SubgraphUtils::mergeNodeIntoSubgraph(versioning_if, te_dyn_group);
471     inlineFallbackGraphAndAddSRCopyOutOp(
472         SubgraphUtils::getSubgraph(te_dyn_group));
473   }
474 }
475 
476 // This operator is inserted at the end of the fallback block computing outputs
477 // for the fusion group. We convert block1():
478 //   %14 : Tensor = aten::mul(%0, %1)
479 //   %15 : Tensor = aten::mul(%0, %14)
480 //   -> (%15, %14)
481 // return (%3, %4)
482 // to
483 // block1():
484 //   %14 : Tensor = aten::mul(%0, %1)
485 //   %15 : Tensor = aten::mul(%0, %14)
486 //   %16 : Tensor, %17 : Tensor = prim::StaticRuntimeCopyOuts(%15, %14)
487 //   -> (%16, %17)
488 // Every output of the block is added as an input, and for each input there is
489 // a StaticRuntimeCopyOuts output. SR invokes the composed operator first with
490 // no tensors on the stack, in which case the Op will just return back the
491 // inputs. Second it invokes it with pre-allocated tensors, one for each output
492 // of the Fusion group, which is the same number of outputs of the fallback
493 // block. In this case we copy over the values of the inputs to pre-allocated
494 // tensors
495 // Note: this logic is meant to reflect the invocation of the TE Kernel
496 // and `runWithAllocatedOutputs` in tensorexpr_fuser.cpp
StaticRuntimeCopyOuts(const Node * node)497 static Operation StaticRuntimeCopyOuts(const Node* node) {
498   auto num_ten_inputs = node->inputs().size();
499   return [num_ten_inputs](Stack& stack) {
500     std::vector<IValue> inputs = pop(stack, num_ten_inputs);
501     // uncommon case - first run
502     if (stack.empty()) {
503       for (IValue elem : inputs) {
504         push(stack, std::move(elem));
505       }
506     } else {
507       at::ArrayRef<IValue> outputs = last(stack, num_ten_inputs);
508       for (size_t i = 0; i < inputs.size(); ++i) {
509         IValue out = outputs[i];
510         at::Tensor& out_t = out.toTensor();
511         fastResizeToZero(out_t);
512         out_t.resize_as_(inputs[i].toTensor());
513         out_t.copy_(inputs[i].toTensor());
514       }
515     }
516     return 0;
517   };
518 }
519 
520 RegisterOperators SRCopyOuts({
521     torch::jit::Operator(
522         prim::StaticRuntimeCopyOuts,
523         StaticRuntimeCopyOuts,
524         AliasAnalysisKind::CONSERVATIVE),
525 });
526 
527 // On each invocation of this guard, we need to check all of the static
528 // information (dtype/device/requires grad/contiguity/static dims),
529 // and also the that the symbolic shape dimensions are observed.
530 // For any symbolic dimension we need to set its value on its first
531 // use and for all subsequent uses check that the values are equal
532 RegisterOperators reg_guard({
533     Operator(
534         "prim::TensorExprDynamicGuard(...) -> bool",
__anonc321d2750602(const Node* node) 535         [](const Node* node) -> Operation {
536           const auto& types = node->tys(attr::types);
537 
538           // Each inputs expected # of dims
539           std::vector<size_t> expected_dims;
540 
541           // A flattened vector of all the expected values for all
542           // tensor dims. A positive value corresponds to a static
543           // shape to check and a negative value corresponds to symbolic
544           // dimension index to check
545           std::vector<int64_t> flattened_input_dims;
546 
547           // Each inputs expected scalar types
548           std::vector<c10::ScalarType> expected_scalar_types;
549 
550           // Map from symbolic dimension value to its set's index
551           std::map<int64_t, size_t> sym_dim_flat_index;
552           TORCH_INTERNAL_ASSERT(!types.empty());
553 
554           // we should just be fusing fusion groups with a single device
555           // and with tensors not requiring grad
556           auto maybe_device = types[0]->expect<TensorType>()->device();
557           TORCH_INTERNAL_ASSERT(maybe_device);
558           auto device = *maybe_device;
559 
560           // flattened vector of each inputs striding behavior
561           std::vector<StrideInput> flattened_input_striding;
562           const IValue& sym_strides = node->ival(attr::striding_inputs_desc);
563           std::vector<std::vector<std::string>> sym_strides_strs =
564               sym_strides.to<std::vector<std::vector<std::string>>>();
565           for (const auto& vec : sym_strides_strs) {
566             std::vector<StrideInput> input_desc;
567             for (const std::string& str : vec) {
568               flattened_input_striding.push_back(strideInputFromString(str));
569             }
570           }
571 
572           for (const auto& type : types) {
573             auto tt = type->expect<TensorType>();
574             auto ss = tt->symbolic_sizes();
575             TORCH_INTERNAL_ASSERT(ss.rank());
576             expected_dims.push_back(*ss.rank());
577             TORCH_INTERNAL_ASSERT(tt->scalarType());
578             expected_scalar_types.push_back(*tt->scalarType());
579             TORCH_INTERNAL_ASSERT(tt->device() && *tt->device() == device);
580             for (size_t i = 0; i < *ss.rank(); ++i) {
581               auto sym_dim = ss[i];
582               auto value = sym_dim.value();
583               if (value >= 0) {
584                 flattened_input_dims.push_back(value);
585               } else {
586                 // use index for set if it exists, otherwise extend the vector
587                 // of sym shapes by 1
588                 size_t sym_dim_index = 0;
589                 if (sym_dim_flat_index.count(value)) {
590                   sym_dim_index = sym_dim_flat_index[value];
591                 } else {
592                   auto size = sym_dim_flat_index.size();
593                   sym_dim_flat_index[value] = (-1) - size;
594                   sym_dim_index = sym_dim_flat_index[value];
595                 }
596                 // TODO: potential optimization - if there is a Symbolic
597                 // Sym with only one use we dont need to test anything
598                 flattened_input_dims.push_back(
599                     static_cast<int64_t>(sym_dim_index));
600               }
601             }
602           }
603 
604           const auto num_inputs = types.size();
605           const auto num_symbolic_dims = sym_dim_flat_index.size();
606           return [num_inputs,
607                   expected_dims,
608                   device,
609                   expected_scalar_types,
610                   flattened_input_dims,
611                   flattened_input_striding,
612                   num_symbolic_dims](Stack& stack) {
613             at::ArrayRef<IValue> inputs = last(stack, num_inputs);
614             drop(stack, num_inputs);
615             // each invocation we need to reset what value of each symbolic
616             // symbol is.
617             // TODO: could this be a reference and not allocated on
618             // each invocation or would that mess up with multithreaded
619             // inference since we are writing to it?
620             // TODO - smallvector here ?
621             bool grad_mode_enabled = at::GradMode::is_enabled();
622             std::vector<int64_t> flattened_symbolic_dims(num_symbolic_dims, -1);
623             size_t flattened_dim_offset = 0;
624             size_t flattened_stride_offset = 0;
625             for (const auto i : c10::irange(num_inputs)) {
626               at::Tensor tensor = inputs[i].toTensor();
627               if (C10_UNLIKELY(
628                       tensor.device() != device ||
629                       tensor.dtype() != expected_scalar_types[i])) {
630                 push(stack, false);
631                 return;
632               }
633               if (C10_UNLIKELY(grad_mode_enabled && tensor.requires_grad())) {
634                 push(stack, false);
635                 return;
636               }
637               const auto& sizes = tensor.sizes();
638               const auto num_dims = sizes.size();
639               if (C10_UNLIKELY(num_dims != expected_dims[i])) {
640                 push(stack, false);
641                 return;
642               }
643               auto striding = flattened_input_striding[flattened_stride_offset];
644               // Tensors natively store whether they are contiguous
645               // in the default memory format or in channels last,
646               // so it is more efficient to query whether they follow this
647               // property than iterating over dimensions and checking yourself
648               if (striding == StrideInput::TENSOR_CONT) {
649                 if (C10_UNLIKELY(
650                         !tensor.is_contiguous(at::MemoryFormat::Contiguous))) {
651                   push(stack, false);
652                   return;
653                 }
654                 flattened_stride_offset += 1;
655               } else if (striding == StrideInput::TENSOR_CONT_CHANNELS_LAST) {
656                 // TODO: 5D channels last
657                 if (C10_UNLIKELY(!tensor.is_contiguous(
658                         at::MemoryFormat::ChannelsLast))) {
659                   push(stack, false);
660                   return;
661                 }
662                 flattened_stride_offset += 1;
663               } else {
664                 auto strides = tensor.strides();
665                 for (size_t dim = 0; dim < num_dims; ++dim) {
666                   auto summarized_dim = summarizeStrideDim(
667                       sizes,
668                       strides,
669                       dim,
670                       flattened_input_striding,
671                       flattened_stride_offset);
672                   if (C10_UNLIKELY(
673                           summarized_dim !=
674                           flattened_input_striding
675                               [dim + flattened_stride_offset])) {
676                     push(stack, false);
677                     return;
678                   }
679                 }
680                 flattened_stride_offset += num_dims;
681               }
682               for (const auto dim_index : c10::irange(num_dims)) {
683                 const auto dim_value =
684                     flattened_input_dims[dim_index + flattened_dim_offset];
685                 const int64_t tensor_dim = sizes[dim_index];
686                 if (dim_value >= 0) {
687                   if (C10_UNLIKELY(dim_value != tensor_dim)) {
688                     push(stack, false);
689                     return;
690                   }
691                 } else {
692                   // flattened sym indices start at -1,
693                   // so -1 -> index 0, -2 -> index 1
694                   const auto flattened_sym_index = (-dim_value) - 1;
695                   const auto flattened_sym_value =
696                       flattened_symbolic_dims[flattened_sym_index];
697                   // sym symbol already seen, check value
698                   if (flattened_symbolic_dims[flattened_sym_index] >= 0) {
699                     if (C10_UNLIKELY(flattened_sym_value != tensor_dim)) {
700                       push(stack, false);
701                       return;
702                     }
703                   } else {
704                     // not seen, write value
705                     flattened_symbolic_dims[flattened_sym_index] = tensor_dim;
706                   }
707                 }
708               }
709               flattened_dim_offset += num_dims;
710             }
711 
712             push(stack, true);
713             return;
714           };
715         },
716         aliasAnalysisFromSchema()),
717 });
718 
runTensorExprDynamicGroup(const Code & code,Stack & stack)719 void runTensorExprDynamicGroup(const Code& code, Stack& stack) {
720   InterpreterState interpreter{code};
721   interpreter.run(stack);
722 }
723 
createTensorExprDynamicGroup(const Node * node)724 static Operation createTensorExprDynamicGroup(const Node* node) {
725   const auto& graph = node->g(attr::Subgraph);
726   Code code(graph, "");
727   // This implementation creates a Code object and InterpreterState on every
728   // call to TensorExprDynamicGroup, which affects performance. Ideally, we
729   // should be reusing Code and InterpreterState across calls to this op.
730   // But that is resulting in a "No frames found" error.
731   // TODO: Improve the performance of this by figuring out a better approach.
732   // NB: this is only run in SR, which is single-threaded
733   return [code](Stack& stack) {
734     runTensorExprDynamicGroup(code, stack);
735     return 0;
736   };
737 }
738 
739 RegisterOperators TensorExprDynamicOp({
740     torch::jit::Operator(
741         prim::TensorExprDynamicGroup,
742         createTensorExprDynamicGroup,
743         AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
744 });
745 
746 } // namespace torch::jit
747