1 #include <ATen/core/functional.h>
2 #include <ATen/core/interned_strings.h>
3 #include <c10/core/MemoryFormat.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/ir/ir_views.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
10 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
11 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
12 #include <torch/csrc/jit/runtime/graph_iterator.h>
13 #include <torch/csrc/jit/runtime/register_ops_utils.h>
14 #include <torch/csrc/jit/runtime/static/ops.h>
15 #include <sstream>
16 #include <utility>
17
18 namespace torch::jit {
19
20 // Inserts the Compute for Each Symbolic Shape in the TensorExpr Graph
21 // and returns back a map from Symbolic Shape Value to its runtime Value *
InsertSymbolicShapesCompute(const ShapeComputeGraphMapping & shape_mapping,Node * tensorexpr_graph)22 static std::map<int64_t, Value*> InsertSymbolicShapesCompute(
23 const ShapeComputeGraphMapping& shape_mapping,
24 Node* tensorexpr_graph) {
25 WithInsertPoint guard(tensorexpr_graph);
26 auto enclosing_graph = tensorexpr_graph->owningGraph();
27
28 std::map<Value*, Value*> shape_graph_input_to_enclosing_graph_value;
29 for (const auto& pair :
30 shape_mapping.enclosing_graph_value_to_shape_graph_input_) {
31 shape_graph_input_to_enclosing_graph_value[pair.second] = pair.first;
32 }
33 std::vector<Value*> shape_compute_graph_inputs;
34 for (Value* shape_graph_input :
35 shape_mapping.partial_eval_shape_graph->inputs()) {
36 auto enclosing_graph_input =
37 shape_graph_input_to_enclosing_graph_value.find(shape_graph_input);
38 TORCH_INTERNAL_ASSERT(
39 enclosing_graph_input !=
40 shape_graph_input_to_enclosing_graph_value.end());
41 if (*enclosing_graph_input->second->type() == *shape_graph_input->type()) {
42 shape_compute_graph_inputs.push_back(tensorexpr_graph->inputs().at(
43 enclosing_graph_input->second->offset()));
44 } else {
45 TORCH_INTERNAL_ASSERT(
46 enclosing_graph_input->second->type()->cast<TensorType>() &&
47 shape_graph_input->type()->isSubtypeOf(ListType::ofInts()));
48 shape_compute_graph_inputs.push_back(enclosing_graph->insert(
49 aten::size,
50 {tensorexpr_graph->inputs().at(
51 enclosing_graph_input->second->offset())}));
52 }
53 }
54 auto sym_shape_values = insertGraph(
55 *enclosing_graph,
56 *shape_mapping.partial_eval_shape_graph,
57 shape_compute_graph_inputs);
58 std::map<int64_t, Value*> sym_shape_to_enclosing_graph_value;
59 for (size_t i = 0;
60 i < shape_mapping.partial_eval_shape_graph->outputs().size();
61 ++i) {
62 Value* output = shape_mapping.partial_eval_shape_graph->outputs().at(i);
63 auto sym_shape =
64 shape_mapping.graph_output_to_symbolic_shape_dim_.find(output);
65 TORCH_INTERNAL_ASSERT(
66 sym_shape != shape_mapping.graph_output_to_symbolic_shape_dim_.end());
67 sym_shape_to_enclosing_graph_value[sym_shape->second] = sym_shape_values[i];
68 }
69 return sym_shape_to_enclosing_graph_value;
70 }
71
72 void insertDynamicShapesGuard(
73 const ShapeComputeGraphMapping& shape_mapping,
74 Node* guarded_node,
75 bool add_composed_op,
76 std::vector<std::vector<StrideInput>>& input_info,
77 std::vector<StrideInput>& output_strides);
78
toString(StrideInput si)79 std::string toString(StrideInput si) {
80 switch (si) {
81 case StrideInput::TENSOR_CONT:
82 return "TENSOR_CONT";
83 case StrideInput::TENSOR_CONT_CHANNELS_LAST:
84 return "TENSOR_CONT_CHANNELS_LAST";
85 case StrideInput::S_ONE:
86 return "S_ONE";
87 case StrideInput::S_CONT:
88 return "S_CONT";
89 case StrideInput::S_TRAN_CONT:
90 return "S_TRAN_CONT";
91 case StrideInput::S_AS_ARG:
92 return "S_AS_ARG";
93 }
94 TORCH_INTERNAL_ASSERT(false);
95 }
96
strideInputFromString(const std::string & si)97 StrideInput strideInputFromString(const std::string& si) {
98 if (si == "TENSOR_CONT") {
99 return StrideInput::TENSOR_CONT;
100 } else if (si == "TENSOR_CONT_CHANNELS_LAST") {
101 return StrideInput::TENSOR_CONT_CHANNELS_LAST;
102 } else if (si == "S_ONE") {
103 return StrideInput::S_ONE;
104 } else if (si == "S_CONT") {
105 return StrideInput::S_CONT;
106 } else if (si == "S_TRAN_CONT") {
107 return StrideInput::S_TRAN_CONT;
108 } else if (si == "S_AS_ARG") {
109 return StrideInput::S_AS_ARG;
110 } else {
111 TORCH_INTERNAL_ASSERT(false);
112 }
113 }
114
115 // in the runtime guard, strides are serialized as one flat
116 // vector. stride_inputs_offset indexes into that vector
117 // where the strides of this tensor begin
summarizeStrideDim(const c10::IntArrayRef sizes,const c10::IntArrayRef strides,size_t dim,const std::vector<StrideInput> & stride_inputs,size_t stride_inputs_offset)118 inline StrideInput summarizeStrideDim(
119 const c10::IntArrayRef sizes,
120 const c10::IntArrayRef strides,
121 size_t dim,
122 const std::vector<StrideInput>& stride_inputs,
123 size_t stride_inputs_offset) {
124 if (strides[dim] == 1) {
125 return StrideInput::S_ONE;
126 } else if (
127 dim + 1 < sizes.size() &&
128 strides[dim] == strides[dim + 1] * sizes[dim + 1]) {
129 return StrideInput::S_CONT;
130 // Transposed Contiguous depends on prior dim and contiguous depends on next
131 // dim, so to avoid a mutual dependence check that the next dim is Stride
132 // Contiguous
133 } else if (
134 dim > 0 && strides[dim] == strides[dim - 1] * sizes[dim - 1] &&
135 (stride_inputs[dim - 1 + stride_inputs_offset] != StrideInput::S_CONT)) {
136 return StrideInput::S_TRAN_CONT;
137 } else {
138 return StrideInput::S_AS_ARG;
139 }
140 }
141
summarizeInputStrides(const TensorType & tt)142 static std::vector<StrideInput> summarizeInputStrides(const TensorType& tt) {
143 auto strides = *tt.strides().concrete_sizes();
144 auto sizes = *tt.sizes().concrete_sizes();
145 if (c10::is_contiguous_strides(sizes, strides)) {
146 return {StrideInput::TENSOR_CONT};
147 // TODO: channels last 3d
148 } else if (c10::is_channels_last_strides_2d(sizes, strides)) {
149 return {StrideInput::TENSOR_CONT_CHANNELS_LAST};
150 }
151 std::vector<StrideInput> stride_inputs;
152 for (size_t dim = 0; dim < sizes.size(); ++dim) {
153 stride_inputs.push_back(
154 summarizeStrideDim(sizes, strides, dim, stride_inputs, 0));
155 }
156 return stride_inputs;
157 };
158
159 // Todo: incorporate in codegen
summarizeOutputStrides(const TensorType & tt)160 static StrideInput summarizeOutputStrides(const TensorType& tt) {
161 auto strides = *tt.strides().concrete_sizes();
162 auto sizes = *tt.sizes().concrete_sizes();
163 // We only try to maintain output striding for channels last tensors,
164 // otherwise we defer to contiguous
165 // TODO: channels last 3d
166 if (c10::is_channels_last_strides_2d(sizes, strides)) {
167 return StrideInput::TENSOR_CONT_CHANNELS_LAST;
168 }
169 return StrideInput::TENSOR_CONT;
170 }
171
172 // Generalize Complete Shapes inputs to Symbolic Shapes.
173 // Dimensions of value 1 will be preserved, otherwise
174 // dimensions with the same value will be bucketed to the same
175 // symbolic shape.
176 // E.g. Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1)
177 // Also summarize input striding behavior. The Size information is stored on the
178 // type, The striding is returned. See StrideInput for description of stride
179 // specializations
180 static std::optional<std::vector<std::vector<StrideInput>>>
TryGeneralizeInputDimensionsToSymbolicShapes(const std::shared_ptr<Graph> & tensorexpr_graph)181 TryGeneralizeInputDimensionsToSymbolicShapes(
182 const std::shared_ptr<Graph>& tensorexpr_graph) {
183 std::map<size_t, int64_t> shape_to_sym_shape;
184 std::vector<std::vector<StrideInput>> input_striding;
185
186 for (Value* v : tensorexpr_graph->inputs()) {
187 if (!v->type()->cast<TensorType>()) {
188 continue;
189 }
190 auto tt = v->type()->expectRef<TensorType>();
191 if (!tt.sizes().isComplete() || !tt.strides().isComplete()) {
192 return std::nullopt;
193 }
194 input_striding.push_back(summarizeInputStrides(tt));
195 std::vector<at::ShapeSymbol> shape_vec = *tt.symbolic_sizes().sizes();
196 auto new_sizes = c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
197 auto value = shape.value();
198 TORCH_INTERNAL_ASSERT(value >= 0, "Expected complete tensor");
199 if (value == 1) {
200 return value;
201 } else if (shape_to_sym_shape.count(static_cast<size_t>(value))) {
202 return shape_to_sym_shape[value];
203 } else {
204 auto new_shape_symbol = at::ShapeSymbol::newSymbol().value();
205 shape_to_sym_shape[static_cast<size_t>(value)] = new_shape_symbol;
206 return new_shape_symbol;
207 }
208 });
209 v->setType(tt.withSymbolicShapes(c10::SymbolicShape(new_sizes)));
210 }
211 return input_striding;
212 }
213
moveConstantTensorsOutOfSubgraph(Node * tensorexpr_graph_node,const std::shared_ptr<Graph> & tensorexpr_graph)214 static void moveConstantTensorsOutOfSubgraph(
215 Node* tensorexpr_graph_node,
216 const std::shared_ptr<Graph>& tensorexpr_graph) {
217 auto parent = tensorexpr_graph_node->owningGraph();
218
219 auto env = [&](Value* v) {
220 TORCH_INTERNAL_ASSERT(
221 false,
222 "this should never happen since constant nodes do not have any inputs",
223 v->debugName());
224 return v;
225 };
226
227 WithInsertPoint wip(tensorexpr_graph_node);
228 std::vector<Node*> to_destroy;
229 for (auto node : tensorexpr_graph->nodes()) {
230 if (node->kind() == prim::Constant) {
231 if (!node->output()->type()->cast<TensorType>()) {
232 continue;
233 }
234
235 // copy the constant and insert that copy into the parent graph.
236 auto copy = parent->createClone(node, env);
237 parent->insertNode(copy);
238
239 // add a new input to the te subgraph and replace the uses of the
240 // constant with this input.
241 auto new_const = tensorexpr_graph->addInput();
242 new_const->setType(node->output()->type());
243 node->output()->replaceAllUsesWith(new_const);
244
245 // add the copy as input to the te node
246 tensorexpr_graph_node->addInput(copy->output());
247
248 to_destroy.push_back(node);
249 }
250 }
251
252 for (auto n : to_destroy) {
253 n->destroy();
254 }
255 }
256
GenerateGuard(Node * tensorexpr_graph_node,bool add_composed_op)257 bool GenerateGuard(Node* tensorexpr_graph_node, bool add_composed_op) {
258 auto tensorexpr_graph = SubgraphUtils::getSubgraph(tensorexpr_graph_node);
259
260 // Move constant tensors from the subgraph to the outer scope.
261 // This is necessary because symbolic shape analysis does not handle the
262 // case of broadcast(constant, symbolic_shape) well and that results in poor
263 // performance.
264 moveConstantTensorsOutOfSubgraph(tensorexpr_graph_node, tensorexpr_graph);
265
266 // Generalize Inputs
267 auto input_striding =
268 TryGeneralizeInputDimensionsToSymbolicShapes(tensorexpr_graph);
269 if (!input_striding) {
270 return false;
271 }
272
273 // Get output striding behavior
274 std::vector<StrideInput> output_striding;
275 for (Value* v : tensorexpr_graph->outputs()) {
276 if (!v->type()->cast<TensorType>()) {
277 continue;
278 }
279 auto tt = v->type()->expectRef<TensorType>();
280 if (!tt.sizes().isComplete() || !tt.strides().isComplete()) {
281 return false;
282 }
283 output_striding.push_back(summarizeOutputStrides(tt));
284 }
285
286 // Try To Propagate Shapes
287 auto maybe_shape_compute_mapping =
288 PropagateShapesAndBuildLargeShapeComputeGraph(
289 tensorexpr_graph,
290 *tensorexpr_graph->nodes().begin(),
291 *tensorexpr_graph->nodes().end());
292 if (!maybe_shape_compute_mapping) {
293 return false;
294 }
295
296 // Insert Guard
297 insertDynamicShapesGuard(
298 *maybe_shape_compute_mapping,
299 tensorexpr_graph_node,
300 add_composed_op,
301 *input_striding,
302 output_striding);
303 return true;
304 }
305
inlineFallbackGraphAndAddSRCopyOutOp(std::shared_ptr<Graph> graph)306 static void inlineFallbackGraphAndAddSRCopyOutOp(std::shared_ptr<Graph> graph) {
307 DepthFirstGraphNodeIterator it(graph);
308
309 Node* n = nullptr;
310 while ((n = it.next()) != nullptr) {
311 if (n->kind() == prim::FallbackGraph) {
312 break;
313 }
314 }
315 TORCH_INTERNAL_ASSERT(n != nullptr, "Expected to find fallback graph");
316
317 auto if_node = n->owningBlock()->owningNode();
318 IfView if_v(if_node);
319 SubgraphUtils::unmergeSubgraph(n);
320
321 auto false_block = if_v.elseBlock();
322 std::vector<Value*> false_block_outputs(
323 if_v.elseOutputs().begin(), if_v.elseOutputs().end());
324 TORCH_INTERNAL_ASSERT(!false_block_outputs.empty());
325
326 for (auto out : false_block_outputs) {
327 TORCH_INTERNAL_ASSERT(out->type()->cast<TensorType>());
328 }
329 auto copy_node = graph->create(
330 prim::StaticRuntimeCopyOuts,
331 false_block_outputs,
332 false_block_outputs.size());
333 false_block->appendNode(copy_node);
334 for (size_t i = 0; i < false_block_outputs.size(); ++i) {
335 false_block->replaceOutput(i, copy_node->outputs().at(i));
336 }
337 }
338
339 // TODO: share more logic with tensorexpr_fuser ?
insertDynamicShapesGuard(const ShapeComputeGraphMapping & shape_mapping,Node * guarded_node,bool add_composed_op,std::vector<std::vector<StrideInput>> & input_info,std::vector<StrideInput> & output_strides)340 void insertDynamicShapesGuard(
341 const ShapeComputeGraphMapping& shape_mapping,
342 Node* guarded_node,
343 bool add_composed_op,
344 std::vector<std::vector<StrideInput>>& input_info,
345 std::vector<StrideInput>& output_strides) {
346 GRAPH_DEBUG(
347 "Inserting a prim::TensorExprDynamicGuard guard for a node",
348 *guarded_node);
349 auto subgraph = SubgraphUtils::getSubgraph(guarded_node);
350
351 // Fixup types of the subgraph inputs
352 std::vector<Value*> inputs_to_check;
353 std::vector<TypePtr> guard_types;
354 for (const auto i : c10::irange(guarded_node->inputs().size())) {
355 Value* node_input = guarded_node->inputs().at(i);
356 // We only check inputs of the guarded nodes
357 if (!node_input->type()->cast<TensorType>()) {
358 continue;
359 }
360 inputs_to_check.push_back(node_input);
361 guard_types.emplace_back(
362 subgraph->inputs().at(i)->type()->expect<TensorType>()->withStrides(
363 c10::VaryingShape<c10::Stride>()));
364 }
365 TORCH_INTERNAL_ASSERT(inputs_to_check.size());
366
367 // prim::TensorExprDynamicGuard nodes look like the following:
368 // %types_match : bool = prim::TypeCheck[attr:types](%inp1 : Tensor, %inp2 :
369 // Tensor)
370 // The input tensors are checked against the expected types on attr::types
371 // Omitting refining the input Tensors for now because they are not actually
372 // used within tensorexpr/kernel.cpp (only the inputs to the Graph are, not
373 // the inputs to the node) and we would have to redo the mapping to compute
374 // symbolic shapes
375
376 Node* typecheck_node =
377 guarded_node->owningGraph()
378 ->create(Symbol::prim("TensorExprDynamicGuard"), inputs_to_check, 1)
379 ->insertBefore(guarded_node);
380
381 typecheck_node->tys_(attr::types, std::move(guard_types));
382 Value* typecheck_result = typecheck_node->output()->setType(BoolType::get());
383
384 // Insert if
385 auto versioning_if =
386 guarded_node->owningGraph()
387 ->create(prim::If, {typecheck_result}, guarded_node->outputs().size())
388 ->insertAfter(typecheck_node);
389
390 for (size_t idx = 0; idx < guarded_node->outputs().size(); ++idx) {
391 versioning_if->output(idx)->setType(guarded_node->output(idx)->type());
392 guarded_node->output(idx)->replaceAllUsesWith(versioning_if->output(idx));
393 }
394 auto true_block = versioning_if->addBlock();
395 auto false_block = versioning_if->addBlock();
396
397 // Fill in the false block. It should contain the unoptimized
398 // copy of the fused subgraph.
399 WithInsertPoint guard(false_block->return_node());
400 const auto subgraph_outputs = insertGraph(
401 *guarded_node->owningGraph(), *subgraph, guarded_node->inputs());
402 for (Value* output : subgraph_outputs) {
403 false_block->registerOutput(output);
404 }
405
406 // types get copied to the fallback graph, so remove specializations before
407 // replacing
408 removeTensorTypeSpecializations(false_block);
409 replaceBlockWithFallbackGraph(false_block, guarded_node->inputs());
410
411 // Fill in the true block. It has all inputs type-checked and its
412 // body should be the fusion group node.
413 guarded_node->moveBefore(true_block->return_node());
414
415 for (Value* output : guarded_node->outputs()) {
416 true_block->registerOutput(output);
417 }
418
419 // Insert Symbolic Shapes Compute and add as inputs to TE Node/Graph
420 // symbolic_shape_inputs will be a list of each symbolic shape,
421 // and the last N inputs to TE Graph/Node will be the N
422 // symbolic shape values
423 auto map = InsertSymbolicShapesCompute(shape_mapping, guarded_node);
424 std::vector<int64_t> symbolic_shape_inputs;
425 for (const auto& pair : map) {
426 symbolic_shape_inputs.push_back(pair.first);
427 guarded_node->addInput(pair.second);
428 std::stringstream ss;
429 ss << "SS_" << -pair.first;
430 subgraph->addInput(ss.str())->setType(IntType::get());
431 }
432 guarded_node->is_(
433 attr::symbolic_shape_inputs, std::move(symbolic_shape_inputs));
434
435 std::vector<std::vector<std::string>> input_striding;
436 for (auto& vec : input_info) {
437 auto string_info =
438 fmap(vec, [&](StrideInput inp) { return toString(inp); });
439 input_striding.push_back(string_info);
440 }
441 auto ival = IValue(input_striding);
442 guarded_node->ival_(attr::striding_inputs_desc, ival);
443 typecheck_node->ival_(attr::striding_inputs_desc, std::move(ival));
444
445 for (Value* v : subgraph->inputs()) {
446 if (auto t = v->type()->cast<TensorType>()) {
447 v->setType(t->withStrides(c10::VaryingShape<c10::Stride>()));
448 }
449 }
450 for (Value* v : subgraph->outputs()) {
451 if (auto t = v->type()->cast<TensorType>()) {
452 v->setType(t->withStrides(c10::VaryingShape<c10::Stride>()));
453 }
454 }
455
456 std::vector<std::string> output_striding =
457 fmap(output_strides, [&](StrideInput inp) { return toString(inp); });
458 auto output_ival = IValue(output_striding);
459 guarded_node->ival_(attr::striding_outputs_desc, std::move(output_ival));
460
461 if (add_composed_op) {
462 // only in SR flow do we check for values on the stack and
463 // forward them along as tensor outputs
464 // TODO: - refactor and make explicit part of TE Kernel api
465 guarded_node->i_(attr::allow_stack_outputs, 1);
466
467 // Create a TensorExprDynamicGroup node
468 auto te_dyn_group = SubgraphUtils::createSingletonSubgraph(
469 typecheck_node, prim::TensorExprDynamicGroup);
470 SubgraphUtils::mergeNodeIntoSubgraph(versioning_if, te_dyn_group);
471 inlineFallbackGraphAndAddSRCopyOutOp(
472 SubgraphUtils::getSubgraph(te_dyn_group));
473 }
474 }
475
476 // This operator is inserted at the end of the fallback block computing outputs
477 // for the fusion group. We convert block1():
478 // %14 : Tensor = aten::mul(%0, %1)
479 // %15 : Tensor = aten::mul(%0, %14)
480 // -> (%15, %14)
481 // return (%3, %4)
482 // to
483 // block1():
484 // %14 : Tensor = aten::mul(%0, %1)
485 // %15 : Tensor = aten::mul(%0, %14)
486 // %16 : Tensor, %17 : Tensor = prim::StaticRuntimeCopyOuts(%15, %14)
487 // -> (%16, %17)
488 // Every output of the block is added as an input, and for each input there is
489 // a StaticRuntimeCopyOuts output. SR invokes the composed operator first with
490 // no tensors on the stack, in which case the Op will just return back the
491 // inputs. Second it invokes it with pre-allocated tensors, one for each output
492 // of the Fusion group, which is the same number of outputs of the fallback
493 // block. In this case we copy over the values of the inputs to pre-allocated
494 // tensors
495 // Note: this logic is meant to reflect the invocation of the TE Kernel
496 // and `runWithAllocatedOutputs` in tensorexpr_fuser.cpp
StaticRuntimeCopyOuts(const Node * node)497 static Operation StaticRuntimeCopyOuts(const Node* node) {
498 auto num_ten_inputs = node->inputs().size();
499 return [num_ten_inputs](Stack& stack) {
500 std::vector<IValue> inputs = pop(stack, num_ten_inputs);
501 // uncommon case - first run
502 if (stack.empty()) {
503 for (IValue elem : inputs) {
504 push(stack, std::move(elem));
505 }
506 } else {
507 at::ArrayRef<IValue> outputs = last(stack, num_ten_inputs);
508 for (size_t i = 0; i < inputs.size(); ++i) {
509 IValue out = outputs[i];
510 at::Tensor& out_t = out.toTensor();
511 fastResizeToZero(out_t);
512 out_t.resize_as_(inputs[i].toTensor());
513 out_t.copy_(inputs[i].toTensor());
514 }
515 }
516 return 0;
517 };
518 }
519
520 RegisterOperators SRCopyOuts({
521 torch::jit::Operator(
522 prim::StaticRuntimeCopyOuts,
523 StaticRuntimeCopyOuts,
524 AliasAnalysisKind::CONSERVATIVE),
525 });
526
527 // On each invocation of this guard, we need to check all of the static
528 // information (dtype/device/requires grad/contiguity/static dims),
529 // and also the that the symbolic shape dimensions are observed.
530 // For any symbolic dimension we need to set its value on its first
531 // use and for all subsequent uses check that the values are equal
532 RegisterOperators reg_guard({
533 Operator(
534 "prim::TensorExprDynamicGuard(...) -> bool",
__anonc321d2750602(const Node* node) 535 [](const Node* node) -> Operation {
536 const auto& types = node->tys(attr::types);
537
538 // Each inputs expected # of dims
539 std::vector<size_t> expected_dims;
540
541 // A flattened vector of all the expected values for all
542 // tensor dims. A positive value corresponds to a static
543 // shape to check and a negative value corresponds to symbolic
544 // dimension index to check
545 std::vector<int64_t> flattened_input_dims;
546
547 // Each inputs expected scalar types
548 std::vector<c10::ScalarType> expected_scalar_types;
549
550 // Map from symbolic dimension value to its set's index
551 std::map<int64_t, size_t> sym_dim_flat_index;
552 TORCH_INTERNAL_ASSERT(!types.empty());
553
554 // we should just be fusing fusion groups with a single device
555 // and with tensors not requiring grad
556 auto maybe_device = types[0]->expect<TensorType>()->device();
557 TORCH_INTERNAL_ASSERT(maybe_device);
558 auto device = *maybe_device;
559
560 // flattened vector of each inputs striding behavior
561 std::vector<StrideInput> flattened_input_striding;
562 const IValue& sym_strides = node->ival(attr::striding_inputs_desc);
563 std::vector<std::vector<std::string>> sym_strides_strs =
564 sym_strides.to<std::vector<std::vector<std::string>>>();
565 for (const auto& vec : sym_strides_strs) {
566 std::vector<StrideInput> input_desc;
567 for (const std::string& str : vec) {
568 flattened_input_striding.push_back(strideInputFromString(str));
569 }
570 }
571
572 for (const auto& type : types) {
573 auto tt = type->expect<TensorType>();
574 auto ss = tt->symbolic_sizes();
575 TORCH_INTERNAL_ASSERT(ss.rank());
576 expected_dims.push_back(*ss.rank());
577 TORCH_INTERNAL_ASSERT(tt->scalarType());
578 expected_scalar_types.push_back(*tt->scalarType());
579 TORCH_INTERNAL_ASSERT(tt->device() && *tt->device() == device);
580 for (size_t i = 0; i < *ss.rank(); ++i) {
581 auto sym_dim = ss[i];
582 auto value = sym_dim.value();
583 if (value >= 0) {
584 flattened_input_dims.push_back(value);
585 } else {
586 // use index for set if it exists, otherwise extend the vector
587 // of sym shapes by 1
588 size_t sym_dim_index = 0;
589 if (sym_dim_flat_index.count(value)) {
590 sym_dim_index = sym_dim_flat_index[value];
591 } else {
592 auto size = sym_dim_flat_index.size();
593 sym_dim_flat_index[value] = (-1) - size;
594 sym_dim_index = sym_dim_flat_index[value];
595 }
596 // TODO: potential optimization - if there is a Symbolic
597 // Sym with only one use we dont need to test anything
598 flattened_input_dims.push_back(
599 static_cast<int64_t>(sym_dim_index));
600 }
601 }
602 }
603
604 const auto num_inputs = types.size();
605 const auto num_symbolic_dims = sym_dim_flat_index.size();
606 return [num_inputs,
607 expected_dims,
608 device,
609 expected_scalar_types,
610 flattened_input_dims,
611 flattened_input_striding,
612 num_symbolic_dims](Stack& stack) {
613 at::ArrayRef<IValue> inputs = last(stack, num_inputs);
614 drop(stack, num_inputs);
615 // each invocation we need to reset what value of each symbolic
616 // symbol is.
617 // TODO: could this be a reference and not allocated on
618 // each invocation or would that mess up with multithreaded
619 // inference since we are writing to it?
620 // TODO - smallvector here ?
621 bool grad_mode_enabled = at::GradMode::is_enabled();
622 std::vector<int64_t> flattened_symbolic_dims(num_symbolic_dims, -1);
623 size_t flattened_dim_offset = 0;
624 size_t flattened_stride_offset = 0;
625 for (const auto i : c10::irange(num_inputs)) {
626 at::Tensor tensor = inputs[i].toTensor();
627 if (C10_UNLIKELY(
628 tensor.device() != device ||
629 tensor.dtype() != expected_scalar_types[i])) {
630 push(stack, false);
631 return;
632 }
633 if (C10_UNLIKELY(grad_mode_enabled && tensor.requires_grad())) {
634 push(stack, false);
635 return;
636 }
637 const auto& sizes = tensor.sizes();
638 const auto num_dims = sizes.size();
639 if (C10_UNLIKELY(num_dims != expected_dims[i])) {
640 push(stack, false);
641 return;
642 }
643 auto striding = flattened_input_striding[flattened_stride_offset];
644 // Tensors natively store whether they are contiguous
645 // in the default memory format or in channels last,
646 // so it is more efficient to query whether they follow this
647 // property than iterating over dimensions and checking yourself
648 if (striding == StrideInput::TENSOR_CONT) {
649 if (C10_UNLIKELY(
650 !tensor.is_contiguous(at::MemoryFormat::Contiguous))) {
651 push(stack, false);
652 return;
653 }
654 flattened_stride_offset += 1;
655 } else if (striding == StrideInput::TENSOR_CONT_CHANNELS_LAST) {
656 // TODO: 5D channels last
657 if (C10_UNLIKELY(!tensor.is_contiguous(
658 at::MemoryFormat::ChannelsLast))) {
659 push(stack, false);
660 return;
661 }
662 flattened_stride_offset += 1;
663 } else {
664 auto strides = tensor.strides();
665 for (size_t dim = 0; dim < num_dims; ++dim) {
666 auto summarized_dim = summarizeStrideDim(
667 sizes,
668 strides,
669 dim,
670 flattened_input_striding,
671 flattened_stride_offset);
672 if (C10_UNLIKELY(
673 summarized_dim !=
674 flattened_input_striding
675 [dim + flattened_stride_offset])) {
676 push(stack, false);
677 return;
678 }
679 }
680 flattened_stride_offset += num_dims;
681 }
682 for (const auto dim_index : c10::irange(num_dims)) {
683 const auto dim_value =
684 flattened_input_dims[dim_index + flattened_dim_offset];
685 const int64_t tensor_dim = sizes[dim_index];
686 if (dim_value >= 0) {
687 if (C10_UNLIKELY(dim_value != tensor_dim)) {
688 push(stack, false);
689 return;
690 }
691 } else {
692 // flattened sym indices start at -1,
693 // so -1 -> index 0, -2 -> index 1
694 const auto flattened_sym_index = (-dim_value) - 1;
695 const auto flattened_sym_value =
696 flattened_symbolic_dims[flattened_sym_index];
697 // sym symbol already seen, check value
698 if (flattened_symbolic_dims[flattened_sym_index] >= 0) {
699 if (C10_UNLIKELY(flattened_sym_value != tensor_dim)) {
700 push(stack, false);
701 return;
702 }
703 } else {
704 // not seen, write value
705 flattened_symbolic_dims[flattened_sym_index] = tensor_dim;
706 }
707 }
708 }
709 flattened_dim_offset += num_dims;
710 }
711
712 push(stack, true);
713 return;
714 };
715 },
716 aliasAnalysisFromSchema()),
717 });
718
runTensorExprDynamicGroup(const Code & code,Stack & stack)719 void runTensorExprDynamicGroup(const Code& code, Stack& stack) {
720 InterpreterState interpreter{code};
721 interpreter.run(stack);
722 }
723
createTensorExprDynamicGroup(const Node * node)724 static Operation createTensorExprDynamicGroup(const Node* node) {
725 const auto& graph = node->g(attr::Subgraph);
726 Code code(graph, "");
727 // This implementation creates a Code object and InterpreterState on every
728 // call to TensorExprDynamicGroup, which affects performance. Ideally, we
729 // should be reusing Code and InterpreterState across calls to this op.
730 // But that is resulting in a "No frames found" error.
731 // TODO: Improve the performance of this by figuring out a better approach.
732 // NB: this is only run in SR, which is single-threaded
733 return [code](Stack& stack) {
734 runTensorExprDynamicGroup(code, stack);
735 return 0;
736 };
737 }
738
739 RegisterOperators TensorExprDynamicOp({
740 torch::jit::Operator(
741 prim::TensorExprDynamicGroup,
742 createTensorExprDynamicGroup,
743 AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
744 });
745
746 } // namespace torch::jit
747