1 #include <torch/csrc/jit/tensorexpr/kernel.h>
2
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorGeometry.h>
6 #include <c10/core/ScalarTypeToTypeMeta.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
10 #include <torch/csrc/jit/passes/mkldnn_rewrite.h>
11 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
12 #include <torch/csrc/jit/tensorexpr/analysis.h>
13 #include <torch/csrc/jit/tensorexpr/expr.h>
14 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
15 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
16 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
17 #include <torch/csrc/jit/tensorexpr/loopnest.h>
18 #include <torch/csrc/jit/tensorexpr/loopnest_randomization.h>
19 #include <torch/csrc/jit/tensorexpr/operators/operators.h>
20
21 #include <utility>
22
23 using namespace torch::jit;
24 using namespace torch::jit::tensorexpr;
25
26 namespace torch::jit::tensorexpr {
27
buildErrorMessage(const std::string & s)28 std::string buildErrorMessage(const std::string& s) {
29 static const std::string generic_error_message =
30 "This error occurred in the fuser. You can turn off the fuser with "
31 "torch.jit.enable_fusion(False).";
32 if (s.empty()) {
33 return generic_error_message;
34 }
35 if (s.back() == '.') {
36 return s + " " + generic_error_message;
37 }
38 return s + ". " + generic_error_message;
39 }
40
41 static int te_cuda_pointwise_loop_levels = -1;
42 static int te_cuda_pointwise_block_count = -1;
43 static int te_cuda_pointwise_block_size = -1;
44 static bool fallback_allowed = false;
45 static bool te_generate_block_code = false;
46 static bool te_must_use_llvm_on_cpu = true;
47 static bool cat_wo_conditionals = true;
48 static bool opt_conditionals = false;
49
setFallbackAllowed(bool value)50 bool setFallbackAllowed(bool value) {
51 bool old_value = fallback_allowed;
52 fallback_allowed = value;
53 return old_value;
54 }
55
fallbackAllowed()56 bool fallbackAllowed() {
57 static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR_FALLBACK");
58 if (!enable_c_str) {
59 return fallback_allowed;
60 }
61 if (std::string(enable_c_str) == "0") {
62 return false;
63 }
64 return true;
65 }
66
fallbackEnforced()67 static bool fallbackEnforced() {
68 static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR_FALLBACK");
69 if (tensorexpr::getTEGenerateBlockCode()) {
70 return false;
71 }
72 if (!enable_c_str) {
73 return fallback_allowed;
74 }
75 if (std::string(enable_c_str) == "2") {
76 return true;
77 }
78 return false;
79 }
80
randomTransformsRequested()81 static int64_t randomTransformsRequested() {
82 const char* enable_c_str =
83 std::getenv("PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED");
84 if (!enable_c_str) {
85 return 0;
86 }
87 return std::stoi(std::string(enable_c_str));
88 }
89
90 #ifdef TORCH_ENABLE_LLVM
dontUseLLVMFlag()91 static bool dontUseLLVMFlag() {
92 static const char* enable_c_str =
93 std::getenv("PYTORCH_TENSOREXPR_DONT_USE_LLVM");
94 if (!enable_c_str) {
95 return false;
96 }
97 return std::string(enable_c_str) == "1";
98 }
99 #endif
100
getTECudaPointwiseLoopLevels()101 int& getTECudaPointwiseLoopLevels() {
102 return te_cuda_pointwise_loop_levels;
103 }
104
getTECudaPointwiseBlockCount()105 int& getTECudaPointwiseBlockCount() {
106 return te_cuda_pointwise_block_count;
107 }
108
getTECudaPointwiseBlockSize()109 int& getTECudaPointwiseBlockSize() {
110 return te_cuda_pointwise_block_size;
111 }
112
113 // TODO: Remove this global var
114 // Ideally Block code gen should be decided
115 // based on device type in tensor.
getTEGenerateBlockCode()116 bool& getTEGenerateBlockCode() {
117 return te_generate_block_code;
118 }
119
getTEMustUseLLVMOnCPU()120 bool& getTEMustUseLLVMOnCPU() {
121 return te_must_use_llvm_on_cpu;
122 }
123
getCatWoConditionals()124 bool& getCatWoConditionals() {
125 return cat_wo_conditionals;
126 }
127
getOptConditionals()128 bool& getOptConditionals() {
129 return opt_conditionals;
130 }
131
pickDeviceType(const at::ArrayRef<torch::jit::Value * > & inputs)132 std::optional<at::Device> pickDeviceType(
133 const at::ArrayRef<torch::jit::Value*>& inputs) {
134 std::optional<at::Device> device = std::nullopt;
135 for (auto const& input : inputs) {
136 auto tt = input->type()->cast<TensorType>();
137 if (tt && tt->device()) {
138 if (device && *device != *tt->device()) {
139 return std::nullopt;
140 }
141 device = *tt->device();
142 }
143 }
144 return device;
145 }
146
pickDeviceType(const std::shared_ptr<Graph> & graph)147 static std::optional<at::Device> pickDeviceType(
148 const std::shared_ptr<Graph>& graph) {
149 std::optional<at::Device> device = std::nullopt;
150 for (auto const& node : graph->nodes()) {
151 for (auto const& input : node->inputs()) {
152 if (auto tt = input->type()->cast<TensorType>()) {
153 if (auto inputDevice = tt->device()) {
154 TORCH_INTERNAL_ASSERT(
155 !device || *device == *inputDevice,
156 buildErrorMessage(
157 "Different devices specified for inputs to the fuser."));
158 device = inputDevice;
159 }
160 }
161 }
162 }
163 for (auto const& input : graph->inputs()) {
164 if (auto tt = input->type()->cast<TensorType>()) {
165 if (auto inputDevice = tt->device()) {
166 TORCH_INTERNAL_ASSERT(
167 !device || *device == *inputDevice,
168 buildErrorMessage(
169 "Different devices specified for inputs to the fuser."));
170 device = inputDevice;
171 }
172 }
173 }
174 if (!device) {
175 // By default assume the device is CPU
176 device = at::kCPU;
177 }
178 return device;
179 }
180
181 // If v is a Tensor with concretely-known sizes and dtype, return them, else
182 // nullopt.
getTensorInfoJit(torch::jit::Value * v)183 static std::optional<TensorInfo> getTensorInfoJit(torch::jit::Value* v) {
184 auto const& it = v->type()->cast<TensorType>();
185
186 c10::ScalarType dtype = c10::ScalarType::Float;
187
188 if (!it) {
189 return std::nullopt;
190 }
191 if (!it->isComplete()) {
192 return std::nullopt;
193 }
194 if (it->scalarType()) {
195 // TODO: ideally we should be strict here and return nullopt if the dtype is
196 // absent in the JIT IR. We're assuming a default Float dtype for now, until
197 // dtype propagation is implemented.
198 dtype = *it->scalarType();
199 }
200 auto concrete_sizes = it->sizes().concrete_sizes();
201 if (!concrete_sizes) {
202 return std::nullopt;
203 }
204 return TensorInfo{*concrete_sizes, dtype};
205 }
_pair_int(const IValue & v)206 static std::vector<int64_t> _pair_int(const IValue& v) {
207 if (v.isIntList()) {
208 return v.toIntVector();
209 } else {
210 return {v.toInt(), v.toInt()};
211 }
212 }
213
isContiguous(const torch::jit::Value * v,at::MemoryFormat memory_format)214 bool isContiguous(const torch::jit::Value* v, at::MemoryFormat memory_format) {
215 auto const& tt = v->type()->cast<TensorType>();
216 if (!tt) {
217 return false;
218 }
219 if (!tt->isComplete()) {
220 return false;
221 }
222 auto const& sizes = tt->sizes().concrete_sizes();
223 auto const& strides = tt->strides().concrete_sizes();
224 if (!sizes || !strides) {
225 return false;
226 }
227
228 // Check dimension size first
229 int ndims = (*sizes).size();
230 if ((memory_format == at::MemoryFormat::ChannelsLast && ndims != 4) ||
231 (memory_format == at::MemoryFormat::ChannelsLast3d && ndims != 5)) {
232 return false;
233 }
234
235 return *strides == TensorType::contiguousStridesOf(*sizes, memory_format);
236 }
237
get_conv_groups_index(const torch::jit::Node * node)238 static size_t get_conv_groups_index(const torch::jit::Node* node) {
239 switch (node->kind()) {
240 case aten::conv2d:
241 return 6;
242 case aten::_convolution:
243 return 8;
244 default:
245 TORCH_CHECK(
246 false,
247 "mkldnnPrepackedConvIsSupportedJit expects node kind to be conv2d or _convolution but got ",
248 node->kind());
249 }
250 }
251
252 // The fuser only supports conv2d with very specific properties:
253 // - Static shapes: 4-d input and filter, 1-d bias.
254 // - Constant strides/padding/dilation/groups
255 // - Equal padding and strides, dilation == 1.
256 // - Depthwise (groups == in_channels == out_channels)
257 // - 3x3 kernel
conv2dIsSupportedJit(const torch::jit::Node * node)258 bool conv2dIsSupportedJit(const torch::jit::Node* node) {
259 auto const& input = getTensorInfoJit(node->input(0));
260 auto const& weight = getTensorInfoJit(node->input(1));
261 auto const& bias = getTensorInfoJit(node->input(2));
262 auto const& stride = toIValue(node->input(3));
263 auto const& pad = toIValue(node->input(4));
264 auto const& dilation = toIValue(node->input(5));
265 size_t groups_index = get_conv_groups_index(node);
266 auto const& groups = toIValue(node->input(groups_index));
267
268 // Everything should be statically known.
269 if (!input || !weight || !bias || !stride || !pad || !dilation || !groups) {
270 GRAPH_DEBUG("some params aren't static");
271 return false;
272 }
273
274 // All inputs should be contiguous so no transposition is required.
275 if (!isContiguous(node->input(0)) || !isContiguous(node->input(1)) ||
276 !isContiguous(node->input(2))) {
277 GRAPH_DEBUG("conv2dIsSupported: some inputs are not contiguous");
278 return false;
279 }
280
281 return conv2dIsSupported(
282 *input,
283 *weight,
284 *bias,
285 _pair_int(*stride),
286 _pair_int(*pad),
287 _pair_int(*dilation),
288 groups->toInt());
289 }
290
mkldnnPrepackedConvIsSupportedJit(const torch::jit::Node * node)291 bool mkldnnPrepackedConvIsSupportedJit(const torch::jit::Node* node) {
292 #if AT_MKLDNN_ENABLED()
293 auto const& input = getTensorInfoJit(node->input(0));
294 auto const& weight = getTensorInfoJit(node->input(1));
295 auto const& stride = toIValue(node->input(3));
296 auto const& pad = toIValue(node->input(4));
297 auto const& dilation = toIValue(node->input(5));
298 size_t groups_index = get_conv_groups_index(node);
299 auto const& groups = toIValue(node->input(groups_index));
300
301 // Everything should be statically known (bias could be NoneType =
302 // prim::Constant()).
303 if (!input || !weight || !stride || !pad || !dilation || !groups) {
304 GRAPH_DEBUG("some params aren't static");
305 return false;
306 }
307
308 // Weights and bias should be Constant when using mkldnn backend
309 if (node->input(1)->node()->kind() != prim::Constant ||
310 node->input(2)->node()->kind() != prim::Constant) {
311 GRAPH_DEBUG(
312 "mkldnnPrepackedConvIsSupported: weight or bias is not Constant");
313 return false;
314 }
315
316 // Input and weight should be NHWC contiguous.
317 if (!(isContiguous(node->input(0), at::MemoryFormat::ChannelsLast) &&
318 isContiguous(node->input(1), at::MemoryFormat::ChannelsLast))) {
319 GRAPH_DEBUG(
320 "mkldnnPrepackedConvIsSupported: input or weight is not ChannelsLast contiguous");
321 return false;
322 }
323
324 return mkldnnPrepackedConvIsSupported(
325 *input,
326 *weight,
327 _pair_int(*stride),
328 _pair_int(*pad),
329 _pair_int(*dilation),
330 groups->toInt());
331 #endif
332 return false;
333 }
334
isConv2d(const Node * node)335 bool isConv2d(const Node* node) {
336 if (node->kind() != aten::_convolution) {
337 return false;
338 }
339
340 auto const& stride = toIValue(node->input(3));
341 auto const& pad = toIValue(node->input(4));
342 auto const& dilation = toIValue(node->input(5));
343 auto const& transposed = toIValue(node->input(6));
344 auto const& output_padding = toIValue(node->input(7));
345
346 if (!stride || !pad || !dilation || !transposed || !output_padding) {
347 GRAPH_DEBUG("some params aren't static");
348 return false;
349 }
350
351 if (stride.value().toIntList().size() != 2 ||
352 pad.value().toIntList().size() != 2 ||
353 dilation.value().toIntList().size() != 2 ||
354 output_padding.value().toIntList().size() != 2) {
355 GRAPH_DEBUG("Conv not 2d");
356 return false;
357 }
358
359 if (transposed.value().toBool()) {
360 GRAPH_DEBUG("transposed Conv");
361 return false;
362 }
363 return true;
364 }
365
366 // The fuser currently only supports matmul of 2D x 2D matrices
matmulIsSupported(const torch::jit::Node * node)367 bool matmulIsSupported(const torch::jit::Node* node) {
368 auto const& input0 = getTensorInfoJit(node->input(0));
369 auto const& input1 = getTensorInfoJit(node->input(1));
370
371 // Everything should be statically known.
372 if (!input0 || !input1) {
373 GRAPH_DEBUG("matmulIsSupported: Input shapes aren't static");
374 return false;
375 }
376
377 // Proper ndim for tensor inputs.
378 if (input0->dims.size() != 2 || input1->dims.size() != 2) {
379 GRAPH_DEBUG("matmulIsSupported: Unsupported input sizes");
380 return false;
381 }
382
383 // Inputs should be contiguous, or the TE will needlessly transpose them.
384 if (!isContiguous(node->input(0)) || !isContiguous(node->input(1))) {
385 GRAPH_DEBUG("matmulIsSupported: Input shapes are not contiguous");
386 return false;
387 }
388
389 return true;
390 }
391
392 } // namespace torch::jit::tensorexpr
393
tensorType(const BufPtr & b)394 static at::ScalarType tensorType(const BufPtr& b) {
395 return static_cast<at::ScalarType>(b->dtype().scalar_type());
396 }
397
constant(const torch::jit::Value * v)398 ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) {
399 if (v->node()->kind() == prim::Constant) {
400 auto val = toIValue(v).value();
401 if (val.isDouble()) {
402 return DoubleImm::make(val.toDouble());
403 } else if (val.isInt()) {
404 return LongImm::make(val.toInt());
405 } else if (val.isBool()) {
406 return BoolImm::make(val.toBool());
407 } else if (val.isNone()) {
408 // This is just a placeholder so we don't throw. None-handling
409 // is operator-specific and should be handled properly in
410 // the operator-specific lowering code.
411 return IntImm::make(0);
412 } else {
413 throw unsupported_dtype();
414 }
415 }
416
417 if (!scalars_.count(v)) {
418 throw malformed_input("no scalar in Constant");
419 }
420
421 return scalars_.at(v);
422 }
423
toArg(const torch::jit::Value * v) const424 ArgValue TensorExprKernel::toArg(const torch::jit::Value* v) const {
425 auto vi = scalars_.find(v);
426 if (vi != scalars_.end()) {
427 return VarHandle(vi->second);
428 }
429 auto ti = bufs_.find(v);
430 if (ti != bufs_.end()) {
431 return BufHandle(ti->second);
432 }
433 if (v->node()->kind() == prim::ListConstruct) {
434 std::vector<ArgValue> vec;
435 for (auto el : v->node()->inputs()) {
436 vec.push_back(toArg(el));
437 }
438 if (vec.empty()) {
439 return BufList(); // Return arbitrarily typed vector
440 } else if (std::get_if<BufHandle>(&vec[0])) {
441 return convertVecArgValue<BufHandle>(vec);
442 } else if (std::get_if<int64_t>(&vec[0])) {
443 return convertVecArgValue<int64_t>(vec);
444 }
445 throw unsupported_dtype();
446 }
447 if (v->node()->kind() == prim::Constant) {
448 auto val = toIValue(v).value();
449 if (val.isDouble()) {
450 return val.toDouble();
451 } else if (val.isInt()) {
452 return val.toInt();
453 } else if (val.isBool()) {
454 return val.toBool();
455 } else if (val.isNone()) {
456 // This is just a placeholder so we don't throw. None-handling
457 // is operator-specific and should be handled properly in
458 // the operator-specific lowering code.
459 return ArgNone();
460 } else if (val.isIntList()) {
461 return val.toIntVector();
462 } else if (val.isDoubleList()) {
463 return val.toDoubleVector();
464 } else if (val.isString()) {
465 return val.toStringRef();
466 } else {
467 throw unsupported_dtype(val.type()->str());
468 }
469 }
470
471 if (!scalars_.count(v)) {
472 throw malformed_input("no scalar in Constant");
473 }
474 return scalars_.at(v);
475 }
476
getVarForShape(const c10::ShapeSymbol & ss)477 ExprHandle TensorExprKernel::getVarForShape(const c10::ShapeSymbol& ss) {
478 if (ss.is_static()) {
479 return LongImm::make(ss.static_size());
480 }
481 auto value = ss.value();
482 auto it = shapeSymbolToVar_.find(value);
483 if (it == shapeSymbolToVar_.end()) {
484 VarHandle var("ss" + std::to_string(-value), kLong);
485 shapeSymbolToVar_.emplace(value, var);
486 return std::move(var);
487 }
488 return it->second;
489 }
490
sizesFromSymbolicShape(const c10::SymbolicShape & shape)491 std::vector<ExprHandle> TensorExprKernel::sizesFromSymbolicShape(
492 const c10::SymbolicShape& shape) {
493 std::vector<ExprHandle> dims;
494 auto maybe_rank = shape.rank();
495 TORCH_INTERNAL_ASSERT(maybe_rank);
496 auto rank = *maybe_rank;
497 for (const auto i : c10::irange(rank)) {
498 dims.push_back(getVarForShape(shape[i]));
499 }
500 return dims;
501 }
502
sizesForValue(const torch::jit::Value * v)503 std::vector<ExprHandle> TensorExprKernel::sizesForValue(
504 const torch::jit::Value* v) {
505 if (known_sizes_.count(v)) {
506 return known_sizes_.at(v);
507 }
508
509 // If the shape is present in the type info, just extract it from here. No
510 // need to infer it.
511 if (v->type()->kind() == TypeKind::TensorType) {
512 auto tt = v->type()->cast<TensorType>();
513 return sizesFromSymbolicShape(tt->symbolic_sizes());
514 }
515
516 if (v->type()->isSubtypeOf(*FloatType::get()) ||
517 v->type()->isSubtypeOf(*BoolType::get()) ||
518 v->type()->isSubtypeOf(*IntType::get())) {
519 return {};
520 }
521 if (v->type()->isSubtypeOf(*NoneType::get())) {
522 return {};
523 }
524 GRAPH_DEBUG("Unknown sizes for the node: ", *v->node());
525 GRAPH_DEBUG("Full fusion group graph:\n", *v->node()->owningGraph());
526 std::string msg = std::string("Unhandled node kind (in sizesForValue): ") +
527 v->node()->kind().toQualString();
528 throw malformed_input(msg);
529 }
530
findDtypeForValue(const torch::jit::Value * v)531 static std::optional<ScalarType> findDtypeForValue(const torch::jit::Value* v) {
532 if (v->type()->kind() == TypeKind::TensorType) {
533 auto tt = v->type()->cast<TensorType>();
534 if (tt->scalarType()) {
535 return static_cast<ScalarType>(*tt->scalarType());
536 }
537 }
538 return tryScalarTypeFromJitType(*v->type());
539 }
540
constZeroDimTensorAsScalarArg(const Value * v,std::vector<ArgValue> & args)541 static bool constZeroDimTensorAsScalarArg(
542 const Value* v,
543 std::vector<ArgValue>& args) {
544 if (v->node()->kind() != prim::Constant || !v->type()->cast<TensorType>()) {
545 return false;
546 }
547
548 const auto t = toIValue(v)->toTensor();
549 if (!t.sizes().empty()) {
550 return false;
551 }
552
553 c10::ScalarType dtype = c10::typeMetaToScalarType(t.dtype());
554 switch (dtype) {
555 case ScalarType::Float:
556 args.emplace_back(t.item().toFloat());
557 return true;
558 case ScalarType::Long:
559 args.emplace_back(t.item().toLong());
560 return true;
561 default:
562 std::stringstream ss;
563 ss << "Unsupported tensor dtype:" << dtype
564 << " for converting constant 0-dim Tensor to scalar" << '\n';
565 throw unsupported_dtype(ss.str());
566 }
567 }
568
computeValue(const torch::jit::Value * v)569 Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
570 auto inputs = v->node()->inputs();
571 auto op = v->node()->kind();
572
573 if (op == aten::rand_like) {
574 hasRandom_ = true;
575 }
576
577 auto outputType = findDtypeForValue(v);
578 std::vector<ExprHandle> outputShape = sizesForValue(v);
579 std::vector<ExprHandle> outputStrides = {};
580 if (memory_layout_policy_ == MemoryLayoutPolicy::kChannelsLastNdContiguous) {
581 outputStrides =
582 c10::fmap<ExprHandle>(make_channels_last_strides(outputShape));
583 } else {
584 // Default
585 outputStrides = c10::fmap<ExprHandle>(make_contiguous_strides(outputShape));
586 }
587
588 std::vector<ArgValue> argInputs;
589 if (op == prim::ConstantChunk) {
590 auto const& n = v->node();
591 argInputs.emplace_back(toArg(inputs[0]));
592 argInputs.emplace_back(static_cast<int64_t>(v->offset()));
593 argInputs.emplace_back(n->i(attr::dim));
594 argInputs.emplace_back(n->i(attr::chunks));
595 } else if (op == aten::to) {
596 argInputs.emplace_back(toArg(inputs[0]));
597 } else if (op == aten::quantize_per_tensor) {
598 argInputs.emplace_back(toArg(inputs[0]));
599 if (!constZeroDimTensorAsScalarArg(inputs[1], argInputs)) {
600 argInputs.emplace_back(toArg(inputs[1]));
601 }
602 if (!constZeroDimTensorAsScalarArg(inputs[2], argInputs)) {
603 argInputs.emplace_back(toArg(inputs[2]));
604 }
605 argInputs.emplace_back(toArg(inputs[3]));
606 } else if (op == aten::conv2d) {
607 for (auto inp : inputs) {
608 argInputs.emplace_back(toArg(inp));
609 }
610 // handle optional bias
611 if (std::get_if<ArgNone>(&argInputs[2])) {
612 Dtype dtype = outputType ? Dtype(*outputType) : kFloat;
613 std::vector<ExprHandle> biasShape;
614 biasShape.push_back(outputShape[1]);
615 auto bias_tensor = at::zeros({outputShape[1].AsNode<LongImm>()->value()});
616 unpacked_constant_tensors_.push_back(bias_tensor);
617 BufPtr buf = alloc<Buf>(
618 "conv2d_bias_opt_" + sanitizeName(v->debugName()),
619 ExprHandleVectorToExprVector(biasShape),
620 dtype);
621 constants_.push_back({buf, bias_tensor.data_ptr()});
622 argInputs[2] = BufHandle(buf);
623 }
624 } else {
625 for (auto inp : inputs) {
626 argInputs.emplace_back(toArg(inp));
627 }
628 }
629
630 if (NNCLoweringFunction custom_lowering = getCustomLoweringFor(op)) {
631 return custom_lowering(
632 argInputs, outputShape, outputStrides, outputType, device_);
633 }
634 if (v->node()->maybeSchema()) {
635 if (NNCLoweringFunction lowering =
636 getStandardLoweringFor(c10::toString(v->node()->schema()))) {
637 return lowering(
638 argInputs, outputShape, outputStrides, outputType, device_);
639 }
640 }
641 std::string msg = std::string("Unhandled node kind (in computeValue): ") +
642 op.toQualString();
643 if (v->node()->maybeSchema()) {
644 msg += std::string("\nSchema: ") + c10::toString(v->node()->schema());
645 }
646 throw malformed_input(msg);
647 }
648
649 // True if all the loops in this vector have equal bounds.
loopBoundsAllEqual(const std::vector<ForPtr> & loops)650 static bool loopBoundsAllEqual(const std::vector<ForPtr>& loops) {
651 if (loops.size() <= 1) {
652 return true;
653 }
654 const auto& start = loops.front()->start();
655 const auto& stop = loops.front()->stop();
656 for (size_t i = 1; i < loops.size(); ++i) {
657 const auto& curr_start = loops[i]->start();
658 const auto& curr_stop = loops[i]->stop();
659 if (!exprEquals(start, curr_start) || !exprEquals(stop, curr_stop)) {
660 return false;
661 }
662 }
663 return true;
664 }
665
666 // Recursively fuse all the loops with matching bounds in `st`. Stops fusing
667 // at any level containing non-loops or non-matching bounds. The restriction
668 // on matching bounds exists to avoid inserting conditionals on the loop
669 // indices where none would be needed, which would significantly complicate
670 // vectorization.
fuseAllLoops(const StmtPtr & st)671 static void fuseAllLoops(const StmtPtr& st) {
672 auto block = to<tensorexpr::Block>(st);
673 if (block == nullptr) {
674 return;
675 }
676
677 std::vector<std::vector<ForPtr>> all_outer_loops;
678 std::vector<ForPtr> outer_loops;
679 for (const auto& stmt : *block) {
680 auto loop = to<For>(stmt);
681 auto hasReduction = !NodeFinder<ReduceOp>::find(stmt).empty();
682 if (!loop || hasReduction) {
683 all_outer_loops.push_back(outer_loops);
684 outer_loops.clear();
685 } else {
686 outer_loops.push_back(loop);
687 }
688 }
689 all_outer_loops.push_back(outer_loops);
690
691 for (const auto& outer_loops : all_outer_loops) {
692 if (outer_loops.empty()) {
693 continue;
694 }
695
696 if (!loopBoundsAllEqual(outer_loops)) {
697 continue;
698 }
699
700 ForPtr fusedLoop;
701 if (!LoopNest::fuseLoops(outer_loops, &fusedLoop)) {
702 continue;
703 }
704
705 fuseAllLoops(fusedLoop->body());
706 }
707 }
708
709 // Compute the trip count of a loop if it is a constant.
tripCount(const ForPtr & loop)710 static std::optional<int64_t> tripCount(const ForPtr& loop) {
711 auto tc = IRSimplifier::simplify(
712 cast<int64_t>(ExprHandle(loop->stop()) - ExprHandle(loop->start())));
713 if (auto val = to<LongImm>(tc.node())) {
714 return val->value();
715 }
716 return std::nullopt;
717 }
718
719 // Prune innermost loops until iterations satisfies a minimum grain size.
pruneByGrainSize(std::vector<ForPtr> & loops)720 static void pruneByGrainSize(std::vector<ForPtr>& loops) {
721 constexpr int64_t minGrainSize = 32768;
722 int64_t grainSize = 1;
723 for (int64_t i = loops.size(); i > 0; i--) {
724 auto tc = tripCount(loops[i - 1]);
725 if (!tc) {
726 break;
727 }
728 grainSize *= *tc;
729 if (grainSize < minGrainSize) {
730 loops.pop_back();
731 }
732 }
733 }
734
735 // Retain enough outermost loops to fill the number of threads.
pruneByThreadCount(std::vector<ForPtr> & loops)736 static void pruneByThreadCount(std::vector<ForPtr>& loops) {
737 int64_t trips = 1;
738 auto threads = at::get_num_threads();
739 auto it = loops.begin();
740 for (; it != loops.end(); it++) {
741 if (trips >= threads) {
742 break;
743 }
744 auto tc = tripCount(*it);
745 if (!tc) {
746 break;
747 }
748 trips *= *tc;
749 }
750 loops.erase(it, loops.end());
751 }
752
753 // Flatten and parallelize outer loops, subject to a minimum number of elements
754 // in the inner loop, and a maximum level of thread-level parallelism in the
755 // outer loops.
756 template <typename Bufs>
parallelizeOuterLoops(LoopNest & l,Bufs && bufs)757 static void parallelizeOuterLoops(LoopNest& l, Bufs&& bufs) {
758 for (auto const& buf : bufs) {
759 auto loops = l.getLoopStmtsFor(buf);
760 pruneByGrainSize(loops);
761 pruneByThreadCount(loops);
762
763 // There are no loops to parallelize; give up.
764 if (loops.size() == 0) {
765 continue;
766 }
767 // The loop nest contains a reduction; give up.
768 auto reductions = NodeFinder<ReduceOp>::find(loops[0]);
769 if (reductions.size() > 0) {
770 continue;
771 }
772 // The loop nest has loop carried dependences; give up.
773 if (LoopNest::hasLoopCarriedDependence(loops[0])) {
774 continue;
775 }
776 // Try to flatten the outer loops and parallelize them if successful.
777 ForPtr flattened = nullptr;
778 if (loops.size() == 1) {
779 flattened = loops[0];
780 } else {
781 LoopNest::flatten(loops, &flattened);
782 }
783 if (flattened) {
784 flattened->set_parallel();
785 }
786 }
787 }
788
transformLoops(BackendType backendType,StmtPtr st)789 StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
790 torch::jit::tensorexpr::LoopNest l(std::move(st), bufOutputs_);
791 LoopNest::sanitizeNames(l.root_stmt());
792 GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n");
793 int64_t random_tr_seed = randomTransformsRequested();
794 if (random_tr_seed) {
795 if (random_tr_seed == -1)
796 random_tr_seed = std::time(nullptr);
797 loopnestRandomization(random_tr_seed, l);
798 GRAPH_DEBUG(
799 "After random transform:\n", std::to_string(l.root_stmt()), "\n");
800 }
801
802 bool hasReduction = !NodeFinder<ReduceOp>::find(l.root_stmt()).empty();
803
804 // For Block codegen we create a map of tensor dims before
805 // inlining. Like GPU codegen we need to inline. But the order
806 // where this analysis is run matters.
807 auto block_analysis = std::make_unique<CreateBufferMap>();
808 if (backendType == kBlockCodeGen) {
809 // Run Block analysis to get multi dim buffer info
810 auto root_stmt = l.root_stmt();
811 root_stmt->accept(block_analysis.get());
812 }
813 l.simplify();
814 GRAPH_DEBUG("after simplify", *l.root_stmt());
815
816 // Inlining output & intermediate buffers can duplicate computation.
817 // Duplicating work can slow down the program if it's not ameliorated in some
818 // way, but we've empirically found that:
819 // - On CPU, LLVM's CSE does a good job as long as you horizontally fuse
820 // output loops.
821 // - On GPU, there's enough compute to hide the extra work, and inlining
822 // avoids synchronizing between kernels.
823 l.inlineIntermediateBufs(/*allow_duplicated_work=*/true);
824 GRAPH_DEBUG("after inline", *l.root_stmt());
825
826 // Optimizing conditionals needs to be performed after inlining because
827 // inlining wouldn't work once the loops are split. Also, it has to be
828 // performed before loop fusion because loop fusion introduces cases where
829 // multiple conditionals are in the same loop and this optimization does not
830 // handle such cases yet.
831 if (getOptConditionals()) {
832 l.optimizeConditionals();
833 GRAPH_DEBUG("after optimizing conditionals: ", *l.root_stmt());
834 }
835
836 // Fuse loops "horizontally". This pass allows us to combine loops that
837 // write to different output buffers, as long as they have the same bounds.
838 if (backendType == kLLVMCodeGen) {
839 fuseAllLoops(l.root_stmt());
840 GRAPH_DEBUG("after fuse", *l.root_stmt());
841 parallelizeOuterLoops(l, bufsToBeParallelized_);
842 GRAPH_DEBUG("after parallelize", *l.root_stmt());
843 }
844
845 if (backendType == kCudaCodeGen) {
846 for (const auto& buf : bufOutputs_) {
847 std::vector<ForPtr> loops = l.getLoopStmtsFor(buf);
848 if (loops.empty()) {
849 // This happens when Buf is 0-dim
850 continue;
851 }
852 ForPtr flattened = nullptr;
853 LoopNest::flatten(loops, &flattened);
854 assert(flattened);
855
856 int loopLevels = getTECudaPointwiseLoopLevels();
857 const int kDefaultLoopLevels = 2;
858 loopLevels = (loopLevels > 0) ? loopLevels : kDefaultLoopLevels;
859 int blockCount = getTECudaPointwiseBlockCount();
860 int blockSize = getTECudaPointwiseBlockSize();
861
862 if (loopLevels == 2) {
863 ForPtr inner;
864 const int kDefaultBlockSize = 512;
865 if (blockSize < 0) {
866 blockSize = kDefaultBlockSize;
867 }
868 LoopNest::splitWithMask(flattened, blockSize, &inner);
869 flattened->set_gpu_block_index(0);
870 inner->set_gpu_thread_index(0);
871 } else if (loopLevels == 3) {
872 ForPtr inner;
873 ForPtr inner1;
874 // TODO: change the number of microprocessors
875 const int kDefaultBlockCount = 1280;
876 const int kDefaultBlockSize = 256;
877 blockCount = (blockCount > 0) ? blockCount : kDefaultBlockCount;
878 blockSize = (blockSize > 0) ? blockSize : kDefaultBlockSize;
879 LoopNest::splitWithMask(flattened, blockCount * blockSize, &inner);
880 LoopNest::splitWithMask(inner, blockSize, &inner1);
881 inner->set_gpu_block_index(0);
882 inner1->set_gpu_thread_index(0);
883 } else {
884 throw std::runtime_error(
885 "Invalid loop-level: " + std::to_string(loopLevels));
886 }
887 }
888 }
889
890 if (backendType == kBlockCodeGen) {
891 for (const auto& buf : bufOutputs_) {
892 const int default_fp16_blocksize = 16;
893 const int default_uint8_blocksize = 32;
894 int blockSize = default_fp16_blocksize;
895 // We only handle looplevels == 2 for now
896 if (buf->dtype().scalar_type() == ScalarType::Byte) {
897 blockSize = default_uint8_blocksize;
898 }
899 std::vector<ForPtr> loops = l.getLoopStmtsFor(buf);
900 TORCH_INTERNAL_ASSERT(
901 !loops.empty(),
902 buildErrorMessage(
903 "No loops found for the buffer " + buf->name_hint() +
904 " in the fuser."));
905 ForPtr flattened = nullptr;
906 LoopNest::flatten(loops, &flattened);
907 assert(flattened);
908
909 ForPtr inner = nullptr;
910 LoopNest::splitWithMask(flattened, blockSize, &inner);
911 flattened->set_gpu_block_index(0);
912 inner->set_gpu_thread_index(0);
913 flattened->set_buffer_map(block_analysis->getBufferMap());
914 }
915 }
916
917 if (pre_alloc_) {
918 auto interm_bufs = l.getIntermediateBufs();
919 preAllocIntermediateBufs(interm_bufs);
920 }
921
922 l.prepareForCodegen();
923
924 GRAPH_DEBUG("after prepareForCodegen", *l.root_stmt());
925 l.simplify();
926 GRAPH_DEBUG("after simplification", *l.root_stmt());
927
928 if (backendType == kLLVMCodeGen && !hasReduction) {
929 l.vectorizeInnerLoops();
930 GRAPH_DEBUG("after vectorization", *l.root_stmt());
931 }
932
933 StmtPtr stmt = l.root_stmt();
934 // Arithmetic Simplification.
935 stmt = IRSimplifier::simplify(stmt);
936 GRAPH_DEBUG("Final Stmt:\n", std::to_string(stmt), "\n");
937 return stmt;
938 }
939
getCodeGenName(BackendType backendType)940 std::string TensorExprKernel::getCodeGenName(BackendType backendType) {
941 switch (backendType) {
942 case kCudaCodeGen:
943 return "cuda_codegen";
944 case kLLVMCodeGen:
945 return "llvm_codegen";
946 case kSimpleIREval:
947 return "simple_ir_eval";
948 case kBlockCodeGen:
949 return "block_codegen";
950 default:
951 throw std::runtime_error(
952 "invalid backend type: " +
953 std::to_string(static_cast<int>(backendType)));
954 }
955 }
956
957 template <typename T>
isValidPrimProperty(const std::optional<T> & a,T b)958 static bool isValidPrimProperty(const std::optional<T>& a, T b) {
959 return !a.has_value() || *a == b;
960 }
961
inferBackendTypeFromDevice(at::Device device)962 TensorExprKernel::BackendType TensorExprKernel::inferBackendTypeFromDevice(
963 at::Device device) {
964 BackendType backendType = BackendType::kUninitialized;
965 if (device.type() == at::kCUDA) {
966 backendType = kCudaCodeGen;
967 } else if (device.type() == at::kCPU && getTEGenerateBlockCode()) {
968 backendType = kBlockCodeGen;
969 } else if (device.type() == at::kCPU) {
970 #ifdef TORCH_ENABLE_LLVM
971 backendType = dontUseLLVMFlag() ? kSimpleIREval : kLLVMCodeGen;
972 #else
973 backendType = kSimpleIREval;
974 #endif
975 if (getTEMustUseLLVMOnCPU() && backendType == kSimpleIREval) {
976 throw std::runtime_error("LLVM Backend not found");
977 }
978 } else {
979 throw std::runtime_error("Invalid device type");
980 }
981 return backendType;
982 }
983
984 // we use the debug names in printing cuda code, they need to be removed
985 // of characters that can't be used in a variable identifier
genInputDebugNames()986 void TensorExprKernel::genInputDebugNames() {
987 std::unordered_map<std::string, const torch::jit::Value*> name_to_value;
988 std::unordered_set<std::string> name_set;
989 std::unordered_map<const torch::jit::Value*, std::string> value_to_name;
990 for (const torch::jit::Value* input : graph_->inputs()) {
991 std::string sanitized_name = sanitizeName(input->debugName());
992 // we could get fancier here, but name conflict is extremely unlikely
993 while (name_set.count(sanitized_name)) {
994 sanitized_name.append("_");
995 }
996 value_to_name[input] = sanitized_name;
997 name_set.insert(sanitized_name);
998 }
999 input_name_map_ = std::move(value_to_name);
1000 }
1001
1002 template <typename T>
toExprHandles(const std::vector<T> & sizes)1003 static std::vector<ExprHandle> toExprHandles(const std::vector<T>& sizes) {
1004 std::vector<ExprHandle> dims;
1005 dims.reserve(sizes.size());
1006 for (auto const& size : sizes) {
1007 dims.emplace_back(size);
1008 }
1009 return dims;
1010 }
1011
getStrideArg(size_t tensor_input_index,size_t stride_index)1012 ExprHandle TensorExprKernel::getStrideArg(
1013 size_t tensor_input_index,
1014 size_t stride_index) {
1015 auto it = strideArgToVar_.find(
1016 std::pair<size_t, size_t>(tensor_input_index, stride_index));
1017 if (it == strideArgToVar_.end()) {
1018 VarHandle var(
1019 "stride_arg" + std::to_string(tensor_input_index) + "_" +
1020 std::to_string(stride_index),
1021 kLong);
1022 strideArgToVar_[std::pair<size_t, size_t>(
1023 tensor_input_index, stride_index)] = var;
1024 return std::move(var);
1025 }
1026 return it->second;
1027 }
1028
getSymbolicStrideDesc(const torch::jit::Value * value)1029 std::vector<torch::jit::StrideInput>& TensorExprKernel::getSymbolicStrideDesc(
1030 const torch::jit::Value* value) {
1031 TORCH_INTERNAL_ASSERT(symbolic_strides_.count(value));
1032 return symbolic_strides_[value];
1033 }
1034
getInputStrides(const torch::jit::Value * input,const std::vector<ExprHandle> & inputTensorDims)1035 std::vector<ExprHandle> TensorExprKernel::getInputStrides(
1036 const torch::jit::Value* input,
1037 const std::vector<ExprHandle>& inputTensorDims) {
1038 std::vector<ExprHandle> inputTensorStrides;
1039 if (input->isCompleteTensor()) {
1040 auto const strides =
1041 input->type()->expect<TensorType>()->strides().concrete_sizes();
1042 std::vector<ExprHandle> inputTensorStrides;
1043 for (size_t stride : *strides) {
1044 inputTensorStrides.push_back(LongImm::make(stride));
1045 }
1046 return inputTensorStrides;
1047 }
1048
1049 size_t rank = inputTensorDims.size();
1050 std::vector<StrideInput>& stride_input = getSymbolicStrideDesc(input);
1051 if (stride_input.size() == 1 &&
1052 (stride_input[0] == StrideInput::TENSOR_CONT_CHANNELS_LAST ||
1053 stride_input[0] == StrideInput::TENSOR_CONT)) {
1054 auto strides = stride_input[0] == StrideInput::TENSOR_CONT
1055 ? make_contiguous_strides(inputTensorDims)
1056 : make_channels_last_strides(inputTensorDims);
1057 return fmap(
1058 strides, [&](ExprPtr stride) { return ExprHandle(std::move(stride)); });
1059 }
1060
1061 inputTensorStrides.resize(rank);
1062 std::vector<bool> stride_set;
1063 for (size_t i = 0; i < rank; ++i) {
1064 stride_set.push_back(false);
1065 }
1066 // first, generate non-dependent values
1067 size_t generated_strides = 0;
1068 for (const auto i : c10::irange(rank)) {
1069 if (stride_input[i] == torch::jit::StrideInput::S_ONE) {
1070 inputTensorStrides[i] = LongImm::make(1);
1071 stride_set[i] = true;
1072 generated_strides++;
1073 } else if (stride_input[i] == torch::jit::StrideInput::S_AS_ARG) {
1074 size_t input_index = input->offset();
1075 inputTensorStrides[i] = getStrideArg(input_index, i);
1076 stride_set[i] = true;
1077 generated_strides++;
1078 }
1079 }
1080 // Contiguous and Transposed Contiguous depend on adjacent values
1081 while (generated_strides != rank) {
1082 for (int i = static_cast<int>(rank) - 1; i >= 0; i--) {
1083 if (stride_input[i] == torch::jit::StrideInput::S_CONT &&
1084 stride_set[i + 1]) {
1085 inputTensorStrides[i] =
1086 inputTensorStrides[i + 1] * inputTensorDims[i + 1];
1087
1088 stride_set[i] = true;
1089 generated_strides++;
1090 }
1091 }
1092 for (int i = 0; i < static_cast<int>(rank); i++) {
1093 if (stride_input[i] == torch::jit::StrideInput::S_TRAN_CONT &&
1094 stride_set[i - 1]) {
1095 inputTensorStrides[i] =
1096 inputTensorStrides[i - 1] * inputTensorDims[i - 1];
1097 stride_set[i] = true;
1098 generated_strides++;
1099 }
1100 }
1101 }
1102 return inputTensorStrides;
1103 }
1104
bindInput(const torch::jit::Value * input)1105 Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
1106 auto const& t = input->type();
1107 auto const& outputs = input->owningGraph()->outputs();
1108 std::unordered_set<const Value*> outputs_set(outputs.begin(), outputs.end());
1109
1110 auto is_concrete_cont = [](const torch::jit::Value* input,
1111 const MemoryLayoutPolicy& mem_layout_policy) {
1112 if (input->isCompleteTensor()) {
1113 auto mem_layout = (mem_layout_policy == MemoryLayoutPolicy::kContiguous)
1114 ? at::MemoryFormat::Contiguous
1115 : at::MemoryFormat::ChannelsLast;
1116 return isContiguous(input, mem_layout);
1117 } else {
1118 return false;
1119 }
1120 };
1121
1122 auto is_symbolic_cont = [](std::vector<torch::jit::StrideInput> desc,
1123 const MemoryLayoutPolicy& mem_layout_policy) {
1124 if (desc.size() == 1) {
1125 auto mem_layout = (mem_layout_policy == MemoryLayoutPolicy::kContiguous)
1126 ? torch::jit::StrideInput::TENSOR_CONT
1127 : torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST;
1128 return desc[0] == mem_layout;
1129 } else {
1130 return false;
1131 }
1132 };
1133
1134 Tensor result(nullptr, nullptr);
1135 switch (t->kind()) {
1136 case TypeKind::TensorType: {
1137 auto tt = input->type()->cast<TensorType>();
1138 bool contiguous_concrete_tensor =
1139 is_concrete_cont(input, memory_layout_policy_);
1140 bool contiguous_symbolic_tensor = false;
1141 if (has_symbolic_shapes_) {
1142 auto desc = getSymbolicStrideDesc(input);
1143 contiguous_symbolic_tensor =
1144 is_symbolic_cont(desc, memory_layout_policy_);
1145 }
1146
1147 // Get input size and strides
1148 auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
1149 auto inputTensorStrides = getInputStrides(input, size_handles);
1150
1151 // We don't need to copy the input if:
1152 // 1) it is not an output AND
1153 // 2) it is contiguous
1154 bool contiguous =
1155 contiguous_concrete_tensor || contiguous_symbolic_tensor;
1156 if (!outputs_set.count(input) && contiguous) {
1157 BufHandle inBuffer(
1158 "t" + input_name_map_[input],
1159 sizesFromSymbolicShape(tt->symbolic_sizes()),
1160 inputTensorStrides,
1161 ToDtype(static_cast<ScalarType>(*tt->scalarType())));
1162 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1163 inBuffer.node()->is_contiguous() ||
1164 inBuffer.node()->is_channels_last_1d_contiguous() ||
1165 inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast) ||
1166 inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast3d));
1167 bufs_.emplace(input, inBuffer.node());
1168 bufferArgs_.emplace_back(inBuffer);
1169 break;
1170 }
1171
1172 // if the input isn't contiguous or is an output,
1173 // write strided input into contiguous buffer that is
1174 // then used in all further compute
1175 ExprHandle flat_size = 1;
1176 for (size_t i = 0; i < size_handles.size(); ++i) {
1177 auto size = size_handles[i];
1178 if (size.AsNode<LongImm>() && immediateAs<int64_t>(size.node()) == 0) {
1179 flat_size = 0;
1180 break;
1181 }
1182 flat_size = flat_size + (size - 1) * inputTensorStrides[i];
1183 }
1184 flat_size = IRSimplifier::simplify(flat_size);
1185 BufHandle inBuffer(
1186 "t" + input_name_map_[input],
1187 {flat_size},
1188 ToDtype(static_cast<ScalarType>(*tt->scalarType())));
1189
1190 result = Compute(
1191 "input" + std::to_string(bufs_.size() + 1),
1192 size_handles,
1193 [&](const std::vector<VarHandle>& axes) {
1194 ExprHandle idx = 0;
1195 for (size_t i = 0; i < axes.size(); i++) {
1196 idx = idx + axes[i] * inputTensorStrides[i];
1197 }
1198 return inBuffer.load(idx);
1199 });
1200 bufs_.emplace(input, result.buf());
1201 bufferArgs_.emplace_back(inBuffer);
1202 break;
1203 }
1204 case TypeKind::FloatType: {
1205 VarHandle v("v" + input_name_map_[input], kDouble);
1206 bufferArgs_.emplace_back(v);
1207 scalars_.emplace(input, v);
1208 break;
1209 }
1210 case TypeKind::BoolType: {
1211 VarHandle v("v" + input_name_map_[input], kBool);
1212 bufferArgs_.emplace_back(v);
1213 scalars_.emplace(input, v);
1214 break;
1215 }
1216 case TypeKind::IntType: {
1217 VarHandle v("v" + input_name_map_[input], kLong);
1218 bufferArgs_.emplace_back(v);
1219 scalars_.emplace(input, v);
1220 break;
1221 }
1222 default: {
1223 throw unsupported_dtype(t->repr_str());
1224 break;
1225 }
1226 }
1227 return result;
1228 }
1229
getCustomLoweringFor(c10::Symbol op) const1230 NNCLoweringFunction TensorExprKernel::getCustomLoweringFor(
1231 c10::Symbol op) const {
1232 if (custom_lowerings_.count(op))
1233 return custom_lowerings_.at(op);
1234 return nullptr;
1235 }
1236
1237 template <typename T>
reverse_sort_indices(const std::vector<T> & v)1238 std::vector<size_t> reverse_sort_indices(const std::vector<T>& v) {
1239 // initialize original index locations
1240 std::vector<size_t> idx(v.size());
1241 iota(idx.begin(), idx.end(), 0);
1242
1243 std::sort(idx.begin(), idx.end(), [&v](size_t i1, size_t i2) {
1244 return v[i1] > v[i2];
1245 });
1246 return idx;
1247 }
1248
denseAndNonOverlapping(at::ArrayRef<int64_t> sizes,at::ArrayRef<int64_t> strides)1249 static bool denseAndNonOverlapping(
1250 at::ArrayRef<int64_t> sizes,
1251 at::ArrayRef<int64_t> strides) {
1252 return (strides == at::infer_dense_strides(sizes, strides));
1253 }
1254
convertSymbolicOutputToCorrectStrides(const std::vector<ExprHandle> & sizes,const std::vector<size_t> & sorted_stride_indices_descending,const std::vector<ExprPtr> & strides,BufPtr & buf)1255 Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
1256 const std::vector<ExprHandle>& sizes,
1257 const std::vector<size_t>& sorted_stride_indices_descending,
1258 const std::vector<ExprPtr>& strides,
1259 BufPtr& buf) {
1260 // We need to convert the output tensor so that its values are layed
1261 // so that when viewed from the output strides the values are correct.
1262 // A contiguous Tensor of size(2, 3) with values 0-5 is layed out as:
1263 // [0] [1] [2] [3] [4] [5]
1264 // The same valued tensor with strides (1, 2) would be layed out like
1265 // [0] [3] [1] [4] [2] [5]
1266 // When we are doing the re-ordering of values into the output tensor,
1267 // we are iterating per-element of the input, and we are fixed
1268 // in indexing in to the output tensor at [i, j] = val
1269 // `val` we want here is equal to the indices for the output
1270 // tensor that would have given the same position as the output
1271 // The position is equal to the sum of stride[i] * index[i],
1272 // and we can can calculate the equivalent indices in the
1273 // output tensor strides by iteratively computing the index of
1274 // the biggest stride:
1275 // absolute = ...
1276 // for stride in strides_from_largest_to_smallest:
1277 // cur_idx = absolute // stride
1278 // absolute = absolute % stride
1279 std::vector<ExprPtr> default_strides = make_contiguous_strides(sizes);
1280 auto zero = LongImm::make(0);
1281 return Compute(
1282 "output_1", sizes, [&](const std::vector<VarHandle>& axes_input) {
1283 std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
1284 auto absolute_position = ExprHandle(immLike(axes[0], 0));
1285 for (size_t i = 0; i < axes.size(); ++i) {
1286 ExprHandle stride(default_strides[i]);
1287 ExprHandle axis = axes[i];
1288 absolute_position = absolute_position + (stride * axis);
1289 }
1290 std::vector<ExprHandle> new_axes(
1291 sorted_stride_indices_descending.size());
1292 for (size_t stride_index : sorted_stride_indices_descending) {
1293 const auto& stride = strides[stride_index];
1294 auto index = absolute_position / ExprHandle(stride);
1295 // XXX, in symbolic output ordering, we do not the arbitrary
1296 // ordering of strides as in usual output ordering, just
1297 // channels last, so even in the presence of size == 1
1298 // we produce correct output here
1299 absolute_position = absolute_position % ExprHandle(stride);
1300 new_axes[stride_index] = index;
1301 }
1302 return BufHandle(buf).load(new_axes);
1303 });
1304 }
1305
convertSymbolicOutputToCorrectStrides(torch::jit::Value * v)1306 Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
1307 torch::jit::Value* v) {
1308 const TensorTypePtr& tt = v->type()->expect<TensorType>();
1309 TORCH_INTERNAL_ASSERT(
1310 bufs_.count(v),
1311 buildErrorMessage(
1312 "Output tensor has no corresponding bufs in the fuser."));
1313 BufPtr buf = bufs_.at(v);
1314 TORCH_INTERNAL_ASSERT(buf != nullptr);
1315 TORCH_INTERNAL_ASSERT(tt != nullptr);
1316 TORCH_INTERNAL_ASSERT(tt->symbolic_sizes().rank() != std::nullopt);
1317
1318 auto stride_desc = getSymbolicStrideDesc(v);
1319 TORCH_INTERNAL_ASSERT(stride_desc.size() == 1);
1320 auto memory_format = (stride_desc[0] == torch::jit::StrideInput::TENSOR_CONT)
1321 ? at::MemoryFormat::Contiguous
1322 : at::MemoryFormat::ChannelsLast;
1323 // output is contiguous with specified memory format, no work to do
1324 if (buf->is_contiguous(memory_format)) {
1325 return Tensor(buf, nullptr);
1326 }
1327
1328 TORCH_INTERNAL_ASSERT(
1329 stride_desc[0] == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
1330 auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
1331 auto strides = make_channels_last_strides(sizes);
1332 // For a tensor with dimensions N C H W, channels last
1333 // format will is in format N H W C,
1334 // so the order largest to smallest will be N, H, W, C
1335 std::vector<size_t> sorted_stride_indices = {0, 2, 3, 1};
1336 auto zero = LongImm::make(0);
1337 std::vector<ExprPtr> default_strides = make_contiguous_strides(sizes);
1338 // See explanation in convertOutputToCorrectStrides
1339 return convertSymbolicOutputToCorrectStrides(
1340 sizes, sorted_stride_indices, strides, buf);
1341 }
1342
convertStaticShapeOutputToCorrectStrides(torch::jit::Value * v)1343 Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides(
1344 torch::jit::Value* v) {
1345 const TensorTypePtr& tt = v->type()->expect<TensorType>();
1346 TORCH_INTERNAL_ASSERT(
1347 bufs_.count(v),
1348 buildErrorMessage(
1349 "Output tensor has no corresponding bufs in the fuser."));
1350 BufPtr buf = bufs_.at(v);
1351
1352 // No shape info is present in the graph
1353 if (!tt->sizes().concrete_sizes()) {
1354 std::string msg =
1355 std::string("Shapes for output '%") + v->debugName() + "' are unknown";
1356 throw malformed_input(msg);
1357 }
1358
1359 TORCH_INTERNAL_ASSERT(
1360 tt->sizes().concrete_sizes(),
1361 buildErrorMessage("Output shapes are unknown."));
1362 auto sizes = *tt->sizes().concrete_sizes();
1363 at::MemoryFormat memory_format =
1364 (memory_layout_policy_ == MemoryLayoutPolicy::kContiguous)
1365 ? c10::MemoryFormat::Contiguous
1366 : c10::MemoryFormat::ChannelsLast;
1367 std::vector<int64_t> default_strides =
1368 TensorType::contiguousStridesOf(sizes, memory_format);
1369 if (!tt->strides().concrete_sizes()) {
1370 return Tensor(buf, nullptr);
1371 }
1372 TORCH_INTERNAL_ASSERT(
1373 tt->strides().concrete_sizes(),
1374 buildErrorMessage("Output strides are unknown."));
1375 const std::vector<int64_t> strides = *tt->strides().concrete_sizes();
1376 // All Tensors in NNC are layed out in default, contiguous layout.
1377 // If the output is also default contiguous we don't need to do anything
1378 if (strides == default_strides) {
1379 return Tensor(buf, nullptr);
1380 }
1381 // If the tensor is not dense or overlaps, we have
1382 // no way of matching the profiled striding
1383 if (!denseAndNonOverlapping(sizes, strides)) {
1384 return Tensor(buf, nullptr);
1385 }
1386
1387 auto dims = sizesForValue(v);
1388 auto zero = LongImm::make(0);
1389 std::vector<size_t> sorted_stride_indices = reverse_sort_indices(strides);
1390
1391 // TODO: call into `convertOutputToCorrectStrides`. Currently this causes a
1392 // bug in IRSimplifier to occur. See explanation in
1393 // `convertOutputToCorrectStrides`
1394 return Compute(
1395 "output_1", dims, [&](const std::vector<VarHandle>& axes_input) {
1396 std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
1397 auto absolute_position = ExprHandle(immLike(axes[0], 0));
1398 for (size_t i = 0; i < axes.size(); ++i) {
1399 absolute_position = absolute_position +
1400 (ExprHandle(immLike(axes[i], default_strides[i])) * axes[i]);
1401 }
1402
1403 std::vector<ExprHandle> new_axes(sorted_stride_indices.size());
1404 for (size_t stride_index : sorted_stride_indices) {
1405 auto size = sizes[stride_index];
1406 auto index = zero;
1407 if (size != 1) {
1408 auto stride = strides[stride_index];
1409 index = absolute_position /
1410 ExprHandle(immLike(absolute_position, stride));
1411 absolute_position = absolute_position %
1412 ExprHandle(immLike(absolute_position, stride));
1413 }
1414 new_axes[stride_index] = index;
1415 }
1416 return BufHandle(buf).load(new_axes);
1417 });
1418 }
1419
bindConstant(const torch::jit::Value * v)1420 void TensorExprKernel::bindConstant(const torch::jit::Value* v) {
1421 auto val = toIValue(v).value();
1422 if (torch::isCustomClass(val)) {
1423 auto name_hint = "const_" + sanitizeName(v->debugName());
1424 auto dtype = Dtype(ScalarType::Float);
1425 std::vector<ExprPtr> dims;
1426 BufPtr buf = alloc<Buf>(name_hint, dims, dtype);
1427 auto dataPtr = val.toObjectRef().getSlot(0).toCapsule().get();
1428 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1429 constants_.push_back({buf, dataPtr, const_cast<Node*>(v->node())});
1430 bufs_[v] = buf;
1431 return;
1432 }
1433 if (!v->type()->cast<TensorType>()) {
1434 // Only Tensor constants need to be bound, scalar constants will be turned
1435 // into immediates in TE IR
1436 return;
1437 }
1438 auto const_tensor = toIValue(v)->toTensor();
1439 auto scalar_type = c10::typeMetaToScalarType(const_tensor.options().dtype());
1440 auto sizes = const_tensor.sizes();
1441 std::vector<ExprHandle> te_sizes;
1442 te_sizes.reserve(sizes.size());
1443 for (auto s : sizes) {
1444 te_sizes.emplace_back(s);
1445 }
1446 BufPtr buf = alloc<Buf>(
1447 "const_" + sanitizeName(v->debugName()),
1448 ExprHandleVectorToExprVector(te_sizes),
1449 ToDtype(scalar_type));
1450
1451 if (!const_tensor.is_contiguous()) {
1452 const_tensor = const_tensor.clone().contiguous();
1453 unpacked_constant_tensors_.push_back(const_tensor);
1454 }
1455
1456 constants_.push_back({buf, const_tensor.data_ptr()});
1457 bufs_[v] = buf;
1458 }
1459
preAllocIntermediateBufs(const std::vector<BufPtr> & interm_bufs)1460 std::vector<BufPtr> TensorExprKernel::preAllocIntermediateBufs(
1461 const std::vector<BufPtr>& interm_bufs) {
1462 std::vector<BufPtr> remaining_interm_bufs;
1463 for (const auto& buf : interm_bufs) {
1464 // Check if buf shape is static and compute its size if static.
1465 bool is_static = true;
1466 size_t size =
1467 elementSize(buf->dtype().scalar_type()) * buf->dtype().lanes();
1468 for (auto& d : buf->dims()) {
1469 if (!d->isConstant()) {
1470 is_static = false;
1471 break;
1472 }
1473 size = size * (*intValue(d));
1474 }
1475 // Only allocate memory for static bufs.
1476 if (!is_static) {
1477 remaining_interm_bufs.push_back(buf);
1478 continue;
1479 }
1480 auto bp = (void*)malloc(size);
1481 if (!bp) {
1482 remaining_interm_bufs.push_back(buf);
1483 continue;
1484 }
1485 constants_.push_back({buf, bp});
1486 }
1487 return remaining_interm_bufs;
1488 }
1489
bindAllInputs()1490 BlockPtr TensorExprKernel::bindAllInputs() {
1491 std::vector<CodeGen::BufferArg> symbolic_shape_args;
1492 std::vector<CodeGen::BufferArg> symbolic_stride_args;
1493
1494 auto symbolic_shape_inputs_start_pos =
1495 nInputs_ - symbolic_shape_inputs_.size();
1496 if (has_symbolic_shapes_) {
1497 // The graph is supposed to have input params that represent the symbolic
1498 // dims at the end of the list of inputs. The number of such symbolic input
1499 // params is defined by the size of the `symbolic_shape_inputs_` vector.
1500 //
1501 // TODO: Check if the tensors with symbolic shapes are contiguous.
1502 TORCH_CHECK(
1503 nInputs_ > static_cast<int64_t>(symbolic_shape_inputs_.size()),
1504 "Symbolic dims not provided as inputs to the graph");
1505
1506 // First, process the symbolic input params and create a new variable for
1507 // each of them.
1508 // NOTE: This has to be done before processing the tensor inputs, because
1509 // their symbolic sizes needs to be associated with these variables we
1510 // create for the symbolic input params.
1511 symbolic_shape_args.reserve(symbolic_shape_inputs_.size());
1512
1513 for (size_t i = symbolic_shape_inputs_start_pos;
1514 i < static_cast<size_t>(nInputs_);
1515 ++i) {
1516 auto input = graph_->inputs()[i];
1517 if (input->type()->kind() != TypeKind::IntType) {
1518 throw std::runtime_error(
1519 "Expected integer type input to graph for symbolic dims.");
1520 }
1521 VarHandle v("v" + input_name_map_[input], kLong);
1522 symbolic_shape_args.emplace_back(v);
1523 scalars_.emplace(input, v);
1524 shapeSymbolInputPos_[scalars_[input].node()] = i;
1525 }
1526 // For every shape symbol, store a map to the corresponding var.
1527 for (size_t i = 0; i < symbolic_shape_inputs_.size(); ++i) {
1528 shapeSymbolToVar_[symbolic_shape_inputs_[i]] =
1529 scalars_[graph_->inputs()[symbolic_shape_inputs_start_pos + i]];
1530 }
1531
1532 // Next, process symbolic input params and create an argument for symbolic
1533 for (size_t i = 0; i < symbolic_shape_inputs_start_pos; ++i) {
1534 auto input = graph_->inputs()[i];
1535 auto tt = input->type()->cast<TensorType>();
1536 if (!tt) {
1537 continue;
1538 }
1539 auto symbolic_stride = getSymbolicStrideDesc(input);
1540 for (size_t j = 0; j < symbolic_stride.size(); ++j) {
1541 if (symbolic_stride[j] == torch::jit::StrideInput::S_AS_ARG) {
1542 VarHandle v("v" + input_name_map_[input], kLong);
1543 symbolic_stride_args.emplace_back(v);
1544 strideArgToVar_[{i, j}] = v;
1545 input_stride_args_.emplace_back(i, j);
1546 }
1547 }
1548 }
1549 }
1550
1551 // Block to collect the Stmts corresponding to all tensors.
1552 auto block = alloc<Block>(std::vector<StmtPtr>({}));
1553
1554 // Process the inputs before the symbolic input params.
1555 for (const auto i : c10::irange(symbolic_shape_inputs_start_pos)) {
1556 auto input = graph_->inputs()[i];
1557 Tensor t = bindInput(input);
1558 if (t.stmt()) {
1559 block->append_stmt(t.stmt());
1560 }
1561 }
1562 // Now, add all the variables corresponding to the symbolic input params.
1563 bufferArgs_.insert(
1564 bufferArgs_.end(),
1565 symbolic_shape_args.begin(),
1566 symbolic_shape_args.end());
1567
1568 // Now, add all the variables corresponding to symbolic stride inputs
1569 bufferArgs_.insert(
1570 bufferArgs_.end(),
1571 symbolic_stride_args.begin(),
1572 symbolic_stride_args.end());
1573
1574 return block;
1575 }
1576
deduceMemoryLayoutPolicy()1577 void TensorExprKernel::deduceMemoryLayoutPolicy() {
1578 // If the tensor is channels-last contiguous, the preferred memory layout
1579 // propagation policy is to use channels-last. Otherwise, the preferred policy
1580 // is to use contiguous.
1581 auto _prefer_symbolic_mem =
1582 [](const torch::jit::Value* val,
1583 const std::vector<torch::jit::StrideInput>& stride_desc_vec) {
1584 TORCH_INTERNAL_ASSERT(!stride_desc_vec.empty());
1585 // Has symbolic stride information
1586 auto cur_stride_desc = stride_desc_vec[0];
1587 return (cur_stride_desc ==
1588 torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST)
1589 ? MemoryLayoutPolicy::kChannelsLastNdContiguous
1590 : MemoryLayoutPolicy::kContiguous;
1591 };
1592
1593 auto _prefer_static_mem = [](const torch::jit::Value* val) {
1594 // No shape info is present in the graph
1595 TORCH_INTERNAL_ASSERT(
1596 val->isCompleteTensor(),
1597 buildErrorMessage(val->debugName() + " is not a complete tensor."));
1598 const auto& tt = val->type()->expect<TensorType>();
1599 const auto sizes = *tt->sizes().concrete_sizes();
1600 const auto strides = *tt->strides().concrete_sizes();
1601 return (c10::is_channels_last_strides_2d(sizes, strides))
1602 ? MemoryLayoutPolicy::kChannelsLastNdContiguous
1603 : MemoryLayoutPolicy::kContiguous;
1604 };
1605
1606 // Filter out the tensor from the graph inputs and outputs to
1607 // deduce the memory layout propagation policy
1608 auto _is_tensor = [](const jit::Value* el) {
1609 return el->type()->kind() == TypeKind::TensorType;
1610 };
1611 std::vector<torch::jit::Value*> graph_io_tensors;
1612 std::copy_if(
1613 graph_->inputs().begin(),
1614 graph_->inputs().end(),
1615 std::back_inserter(graph_io_tensors),
1616 _is_tensor);
1617 std::copy_if(
1618 graph_->outputs().begin(),
1619 graph_->outputs().end(),
1620 std::back_inserter(graph_io_tensors),
1621 _is_tensor);
1622 // std::all_of returns true if the range is empty. But we prefer to keep
1623 // the original memory layout propagation policy for this case. So we
1624 // check whether the range is empty.
1625 auto prefer_channels_last = (!graph_io_tensors.empty());
1626 for (auto el : graph_io_tensors) {
1627 auto is_complete = el->isCompleteTensor();
1628 auto is_symbolic = symbolic_strides_.count(el);
1629
1630 auto preferred_mem_layout = is_complete
1631 ? _prefer_static_mem(el)
1632 : (is_symbolic ? _prefer_symbolic_mem(el, symbolic_strides_[el])
1633 : MemoryLayoutPolicy::kContiguous);
1634 if (preferred_mem_layout != MemoryLayoutPolicy::kChannelsLastNdContiguous) {
1635 prefer_channels_last = false;
1636 break;
1637 }
1638 }
1639
1640 // If the memory layout of all the input and outputs is channels-last
1641 // contiguous, the propagated memory layout should be channels-last.
1642 // Otherwise, the propagated memory layout is contiguous which is as
1643 // same as current situation.
1644 memory_layout_policy_ = prefer_channels_last
1645 ? MemoryLayoutPolicy::kChannelsLastNdContiguous
1646 : MemoryLayoutPolicy::kContiguous;
1647 }
1648
optimizeOwningGraph()1649 void TensorExprKernel::optimizeOwningGraph() {
1650 GRAPH_DUMP("TensorExprKernel graph (Before graph optimization):", graph_);
1651
1652 // We may manipulate output pointers in graph manipulation. So we store the
1653 // original outputs for symbolic strides information synchronization
1654 auto _orignal_graph_outputs = graph_->outputs().vec();
1655
1656 // Get the graph device information first. The graph optimization
1657 // might be device specific.
1658 device_ = *pickDeviceType(graph_);
1659
1660 // Determine the propagated memory layout
1661 deduceMemoryLayoutPolicy();
1662
1663 // Fuse Conv with Eltwise Op
1664 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph_);
1665 FuseConvWithEltwise(graph_);
1666
1667 // Optimize the concatenation
1668 OptimizeCat(graph_);
1669
1670 // Synchronize the symbolic strides information
1671 auto graph_outputs = graph_->outputs();
1672 TORCH_INTERNAL_ASSERT(graph_outputs.size() == _orignal_graph_outputs.size());
1673 for (auto i : c10::irange(graph_outputs.size())) {
1674 auto el_orig = _orignal_graph_outputs.at(i);
1675 auto el_new = graph_outputs.at(i);
1676 if (symbolic_strides_.count(el_orig) && (el_orig != el_new)) {
1677 symbolic_strides_[el_new] = symbolic_strides_[el_orig];
1678 symbolic_strides_.erase(el_orig);
1679 }
1680 }
1681
1682 GRAPH_DUMP("TensorExprKernel graph (After graph optimization):", graph_);
1683 }
1684
compile()1685 void TensorExprKernel::compile() {
1686 GRAPH_DUMP("TensorExprKernel graph:", graph_);
1687
1688 has_symbolic_shapes_ = !symbolic_shape_inputs_.empty();
1689 nInputs_ = graph_->inputs().size();
1690 nOutputs_ = graph_->outputs().size();
1691 genInputDebugNames();
1692
1693 // Bind inputs to buffers.
1694 auto block = bindAllInputs();
1695
1696 // Bind nodes to tensor compute expressions.
1697 for (auto const& n : graph_->nodes()) {
1698 if (n->kind() == prim::ListConstruct) {
1699 continue;
1700 } else if (n->kind() == prim::Constant) {
1701 bindConstant(n->output());
1702 continue;
1703 } else {
1704 for (auto const& output : n->outputs()) {
1705 if (output->hasUses()) {
1706 Tensor t = computeValue(output);
1707
1708 // If there are for-loops before ExternalCall as follows,
1709 // stmt1: for:
1710 // stmt2 for:
1711 // stmt3: ExternalCall
1712 // the for-loops would not be parallelized. So we mark the
1713 // buf args of ExternalCall as to be parallelized to make sure
1714 // its previous loop still could be parallelized.
1715 if (to<ExternalCall>(t.stmt())) {
1716 auto _external_call = to<ExternalCall>(t.stmt());
1717 for (const auto& _buf : _external_call->buf_args()) {
1718 bufsToBeParallelized_.insert(_buf);
1719 }
1720 }
1721
1722 if (output->type()->cast<TensorType>()) {
1723 // Value is tensor
1724 if (t.buf()) {
1725 bufs_.emplace(output, t.buf());
1726 }
1727 block->append_stmt(t.stmt());
1728 } else {
1729 // Value is scalar
1730 //
1731 // We represent scalar computations in TE with a pair of statements:
1732 // Let val = <compute_expression>
1733 // Store(buf_for_scalar[0], val)
1734 //
1735 // Subsequent computations will use val when they refer to the
1736 // given value, and the buffer will be used if we need to return
1737 // the computed value as an output of the kernel. If this is not an
1738 // output, the store will be removed later by DCE.
1739 //
1740 // NB: NNC's lowering functions return Tensor, which is a pair
1741 // <Buf, Stmt>, but here we also need Var. How can we obtain all of
1742 // Var, Buf, and Stmt?
1743 // We use the following trick: the lowering function creates the
1744 // Let-stmt and a "fake" buffer, whose only purpose is to hold the
1745 // Var. Then outside the lowering function (namely, right here) we
1746 // generate the store and the actual buffer.
1747 VarPtr v = t.buf()->base_handle();
1748 scalars_[output] = VarHandle(v);
1749 block->append_stmt(t.stmt());
1750 std::vector<ExprPtr> dims;
1751 BufHandle buf(
1752 "scalar_" + sanitizeName(output->debugName()), {}, v->dtype());
1753 StmtPtr store = Store::make(buf, {}, ExprHandle(v));
1754 block->append_stmt(store);
1755 bufs_.emplace(output, buf.node());
1756 }
1757 }
1758 }
1759 }
1760 if (hasRandom_ && hasBroadcast_) {
1761 throw std::runtime_error(
1762 "Cannot support broadcast and random within one kernel");
1763 }
1764 }
1765
1766 // Move output operands from `bufs_` to `bufOutputs_`
1767 for (auto i : c10::irange(graph_->outputs().size())) {
1768 auto& output = graph_->outputs().at(i);
1769 if (!bufs_.count(output)) {
1770 throw malformed_input("cannot find output Tensor");
1771 }
1772 if (!output->type()->cast<TensorType>()) {
1773 // Scalar outputs are represented as 0-dim buffers.
1774 bufOutputs_.insert(bufs_.at(output));
1775 bufsToBeParallelized_.insert(bufs_.at(output));
1776 bufferArgs_.emplace_back(BufHandle(bufs_.at(output)));
1777 tensorOutputTensorOptions_.emplace_back(
1778 c10::TensorOptions(tensorType(bufs_.at(output))).device(device_));
1779 tensorOutputSizes_.emplace_back();
1780 tensorOutputStrides_.emplace_back();
1781 isOutputScalar_.push_back(true);
1782 bufs_.erase(output);
1783 continue;
1784 }
1785
1786 const auto& tt = output->type()->expect<TensorType>();
1787 if (has_symbolic_shapes_) {
1788 auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
1789 tensorOutputSymbolicSizes_.push_back(sizes);
1790 TORCH_INTERNAL_ASSERT(symbolic_strides_.count(output));
1791 auto stride_desc_vec = symbolic_strides_[output];
1792 TORCH_INTERNAL_ASSERT(stride_desc_vec.size() == 1);
1793 auto stride_desc = stride_desc_vec[0];
1794 tensorOutputStrideDesc_.push_back(stride_desc);
1795 Tensor properly_strided_output =
1796 convertSymbolicOutputToCorrectStrides(output);
1797 if (properly_strided_output.stmt()) {
1798 block->append_stmt(properly_strided_output.stmt());
1799 }
1800 bufs_[output] = properly_strided_output.buf();
1801 } else {
1802 // The "strided" tensor will be incorrect if used in NNC,
1803 // since NNC views it as contiguous. Only convert it to the right
1804 // strides at the end of the kernel (if already contiguous it's a no-op)
1805 Tensor properly_strided_output =
1806 convertStaticShapeOutputToCorrectStrides(output);
1807 if (properly_strided_output.stmt()) {
1808 block->append_stmt(properly_strided_output.stmt());
1809 }
1810 bufs_[output] = properly_strided_output.buf();
1811 auto sizes = *tt->sizes().concrete_sizes();
1812 tensorOutputSizes_.push_back(sizes);
1813 auto strides = tt->strides().concrete_sizes();
1814
1815 // If the tensor is not dense or overlaps, we have
1816 // no way of matching the profiled striding
1817 if (strides && denseAndNonOverlapping(sizes, *strides)) {
1818 tensorOutputStrides_.push_back(*strides);
1819 } else {
1820 tensorOutputStrides_.push_back(TensorType::contiguousStridesOf(sizes));
1821 }
1822 }
1823
1824 bufOutputs_.insert(bufs_.at(output));
1825 bufsToBeParallelized_.insert(bufs_.at(output));
1826 bufferArgs_.emplace_back(BufHandle(bufs_.at(output)));
1827 tensorOutputTensorOptions_.emplace_back(
1828 c10::TensorOptions(tensorType(bufs_.at(output))).device(device_));
1829 isOutputScalar_.push_back(false);
1830 bufs_.erase(output);
1831 }
1832
1833 BackendType backendType = inferBackendTypeFromDevice(device_);
1834 stmt_ = transformLoops(backendType, block);
1835
1836 for (const auto& c : constants_) {
1837 bufferArgs_.emplace_back(BufHandle(c.buf));
1838 }
1839
1840 if (has_symbolic_shapes_) {
1841 tensorOutputSizes_.resize(bufOutputs_.size());
1842 tensorOutputStrides_.resize(bufOutputs_.size());
1843 }
1844
1845 // Generate code.
1846 codegen_ = CreateCodeGen(
1847 getCodeGenName(backendType),
1848 stmt_,
1849 bufferArgs_,
1850 device_,
1851 kernel_func_name_);
1852 }
1853
recompile()1854 void TensorExprKernel::recompile() {
1855 codegen_ = CreateCodeGen(
1856 "llvm_codegen", stmt_, bufferArgs_, device_, kernel_func_name_);
1857 }
1858
TensorExprKernel(const std::shared_ptr<Graph> & subgraph,std::string kernel_func_name,std::unordered_map<c10::Symbol,NNCLoweringFunction> custom_lowerings,std::vector<int64_t> symbolic_shape_inputs,bool pre_alloc,std::unordered_map<const torch::jit::Value *,std::vector<torch::jit::StrideInput>> symbolic_strides)1859 TensorExprKernel::TensorExprKernel(
1860 const std::shared_ptr<Graph>& subgraph,
1861 std::string kernel_func_name,
1862 std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings,
1863 std::vector<int64_t> symbolic_shape_inputs,
1864 bool pre_alloc /*= false*/,
1865 std::unordered_map<
1866 const torch::jit::Value*,
1867 std::vector<torch::jit::StrideInput>> symbolic_strides)
1868 : graph_(subgraph),
1869 code_(subgraph, ""),
1870 symbolic_shape_inputs_(std::move(symbolic_shape_inputs)),
1871 custom_lowerings_(std::move(custom_lowerings)),
1872 pre_alloc_(pre_alloc),
1873 kernel_func_name_(std::move(kernel_func_name)),
1874 symbolic_strides_(std::move(symbolic_strides)) {
1875 optimizeOwningGraph();
1876 // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
1877 allow_fallback_ = fallbackAllowed();
1878
1879 if (!allow_fallback_) {
1880 compile();
1881 return;
1882 }
1883
1884 use_fallback_ = fallbackEnforced();
1885 if (use_fallback_) {
1886 return;
1887 }
1888
1889 try {
1890 compile();
1891 } catch (...) {
1892 use_fallback_ = true;
1893 }
1894 }
1895
run(Stack & stack) const1896 void TensorExprKernel::run(Stack& stack) const {
1897 if (!use_fallback_ && !allow_fallback_) {
1898 runKernel(stack);
1899 } else if (!use_fallback_ && allow_fallback_) {
1900 try {
1901 runKernel(stack);
1902 } catch (...) {
1903 fallback(stack);
1904 }
1905 } else {
1906 fallback(stack);
1907 }
1908 }
1909
getStaticOutputSizesAndStrides(const at::ArrayRef<IValue> & inputs,std::vector<std::vector<int64_t>> * sizes,std::vector<std::vector<int64_t>> * strides) const1910 void TensorExprKernel::getStaticOutputSizesAndStrides(
1911 const at::ArrayRef<IValue>& inputs,
1912 std::vector<std::vector<int64_t>>* sizes,
1913 std::vector<std::vector<int64_t>>* strides) const {
1914 TORCH_INTERNAL_ASSERT(has_symbolic_shapes_);
1915 // If there are symbolic shapes, then the output tensor size wouldn't have
1916 // been computed at compile time. That has to be done here by using the
1917 // symbolic shape input params passed in to this call.
1918 TORCH_INTERNAL_ASSERT(
1919 tensorOutputSymbolicSizes_.size() == bufOutputs_.size());
1920
1921 TORCH_INTERNAL_ASSERT(sizes);
1922 TORCH_INTERNAL_ASSERT(strides);
1923 *sizes = tensorOutputSizes_;
1924 *strides = tensorOutputStrides_;
1925 auto& static_sizes = *sizes;
1926 auto& static_strides = *strides;
1927 for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
1928 static_sizes[i].clear();
1929 for (auto t : tensorOutputSymbolicSizes_[i]) {
1930 if (t.AsNode<LongImm>()) {
1931 static_sizes[i].emplace_back(immediateAs<int64_t>(t.node()));
1932 } else {
1933 auto input_pos = shapeSymbolInputPos_.at(t.node());
1934 TORCH_INTERNAL_ASSERT(input_pos < inputs.size());
1935 TORCH_INTERNAL_ASSERT(inputs[input_pos].isInt());
1936 static_sizes[i].emplace_back(inputs[input_pos].toInt());
1937 }
1938 }
1939
1940 if (tensorOutputStrideDesc_[i] == torch::jit::StrideInput::TENSOR_CONT) {
1941 static_strides[i] = TensorType::contiguousStridesOf(static_sizes[i]);
1942
1943 } else if (
1944 tensorOutputStrideDesc_[i] ==
1945 torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST) {
1946 static_strides[i] = at::get_channels_last_strides_2d(static_sizes[i]);
1947
1948 } else {
1949 std::string output_desc = toString(tensorOutputStrideDesc_[i]);
1950 TORCH_INTERNAL_ASSERT(
1951 false, "Expected contiguous or channels last, got ", output_desc);
1952 }
1953 }
1954 }
1955
prepareRunArgs(const at::ArrayRef<IValue> & inputs,std::vector<at::Tensor> & outputs) const1956 std::vector<CodeGen::CallArg> TensorExprKernel::prepareRunArgs(
1957 const at::ArrayRef<IValue>& inputs,
1958 std::vector<at::Tensor>& outputs) const {
1959 // TODO: preallocate `runArgs` during compilation and fill in values where
1960 // possible (e.g. for constant tensors)
1961 std::vector<CodeGen::CallArg> runArgs;
1962 runArgs.reserve(
1963 inputs.size() + input_stride_args_.size() + bufOutputs_.size());
1964
1965 for (auto& input : inputs) {
1966 if (input.isInt()) {
1967 runArgs.emplace_back(input.toInt());
1968 } else if (input.isBool()) {
1969 runArgs.emplace_back(input.toBool());
1970 } else if (input.isDouble()) {
1971 runArgs.emplace_back(input.toDouble());
1972 } else if (input.isTensor()) {
1973 runArgs.emplace_back(input.toTensor().data_ptr());
1974 }
1975 }
1976
1977 if (has_symbolic_shapes_) {
1978 std::vector<std::vector<int64_t>> static_sizes;
1979 std::vector<std::vector<int64_t>> static_strides;
1980 getStaticOutputSizesAndStrides(inputs, &static_sizes, &static_strides);
1981
1982 // add stride args
1983 for (const auto& input_stride_arg : input_stride_args_) {
1984 runArgs.emplace_back(
1985 inputs[input_stride_arg.first].toTensor().strides().at(
1986 input_stride_arg.second));
1987 }
1988
1989 for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
1990 auto const& opts = tensorOutputTensorOptions_[i];
1991 outputs.emplace_back(codegen_->empty_strided(
1992 static_sizes[i],
1993 static_strides[i],
1994 opts.dtype,
1995 opts.layout,
1996 opts.device,
1997 opts.pinned_memory));
1998 runArgs.emplace_back(outputs.back().data_ptr());
1999 }
2000 } else {
2001 for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
2002 auto const& opts = tensorOutputTensorOptions_[i];
2003 outputs.emplace_back(codegen_->empty_strided(
2004 tensorOutputSizes_[i],
2005 tensorOutputStrides_[i],
2006 opts.dtype,
2007 opts.layout,
2008 opts.device,
2009 opts.pinned_memory));
2010 runArgs.emplace_back(outputs.back().data_ptr());
2011 }
2012 }
2013
2014 for (const auto& c : constants_) {
2015 runArgs.emplace_back(c.ptr);
2016 }
2017
2018 return runArgs;
2019 }
2020
getCodeGenStmt()2021 StmtPtr TensorExprKernel::getCodeGenStmt() {
2022 return codegen_->stmt();
2023 }
2024
runKernel(Stack & stack) const2025 void TensorExprKernel::runKernel(Stack& stack) const {
2026 // Set up arguments (inputs, then outputs) for kernel call.
2027 auto inputs = last(stack, nInputs_);
2028 std::vector<at::Tensor> outputs;
2029
2030 std::vector<CodeGen::CallArg> runArgs = prepareRunArgs(inputs, outputs);
2031
2032 // Call the kernel.
2033 codegen_->call(runArgs);
2034
2035 // Update the stack.
2036 drop(stack, nInputs_);
2037
2038 int64_t idx = 0;
2039 for (auto& o : outputs) {
2040 if (isOutputScalar_[idx++]) {
2041 // Scalar outputs are returned as 0-dim tensors, we need to extract the
2042 // scalar value from them
2043 push_one(stack, o.item());
2044 } else {
2045 push_one(stack, std::move(o));
2046 }
2047 }
2048 }
2049
runFast(const std::vector<void * > & inputs,const std::vector<void * > & outputs) const2050 void TensorExprKernel::runFast(
2051 const std::vector<void*>& inputs,
2052 const std::vector<void*>& outputs) const {
2053 std::vector<void*> args(inputs);
2054 args.reserve(inputs.size() + outputs.size() + constants_.size());
2055 args.insert(args.end(), outputs.begin(), outputs.end());
2056
2057 // TODO: we can consider preallocating and pre-filling the args vector.
2058 for (const auto& c : constants_) {
2059 args.push_back(c.ptr);
2060 }
2061
2062 // Call the kernel.
2063 codegen_->call_raw(args);
2064 }
2065
runWithAllocatedOutputs(Stack & stack) const2066 void TensorExprKernel::runWithAllocatedOutputs(Stack& stack) const {
2067 TORCH_INTERNAL_ASSERT(
2068 device_ == at::kCPU,
2069 "Pre-allocated output tensors are supported only on CPUs.");
2070 std::vector<void*> args;
2071 args.reserve(nInputs_ + nOutputs_ + constants_.size());
2072
2073 // stack has inputs on the top and outputs right below them.
2074 auto stack_ivals = last(stack, nOutputs_ + nInputs_);
2075 auto stack_outputs = stack_ivals.slice(0, nOutputs_);
2076 auto stack_inputs = stack_ivals.slice(nOutputs_);
2077
2078 std::vector<int64_t> int_inputs(nInputs_);
2079 for (auto i : c10::irange(nInputs_)) {
2080 auto inp = stack_inputs[i];
2081 if (inp.isInt()) {
2082 int_inputs[i] = inp.toInt();
2083 args.emplace_back(&int_inputs[i]);
2084 } else if (inp.isTensor()) {
2085 args.emplace_back(inp.toTensor().data_ptr());
2086 } else {
2087 TORCH_INTERNAL_ASSERT(
2088 false, "Unhandled input type while calling TensorExprKernel");
2089 }
2090 }
2091
2092 std::vector<int64_t> stride_values(input_stride_args_.size());
2093 if (has_symbolic_shapes_) {
2094 std::vector<std::vector<int64_t>> static_sizes;
2095 std::vector<std::vector<int64_t>> static_strides;
2096 getStaticOutputSizesAndStrides(
2097 stack_inputs, &static_sizes, &static_strides);
2098
2099 // add stride args
2100 for (auto idx : c10::irange(input_stride_args_.size())) {
2101 const auto& input_stride_arg = input_stride_args_[idx];
2102 stride_values[idx] =
2103 stack_inputs[input_stride_arg.first].toTensor().strides().at(
2104 input_stride_arg.second);
2105 args.emplace_back(&stride_values[idx]);
2106 }
2107
2108 TORCH_INTERNAL_ASSERT(
2109 nOutputs_ == static_cast<int64_t>(bufOutputs_.size()));
2110 for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
2111 auto& out = stack_outputs[i].toTensor();
2112 // This has only been tested on CPUs.
2113 // TODO: Test on GPUs.
2114 out.resize_(static_sizes[i]);
2115 args.emplace_back(out.data_ptr());
2116 }
2117 } else {
2118 for (auto i : c10::irange(nOutputs_)) {
2119 args.emplace_back(stack_outputs[i].toTensor().data_ptr());
2120 }
2121 }
2122
2123 for (const auto& c : constants_) {
2124 args.emplace_back(c.ptr);
2125 }
2126
2127 // Call the kernel.
2128 codegen_->call_raw(args);
2129
2130 // Remove the inputs from the stack. The outputs are already below the inputs
2131 // in the stack.
2132 drop(stack, nInputs_);
2133 }
2134