xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/kernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/kernel.h>
2 
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorGeometry.h>
6 #include <c10/core/ScalarTypeToTypeMeta.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
10 #include <torch/csrc/jit/passes/mkldnn_rewrite.h>
11 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
12 #include <torch/csrc/jit/tensorexpr/analysis.h>
13 #include <torch/csrc/jit/tensorexpr/expr.h>
14 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
15 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
16 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
17 #include <torch/csrc/jit/tensorexpr/loopnest.h>
18 #include <torch/csrc/jit/tensorexpr/loopnest_randomization.h>
19 #include <torch/csrc/jit/tensorexpr/operators/operators.h>
20 
21 #include <utility>
22 
23 using namespace torch::jit;
24 using namespace torch::jit::tensorexpr;
25 
26 namespace torch::jit::tensorexpr {
27 
buildErrorMessage(const std::string & s)28 std::string buildErrorMessage(const std::string& s) {
29   static const std::string generic_error_message =
30       "This error occurred in the fuser. You can turn off the fuser with "
31       "torch.jit.enable_fusion(False).";
32   if (s.empty()) {
33     return generic_error_message;
34   }
35   if (s.back() == '.') {
36     return s + " " + generic_error_message;
37   }
38   return s + ". " + generic_error_message;
39 }
40 
41 static int te_cuda_pointwise_loop_levels = -1;
42 static int te_cuda_pointwise_block_count = -1;
43 static int te_cuda_pointwise_block_size = -1;
44 static bool fallback_allowed = false;
45 static bool te_generate_block_code = false;
46 static bool te_must_use_llvm_on_cpu = true;
47 static bool cat_wo_conditionals = true;
48 static bool opt_conditionals = false;
49 
setFallbackAllowed(bool value)50 bool setFallbackAllowed(bool value) {
51   bool old_value = fallback_allowed;
52   fallback_allowed = value;
53   return old_value;
54 }
55 
fallbackAllowed()56 bool fallbackAllowed() {
57   static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR_FALLBACK");
58   if (!enable_c_str) {
59     return fallback_allowed;
60   }
61   if (std::string(enable_c_str) == "0") {
62     return false;
63   }
64   return true;
65 }
66 
fallbackEnforced()67 static bool fallbackEnforced() {
68   static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR_FALLBACK");
69   if (tensorexpr::getTEGenerateBlockCode()) {
70     return false;
71   }
72   if (!enable_c_str) {
73     return fallback_allowed;
74   }
75   if (std::string(enable_c_str) == "2") {
76     return true;
77   }
78   return false;
79 }
80 
randomTransformsRequested()81 static int64_t randomTransformsRequested() {
82   const char* enable_c_str =
83       std::getenv("PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED");
84   if (!enable_c_str) {
85     return 0;
86   }
87   return std::stoi(std::string(enable_c_str));
88 }
89 
90 #ifdef TORCH_ENABLE_LLVM
dontUseLLVMFlag()91 static bool dontUseLLVMFlag() {
92   static const char* enable_c_str =
93       std::getenv("PYTORCH_TENSOREXPR_DONT_USE_LLVM");
94   if (!enable_c_str) {
95     return false;
96   }
97   return std::string(enable_c_str) == "1";
98 }
99 #endif
100 
getTECudaPointwiseLoopLevels()101 int& getTECudaPointwiseLoopLevels() {
102   return te_cuda_pointwise_loop_levels;
103 }
104 
getTECudaPointwiseBlockCount()105 int& getTECudaPointwiseBlockCount() {
106   return te_cuda_pointwise_block_count;
107 }
108 
getTECudaPointwiseBlockSize()109 int& getTECudaPointwiseBlockSize() {
110   return te_cuda_pointwise_block_size;
111 }
112 
113 // TODO: Remove this global var
114 // Ideally Block code gen should be decided
115 // based on device type in tensor.
getTEGenerateBlockCode()116 bool& getTEGenerateBlockCode() {
117   return te_generate_block_code;
118 }
119 
getTEMustUseLLVMOnCPU()120 bool& getTEMustUseLLVMOnCPU() {
121   return te_must_use_llvm_on_cpu;
122 }
123 
getCatWoConditionals()124 bool& getCatWoConditionals() {
125   return cat_wo_conditionals;
126 }
127 
getOptConditionals()128 bool& getOptConditionals() {
129   return opt_conditionals;
130 }
131 
pickDeviceType(const at::ArrayRef<torch::jit::Value * > & inputs)132 std::optional<at::Device> pickDeviceType(
133     const at::ArrayRef<torch::jit::Value*>& inputs) {
134   std::optional<at::Device> device = std::nullopt;
135   for (auto const& input : inputs) {
136     auto tt = input->type()->cast<TensorType>();
137     if (tt && tt->device()) {
138       if (device && *device != *tt->device()) {
139         return std::nullopt;
140       }
141       device = *tt->device();
142     }
143   }
144   return device;
145 }
146 
pickDeviceType(const std::shared_ptr<Graph> & graph)147 static std::optional<at::Device> pickDeviceType(
148     const std::shared_ptr<Graph>& graph) {
149   std::optional<at::Device> device = std::nullopt;
150   for (auto const& node : graph->nodes()) {
151     for (auto const& input : node->inputs()) {
152       if (auto tt = input->type()->cast<TensorType>()) {
153         if (auto inputDevice = tt->device()) {
154           TORCH_INTERNAL_ASSERT(
155               !device || *device == *inputDevice,
156               buildErrorMessage(
157                   "Different devices specified for inputs to the fuser."));
158           device = inputDevice;
159         }
160       }
161     }
162   }
163   for (auto const& input : graph->inputs()) {
164     if (auto tt = input->type()->cast<TensorType>()) {
165       if (auto inputDevice = tt->device()) {
166         TORCH_INTERNAL_ASSERT(
167             !device || *device == *inputDevice,
168             buildErrorMessage(
169                 "Different devices specified for inputs to the fuser."));
170         device = inputDevice;
171       }
172     }
173   }
174   if (!device) {
175     // By default assume the device is CPU
176     device = at::kCPU;
177   }
178   return device;
179 }
180 
181 // If v is a Tensor with concretely-known sizes and dtype, return them, else
182 // nullopt.
getTensorInfoJit(torch::jit::Value * v)183 static std::optional<TensorInfo> getTensorInfoJit(torch::jit::Value* v) {
184   auto const& it = v->type()->cast<TensorType>();
185 
186   c10::ScalarType dtype = c10::ScalarType::Float;
187 
188   if (!it) {
189     return std::nullopt;
190   }
191   if (!it->isComplete()) {
192     return std::nullopt;
193   }
194   if (it->scalarType()) {
195     // TODO: ideally we should be strict here and return nullopt if the dtype is
196     // absent in the JIT IR. We're assuming a default Float dtype for now, until
197     // dtype propagation is implemented.
198     dtype = *it->scalarType();
199   }
200   auto concrete_sizes = it->sizes().concrete_sizes();
201   if (!concrete_sizes) {
202     return std::nullopt;
203   }
204   return TensorInfo{*concrete_sizes, dtype};
205 }
_pair_int(const IValue & v)206 static std::vector<int64_t> _pair_int(const IValue& v) {
207   if (v.isIntList()) {
208     return v.toIntVector();
209   } else {
210     return {v.toInt(), v.toInt()};
211   }
212 }
213 
isContiguous(const torch::jit::Value * v,at::MemoryFormat memory_format)214 bool isContiguous(const torch::jit::Value* v, at::MemoryFormat memory_format) {
215   auto const& tt = v->type()->cast<TensorType>();
216   if (!tt) {
217     return false;
218   }
219   if (!tt->isComplete()) {
220     return false;
221   }
222   auto const& sizes = tt->sizes().concrete_sizes();
223   auto const& strides = tt->strides().concrete_sizes();
224   if (!sizes || !strides) {
225     return false;
226   }
227 
228   // Check dimension size first
229   int ndims = (*sizes).size();
230   if ((memory_format == at::MemoryFormat::ChannelsLast && ndims != 4) ||
231       (memory_format == at::MemoryFormat::ChannelsLast3d && ndims != 5)) {
232     return false;
233   }
234 
235   return *strides == TensorType::contiguousStridesOf(*sizes, memory_format);
236 }
237 
get_conv_groups_index(const torch::jit::Node * node)238 static size_t get_conv_groups_index(const torch::jit::Node* node) {
239   switch (node->kind()) {
240     case aten::conv2d:
241       return 6;
242     case aten::_convolution:
243       return 8;
244     default:
245       TORCH_CHECK(
246           false,
247           "mkldnnPrepackedConvIsSupportedJit expects node kind to be conv2d or _convolution but got ",
248           node->kind());
249   }
250 }
251 
252 // The fuser only supports conv2d with very specific properties:
253 // - Static shapes: 4-d input and filter, 1-d bias.
254 // - Constant strides/padding/dilation/groups
255 // - Equal padding and strides, dilation == 1.
256 // - Depthwise (groups == in_channels == out_channels)
257 // - 3x3 kernel
conv2dIsSupportedJit(const torch::jit::Node * node)258 bool conv2dIsSupportedJit(const torch::jit::Node* node) {
259   auto const& input = getTensorInfoJit(node->input(0));
260   auto const& weight = getTensorInfoJit(node->input(1));
261   auto const& bias = getTensorInfoJit(node->input(2));
262   auto const& stride = toIValue(node->input(3));
263   auto const& pad = toIValue(node->input(4));
264   auto const& dilation = toIValue(node->input(5));
265   size_t groups_index = get_conv_groups_index(node);
266   auto const& groups = toIValue(node->input(groups_index));
267 
268   // Everything should be statically known.
269   if (!input || !weight || !bias || !stride || !pad || !dilation || !groups) {
270     GRAPH_DEBUG("some params aren't static");
271     return false;
272   }
273 
274   // All inputs should be contiguous so no transposition is required.
275   if (!isContiguous(node->input(0)) || !isContiguous(node->input(1)) ||
276       !isContiguous(node->input(2))) {
277     GRAPH_DEBUG("conv2dIsSupported: some inputs are not contiguous");
278     return false;
279   }
280 
281   return conv2dIsSupported(
282       *input,
283       *weight,
284       *bias,
285       _pair_int(*stride),
286       _pair_int(*pad),
287       _pair_int(*dilation),
288       groups->toInt());
289 }
290 
mkldnnPrepackedConvIsSupportedJit(const torch::jit::Node * node)291 bool mkldnnPrepackedConvIsSupportedJit(const torch::jit::Node* node) {
292 #if AT_MKLDNN_ENABLED()
293   auto const& input = getTensorInfoJit(node->input(0));
294   auto const& weight = getTensorInfoJit(node->input(1));
295   auto const& stride = toIValue(node->input(3));
296   auto const& pad = toIValue(node->input(4));
297   auto const& dilation = toIValue(node->input(5));
298   size_t groups_index = get_conv_groups_index(node);
299   auto const& groups = toIValue(node->input(groups_index));
300 
301   // Everything should be statically known (bias could be NoneType =
302   // prim::Constant()).
303   if (!input || !weight || !stride || !pad || !dilation || !groups) {
304     GRAPH_DEBUG("some params aren't static");
305     return false;
306   }
307 
308   // Weights and bias should be Constant when using mkldnn backend
309   if (node->input(1)->node()->kind() != prim::Constant ||
310       node->input(2)->node()->kind() != prim::Constant) {
311     GRAPH_DEBUG(
312         "mkldnnPrepackedConvIsSupported: weight or bias is not Constant");
313     return false;
314   }
315 
316   // Input and weight should be NHWC contiguous.
317   if (!(isContiguous(node->input(0), at::MemoryFormat::ChannelsLast) &&
318         isContiguous(node->input(1), at::MemoryFormat::ChannelsLast))) {
319     GRAPH_DEBUG(
320         "mkldnnPrepackedConvIsSupported: input or weight is not ChannelsLast contiguous");
321     return false;
322   }
323 
324   return mkldnnPrepackedConvIsSupported(
325       *input,
326       *weight,
327       _pair_int(*stride),
328       _pair_int(*pad),
329       _pair_int(*dilation),
330       groups->toInt());
331 #endif
332   return false;
333 }
334 
isConv2d(const Node * node)335 bool isConv2d(const Node* node) {
336   if (node->kind() != aten::_convolution) {
337     return false;
338   }
339 
340   auto const& stride = toIValue(node->input(3));
341   auto const& pad = toIValue(node->input(4));
342   auto const& dilation = toIValue(node->input(5));
343   auto const& transposed = toIValue(node->input(6));
344   auto const& output_padding = toIValue(node->input(7));
345 
346   if (!stride || !pad || !dilation || !transposed || !output_padding) {
347     GRAPH_DEBUG("some params aren't static");
348     return false;
349   }
350 
351   if (stride.value().toIntList().size() != 2 ||
352       pad.value().toIntList().size() != 2 ||
353       dilation.value().toIntList().size() != 2 ||
354       output_padding.value().toIntList().size() != 2) {
355     GRAPH_DEBUG("Conv not 2d");
356     return false;
357   }
358 
359   if (transposed.value().toBool()) {
360     GRAPH_DEBUG("transposed Conv");
361     return false;
362   }
363   return true;
364 }
365 
366 // The fuser currently only supports matmul of 2D x 2D matrices
matmulIsSupported(const torch::jit::Node * node)367 bool matmulIsSupported(const torch::jit::Node* node) {
368   auto const& input0 = getTensorInfoJit(node->input(0));
369   auto const& input1 = getTensorInfoJit(node->input(1));
370 
371   // Everything should be statically known.
372   if (!input0 || !input1) {
373     GRAPH_DEBUG("matmulIsSupported: Input shapes aren't static");
374     return false;
375   }
376 
377   // Proper ndim for tensor inputs.
378   if (input0->dims.size() != 2 || input1->dims.size() != 2) {
379     GRAPH_DEBUG("matmulIsSupported: Unsupported input sizes");
380     return false;
381   }
382 
383   // Inputs should be contiguous, or the TE will needlessly transpose them.
384   if (!isContiguous(node->input(0)) || !isContiguous(node->input(1))) {
385     GRAPH_DEBUG("matmulIsSupported: Input shapes are not contiguous");
386     return false;
387   }
388 
389   return true;
390 }
391 
392 } // namespace torch::jit::tensorexpr
393 
tensorType(const BufPtr & b)394 static at::ScalarType tensorType(const BufPtr& b) {
395   return static_cast<at::ScalarType>(b->dtype().scalar_type());
396 }
397 
constant(const torch::jit::Value * v)398 ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) {
399   if (v->node()->kind() == prim::Constant) {
400     auto val = toIValue(v).value();
401     if (val.isDouble()) {
402       return DoubleImm::make(val.toDouble());
403     } else if (val.isInt()) {
404       return LongImm::make(val.toInt());
405     } else if (val.isBool()) {
406       return BoolImm::make(val.toBool());
407     } else if (val.isNone()) {
408       // This is just a placeholder so we don't throw.  None-handling
409       // is operator-specific and should be handled properly in
410       // the operator-specific lowering code.
411       return IntImm::make(0);
412     } else {
413       throw unsupported_dtype();
414     }
415   }
416 
417   if (!scalars_.count(v)) {
418     throw malformed_input("no scalar in Constant");
419   }
420 
421   return scalars_.at(v);
422 }
423 
toArg(const torch::jit::Value * v) const424 ArgValue TensorExprKernel::toArg(const torch::jit::Value* v) const {
425   auto vi = scalars_.find(v);
426   if (vi != scalars_.end()) {
427     return VarHandle(vi->second);
428   }
429   auto ti = bufs_.find(v);
430   if (ti != bufs_.end()) {
431     return BufHandle(ti->second);
432   }
433   if (v->node()->kind() == prim::ListConstruct) {
434     std::vector<ArgValue> vec;
435     for (auto el : v->node()->inputs()) {
436       vec.push_back(toArg(el));
437     }
438     if (vec.empty()) {
439       return BufList(); // Return arbitrarily typed vector
440     } else if (std::get_if<BufHandle>(&vec[0])) {
441       return convertVecArgValue<BufHandle>(vec);
442     } else if (std::get_if<int64_t>(&vec[0])) {
443       return convertVecArgValue<int64_t>(vec);
444     }
445     throw unsupported_dtype();
446   }
447   if (v->node()->kind() == prim::Constant) {
448     auto val = toIValue(v).value();
449     if (val.isDouble()) {
450       return val.toDouble();
451     } else if (val.isInt()) {
452       return val.toInt();
453     } else if (val.isBool()) {
454       return val.toBool();
455     } else if (val.isNone()) {
456       // This is just a placeholder so we don't throw.  None-handling
457       // is operator-specific and should be handled properly in
458       // the operator-specific lowering code.
459       return ArgNone();
460     } else if (val.isIntList()) {
461       return val.toIntVector();
462     } else if (val.isDoubleList()) {
463       return val.toDoubleVector();
464     } else if (val.isString()) {
465       return val.toStringRef();
466     } else {
467       throw unsupported_dtype(val.type()->str());
468     }
469   }
470 
471   if (!scalars_.count(v)) {
472     throw malformed_input("no scalar in Constant");
473   }
474   return scalars_.at(v);
475 }
476 
getVarForShape(const c10::ShapeSymbol & ss)477 ExprHandle TensorExprKernel::getVarForShape(const c10::ShapeSymbol& ss) {
478   if (ss.is_static()) {
479     return LongImm::make(ss.static_size());
480   }
481   auto value = ss.value();
482   auto it = shapeSymbolToVar_.find(value);
483   if (it == shapeSymbolToVar_.end()) {
484     VarHandle var("ss" + std::to_string(-value), kLong);
485     shapeSymbolToVar_.emplace(value, var);
486     return std::move(var);
487   }
488   return it->second;
489 }
490 
sizesFromSymbolicShape(const c10::SymbolicShape & shape)491 std::vector<ExprHandle> TensorExprKernel::sizesFromSymbolicShape(
492     const c10::SymbolicShape& shape) {
493   std::vector<ExprHandle> dims;
494   auto maybe_rank = shape.rank();
495   TORCH_INTERNAL_ASSERT(maybe_rank);
496   auto rank = *maybe_rank;
497   for (const auto i : c10::irange(rank)) {
498     dims.push_back(getVarForShape(shape[i]));
499   }
500   return dims;
501 }
502 
sizesForValue(const torch::jit::Value * v)503 std::vector<ExprHandle> TensorExprKernel::sizesForValue(
504     const torch::jit::Value* v) {
505   if (known_sizes_.count(v)) {
506     return known_sizes_.at(v);
507   }
508 
509   // If the shape is present in the type info, just extract it from here. No
510   // need to infer it.
511   if (v->type()->kind() == TypeKind::TensorType) {
512     auto tt = v->type()->cast<TensorType>();
513     return sizesFromSymbolicShape(tt->symbolic_sizes());
514   }
515 
516   if (v->type()->isSubtypeOf(*FloatType::get()) ||
517       v->type()->isSubtypeOf(*BoolType::get()) ||
518       v->type()->isSubtypeOf(*IntType::get())) {
519     return {};
520   }
521   if (v->type()->isSubtypeOf(*NoneType::get())) {
522     return {};
523   }
524   GRAPH_DEBUG("Unknown sizes for the node: ", *v->node());
525   GRAPH_DEBUG("Full fusion group graph:\n", *v->node()->owningGraph());
526   std::string msg = std::string("Unhandled node kind (in sizesForValue): ") +
527       v->node()->kind().toQualString();
528   throw malformed_input(msg);
529 }
530 
findDtypeForValue(const torch::jit::Value * v)531 static std::optional<ScalarType> findDtypeForValue(const torch::jit::Value* v) {
532   if (v->type()->kind() == TypeKind::TensorType) {
533     auto tt = v->type()->cast<TensorType>();
534     if (tt->scalarType()) {
535       return static_cast<ScalarType>(*tt->scalarType());
536     }
537   }
538   return tryScalarTypeFromJitType(*v->type());
539 }
540 
constZeroDimTensorAsScalarArg(const Value * v,std::vector<ArgValue> & args)541 static bool constZeroDimTensorAsScalarArg(
542     const Value* v,
543     std::vector<ArgValue>& args) {
544   if (v->node()->kind() != prim::Constant || !v->type()->cast<TensorType>()) {
545     return false;
546   }
547 
548   const auto t = toIValue(v)->toTensor();
549   if (!t.sizes().empty()) {
550     return false;
551   }
552 
553   c10::ScalarType dtype = c10::typeMetaToScalarType(t.dtype());
554   switch (dtype) {
555     case ScalarType::Float:
556       args.emplace_back(t.item().toFloat());
557       return true;
558     case ScalarType::Long:
559       args.emplace_back(t.item().toLong());
560       return true;
561     default:
562       std::stringstream ss;
563       ss << "Unsupported tensor dtype:" << dtype
564          << " for converting constant 0-dim Tensor to scalar" << '\n';
565       throw unsupported_dtype(ss.str());
566   }
567 }
568 
computeValue(const torch::jit::Value * v)569 Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
570   auto inputs = v->node()->inputs();
571   auto op = v->node()->kind();
572 
573   if (op == aten::rand_like) {
574     hasRandom_ = true;
575   }
576 
577   auto outputType = findDtypeForValue(v);
578   std::vector<ExprHandle> outputShape = sizesForValue(v);
579   std::vector<ExprHandle> outputStrides = {};
580   if (memory_layout_policy_ == MemoryLayoutPolicy::kChannelsLastNdContiguous) {
581     outputStrides =
582         c10::fmap<ExprHandle>(make_channels_last_strides(outputShape));
583   } else {
584     // Default
585     outputStrides = c10::fmap<ExprHandle>(make_contiguous_strides(outputShape));
586   }
587 
588   std::vector<ArgValue> argInputs;
589   if (op == prim::ConstantChunk) {
590     auto const& n = v->node();
591     argInputs.emplace_back(toArg(inputs[0]));
592     argInputs.emplace_back(static_cast<int64_t>(v->offset()));
593     argInputs.emplace_back(n->i(attr::dim));
594     argInputs.emplace_back(n->i(attr::chunks));
595   } else if (op == aten::to) {
596     argInputs.emplace_back(toArg(inputs[0]));
597   } else if (op == aten::quantize_per_tensor) {
598     argInputs.emplace_back(toArg(inputs[0]));
599     if (!constZeroDimTensorAsScalarArg(inputs[1], argInputs)) {
600       argInputs.emplace_back(toArg(inputs[1]));
601     }
602     if (!constZeroDimTensorAsScalarArg(inputs[2], argInputs)) {
603       argInputs.emplace_back(toArg(inputs[2]));
604     }
605     argInputs.emplace_back(toArg(inputs[3]));
606   } else if (op == aten::conv2d) {
607     for (auto inp : inputs) {
608       argInputs.emplace_back(toArg(inp));
609     }
610     // handle optional bias
611     if (std::get_if<ArgNone>(&argInputs[2])) {
612       Dtype dtype = outputType ? Dtype(*outputType) : kFloat;
613       std::vector<ExprHandle> biasShape;
614       biasShape.push_back(outputShape[1]);
615       auto bias_tensor = at::zeros({outputShape[1].AsNode<LongImm>()->value()});
616       unpacked_constant_tensors_.push_back(bias_tensor);
617       BufPtr buf = alloc<Buf>(
618           "conv2d_bias_opt_" + sanitizeName(v->debugName()),
619           ExprHandleVectorToExprVector(biasShape),
620           dtype);
621       constants_.push_back({buf, bias_tensor.data_ptr()});
622       argInputs[2] = BufHandle(buf);
623     }
624   } else {
625     for (auto inp : inputs) {
626       argInputs.emplace_back(toArg(inp));
627     }
628   }
629 
630   if (NNCLoweringFunction custom_lowering = getCustomLoweringFor(op)) {
631     return custom_lowering(
632         argInputs, outputShape, outputStrides, outputType, device_);
633   }
634   if (v->node()->maybeSchema()) {
635     if (NNCLoweringFunction lowering =
636             getStandardLoweringFor(c10::toString(v->node()->schema()))) {
637       return lowering(
638           argInputs, outputShape, outputStrides, outputType, device_);
639     }
640   }
641   std::string msg = std::string("Unhandled node kind (in computeValue): ") +
642       op.toQualString();
643   if (v->node()->maybeSchema()) {
644     msg += std::string("\nSchema: ") + c10::toString(v->node()->schema());
645   }
646   throw malformed_input(msg);
647 }
648 
649 // True if all the loops in this vector have equal bounds.
loopBoundsAllEqual(const std::vector<ForPtr> & loops)650 static bool loopBoundsAllEqual(const std::vector<ForPtr>& loops) {
651   if (loops.size() <= 1) {
652     return true;
653   }
654   const auto& start = loops.front()->start();
655   const auto& stop = loops.front()->stop();
656   for (size_t i = 1; i < loops.size(); ++i) {
657     const auto& curr_start = loops[i]->start();
658     const auto& curr_stop = loops[i]->stop();
659     if (!exprEquals(start, curr_start) || !exprEquals(stop, curr_stop)) {
660       return false;
661     }
662   }
663   return true;
664 }
665 
666 // Recursively fuse all the loops with matching bounds in `st`.  Stops fusing
667 // at any level containing non-loops or non-matching bounds.  The restriction
668 // on matching bounds exists to avoid inserting conditionals on the loop
669 // indices where none would be needed, which would significantly complicate
670 // vectorization.
fuseAllLoops(const StmtPtr & st)671 static void fuseAllLoops(const StmtPtr& st) {
672   auto block = to<tensorexpr::Block>(st);
673   if (block == nullptr) {
674     return;
675   }
676 
677   std::vector<std::vector<ForPtr>> all_outer_loops;
678   std::vector<ForPtr> outer_loops;
679   for (const auto& stmt : *block) {
680     auto loop = to<For>(stmt);
681     auto hasReduction = !NodeFinder<ReduceOp>::find(stmt).empty();
682     if (!loop || hasReduction) {
683       all_outer_loops.push_back(outer_loops);
684       outer_loops.clear();
685     } else {
686       outer_loops.push_back(loop);
687     }
688   }
689   all_outer_loops.push_back(outer_loops);
690 
691   for (const auto& outer_loops : all_outer_loops) {
692     if (outer_loops.empty()) {
693       continue;
694     }
695 
696     if (!loopBoundsAllEqual(outer_loops)) {
697       continue;
698     }
699 
700     ForPtr fusedLoop;
701     if (!LoopNest::fuseLoops(outer_loops, &fusedLoop)) {
702       continue;
703     }
704 
705     fuseAllLoops(fusedLoop->body());
706   }
707 }
708 
709 // Compute the trip count of a loop if it is a constant.
tripCount(const ForPtr & loop)710 static std::optional<int64_t> tripCount(const ForPtr& loop) {
711   auto tc = IRSimplifier::simplify(
712       cast<int64_t>(ExprHandle(loop->stop()) - ExprHandle(loop->start())));
713   if (auto val = to<LongImm>(tc.node())) {
714     return val->value();
715   }
716   return std::nullopt;
717 }
718 
719 // Prune innermost loops until iterations satisfies a minimum grain size.
pruneByGrainSize(std::vector<ForPtr> & loops)720 static void pruneByGrainSize(std::vector<ForPtr>& loops) {
721   constexpr int64_t minGrainSize = 32768;
722   int64_t grainSize = 1;
723   for (int64_t i = loops.size(); i > 0; i--) {
724     auto tc = tripCount(loops[i - 1]);
725     if (!tc) {
726       break;
727     }
728     grainSize *= *tc;
729     if (grainSize < minGrainSize) {
730       loops.pop_back();
731     }
732   }
733 }
734 
735 // Retain enough outermost loops to fill the number of threads.
pruneByThreadCount(std::vector<ForPtr> & loops)736 static void pruneByThreadCount(std::vector<ForPtr>& loops) {
737   int64_t trips = 1;
738   auto threads = at::get_num_threads();
739   auto it = loops.begin();
740   for (; it != loops.end(); it++) {
741     if (trips >= threads) {
742       break;
743     }
744     auto tc = tripCount(*it);
745     if (!tc) {
746       break;
747     }
748     trips *= *tc;
749   }
750   loops.erase(it, loops.end());
751 }
752 
753 // Flatten and parallelize outer loops, subject to a minimum number of elements
754 // in the inner loop, and a maximum level of thread-level parallelism in the
755 // outer loops.
756 template <typename Bufs>
parallelizeOuterLoops(LoopNest & l,Bufs && bufs)757 static void parallelizeOuterLoops(LoopNest& l, Bufs&& bufs) {
758   for (auto const& buf : bufs) {
759     auto loops = l.getLoopStmtsFor(buf);
760     pruneByGrainSize(loops);
761     pruneByThreadCount(loops);
762 
763     // There are no loops to parallelize; give up.
764     if (loops.size() == 0) {
765       continue;
766     }
767     // The loop nest contains a reduction; give up.
768     auto reductions = NodeFinder<ReduceOp>::find(loops[0]);
769     if (reductions.size() > 0) {
770       continue;
771     }
772     // The loop nest has loop carried dependences; give up.
773     if (LoopNest::hasLoopCarriedDependence(loops[0])) {
774       continue;
775     }
776     // Try to flatten the outer loops and parallelize them if successful.
777     ForPtr flattened = nullptr;
778     if (loops.size() == 1) {
779       flattened = loops[0];
780     } else {
781       LoopNest::flatten(loops, &flattened);
782     }
783     if (flattened) {
784       flattened->set_parallel();
785     }
786   }
787 }
788 
transformLoops(BackendType backendType,StmtPtr st)789 StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
790   torch::jit::tensorexpr::LoopNest l(std::move(st), bufOutputs_);
791   LoopNest::sanitizeNames(l.root_stmt());
792   GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n");
793   int64_t random_tr_seed = randomTransformsRequested();
794   if (random_tr_seed) {
795     if (random_tr_seed == -1)
796       random_tr_seed = std::time(nullptr);
797     loopnestRandomization(random_tr_seed, l);
798     GRAPH_DEBUG(
799         "After random transform:\n", std::to_string(l.root_stmt()), "\n");
800   }
801 
802   bool hasReduction = !NodeFinder<ReduceOp>::find(l.root_stmt()).empty();
803 
804   // For Block codegen we create a map of tensor dims before
805   // inlining. Like GPU codegen we need to inline. But the order
806   // where this analysis is run matters.
807   auto block_analysis = std::make_unique<CreateBufferMap>();
808   if (backendType == kBlockCodeGen) {
809     // Run Block analysis to get multi dim buffer info
810     auto root_stmt = l.root_stmt();
811     root_stmt->accept(block_analysis.get());
812   }
813   l.simplify();
814   GRAPH_DEBUG("after simplify", *l.root_stmt());
815 
816   // Inlining output & intermediate buffers can duplicate computation.
817   // Duplicating work can slow down the program if it's not ameliorated in some
818   // way, but we've empirically found that:
819   // - On CPU, LLVM's CSE does a good job as long as you horizontally fuse
820   //   output loops.
821   // - On GPU, there's enough compute to hide the extra work, and inlining
822   //   avoids synchronizing between kernels.
823   l.inlineIntermediateBufs(/*allow_duplicated_work=*/true);
824   GRAPH_DEBUG("after inline", *l.root_stmt());
825 
826   // Optimizing conditionals needs to be performed after inlining because
827   // inlining wouldn't work once the loops are split. Also, it has to be
828   // performed before loop fusion because loop fusion introduces cases where
829   // multiple conditionals are in the same loop and this optimization does not
830   // handle such cases yet.
831   if (getOptConditionals()) {
832     l.optimizeConditionals();
833     GRAPH_DEBUG("after optimizing conditionals: ", *l.root_stmt());
834   }
835 
836   // Fuse loops "horizontally".  This pass allows us to combine loops that
837   // write to different output buffers, as long as they have the same bounds.
838   if (backendType == kLLVMCodeGen) {
839     fuseAllLoops(l.root_stmt());
840     GRAPH_DEBUG("after fuse", *l.root_stmt());
841     parallelizeOuterLoops(l, bufsToBeParallelized_);
842     GRAPH_DEBUG("after parallelize", *l.root_stmt());
843   }
844 
845   if (backendType == kCudaCodeGen) {
846     for (const auto& buf : bufOutputs_) {
847       std::vector<ForPtr> loops = l.getLoopStmtsFor(buf);
848       if (loops.empty()) {
849         // This happens when Buf is 0-dim
850         continue;
851       }
852       ForPtr flattened = nullptr;
853       LoopNest::flatten(loops, &flattened);
854       assert(flattened);
855 
856       int loopLevels = getTECudaPointwiseLoopLevels();
857       const int kDefaultLoopLevels = 2;
858       loopLevels = (loopLevels > 0) ? loopLevels : kDefaultLoopLevels;
859       int blockCount = getTECudaPointwiseBlockCount();
860       int blockSize = getTECudaPointwiseBlockSize();
861 
862       if (loopLevels == 2) {
863         ForPtr inner;
864         const int kDefaultBlockSize = 512;
865         if (blockSize < 0) {
866           blockSize = kDefaultBlockSize;
867         }
868         LoopNest::splitWithMask(flattened, blockSize, &inner);
869         flattened->set_gpu_block_index(0);
870         inner->set_gpu_thread_index(0);
871       } else if (loopLevels == 3) {
872         ForPtr inner;
873         ForPtr inner1;
874         // TODO: change the number of microprocessors
875         const int kDefaultBlockCount = 1280;
876         const int kDefaultBlockSize = 256;
877         blockCount = (blockCount > 0) ? blockCount : kDefaultBlockCount;
878         blockSize = (blockSize > 0) ? blockSize : kDefaultBlockSize;
879         LoopNest::splitWithMask(flattened, blockCount * blockSize, &inner);
880         LoopNest::splitWithMask(inner, blockSize, &inner1);
881         inner->set_gpu_block_index(0);
882         inner1->set_gpu_thread_index(0);
883       } else {
884         throw std::runtime_error(
885             "Invalid loop-level: " + std::to_string(loopLevels));
886       }
887     }
888   }
889 
890   if (backendType == kBlockCodeGen) {
891     for (const auto& buf : bufOutputs_) {
892       const int default_fp16_blocksize = 16;
893       const int default_uint8_blocksize = 32;
894       int blockSize = default_fp16_blocksize;
895       // We only handle looplevels == 2 for now
896       if (buf->dtype().scalar_type() == ScalarType::Byte) {
897         blockSize = default_uint8_blocksize;
898       }
899       std::vector<ForPtr> loops = l.getLoopStmtsFor(buf);
900       TORCH_INTERNAL_ASSERT(
901           !loops.empty(),
902           buildErrorMessage(
903               "No loops found for the buffer " + buf->name_hint() +
904               " in the fuser."));
905       ForPtr flattened = nullptr;
906       LoopNest::flatten(loops, &flattened);
907       assert(flattened);
908 
909       ForPtr inner = nullptr;
910       LoopNest::splitWithMask(flattened, blockSize, &inner);
911       flattened->set_gpu_block_index(0);
912       inner->set_gpu_thread_index(0);
913       flattened->set_buffer_map(block_analysis->getBufferMap());
914     }
915   }
916 
917   if (pre_alloc_) {
918     auto interm_bufs = l.getIntermediateBufs();
919     preAllocIntermediateBufs(interm_bufs);
920   }
921 
922   l.prepareForCodegen();
923 
924   GRAPH_DEBUG("after prepareForCodegen", *l.root_stmt());
925   l.simplify();
926   GRAPH_DEBUG("after simplification", *l.root_stmt());
927 
928   if (backendType == kLLVMCodeGen && !hasReduction) {
929     l.vectorizeInnerLoops();
930     GRAPH_DEBUG("after vectorization", *l.root_stmt());
931   }
932 
933   StmtPtr stmt = l.root_stmt();
934   // Arithmetic Simplification.
935   stmt = IRSimplifier::simplify(stmt);
936   GRAPH_DEBUG("Final Stmt:\n", std::to_string(stmt), "\n");
937   return stmt;
938 }
939 
getCodeGenName(BackendType backendType)940 std::string TensorExprKernel::getCodeGenName(BackendType backendType) {
941   switch (backendType) {
942     case kCudaCodeGen:
943       return "cuda_codegen";
944     case kLLVMCodeGen:
945       return "llvm_codegen";
946     case kSimpleIREval:
947       return "simple_ir_eval";
948     case kBlockCodeGen:
949       return "block_codegen";
950     default:
951       throw std::runtime_error(
952           "invalid backend type: " +
953           std::to_string(static_cast<int>(backendType)));
954   }
955 }
956 
957 template <typename T>
isValidPrimProperty(const std::optional<T> & a,T b)958 static bool isValidPrimProperty(const std::optional<T>& a, T b) {
959   return !a.has_value() || *a == b;
960 }
961 
inferBackendTypeFromDevice(at::Device device)962 TensorExprKernel::BackendType TensorExprKernel::inferBackendTypeFromDevice(
963     at::Device device) {
964   BackendType backendType = BackendType::kUninitialized;
965   if (device.type() == at::kCUDA) {
966     backendType = kCudaCodeGen;
967   } else if (device.type() == at::kCPU && getTEGenerateBlockCode()) {
968     backendType = kBlockCodeGen;
969   } else if (device.type() == at::kCPU) {
970 #ifdef TORCH_ENABLE_LLVM
971     backendType = dontUseLLVMFlag() ? kSimpleIREval : kLLVMCodeGen;
972 #else
973     backendType = kSimpleIREval;
974 #endif
975     if (getTEMustUseLLVMOnCPU() && backendType == kSimpleIREval) {
976       throw std::runtime_error("LLVM Backend not found");
977     }
978   } else {
979     throw std::runtime_error("Invalid device type");
980   }
981   return backendType;
982 }
983 
984 // we use the debug names in printing cuda code, they need to be removed
985 // of characters that can't be used in a variable identifier
genInputDebugNames()986 void TensorExprKernel::genInputDebugNames() {
987   std::unordered_map<std::string, const torch::jit::Value*> name_to_value;
988   std::unordered_set<std::string> name_set;
989   std::unordered_map<const torch::jit::Value*, std::string> value_to_name;
990   for (const torch::jit::Value* input : graph_->inputs()) {
991     std::string sanitized_name = sanitizeName(input->debugName());
992     // we could get fancier here, but name conflict is extremely unlikely
993     while (name_set.count(sanitized_name)) {
994       sanitized_name.append("_");
995     }
996     value_to_name[input] = sanitized_name;
997     name_set.insert(sanitized_name);
998   }
999   input_name_map_ = std::move(value_to_name);
1000 }
1001 
1002 template <typename T>
toExprHandles(const std::vector<T> & sizes)1003 static std::vector<ExprHandle> toExprHandles(const std::vector<T>& sizes) {
1004   std::vector<ExprHandle> dims;
1005   dims.reserve(sizes.size());
1006   for (auto const& size : sizes) {
1007     dims.emplace_back(size);
1008   }
1009   return dims;
1010 }
1011 
getStrideArg(size_t tensor_input_index,size_t stride_index)1012 ExprHandle TensorExprKernel::getStrideArg(
1013     size_t tensor_input_index,
1014     size_t stride_index) {
1015   auto it = strideArgToVar_.find(
1016       std::pair<size_t, size_t>(tensor_input_index, stride_index));
1017   if (it == strideArgToVar_.end()) {
1018     VarHandle var(
1019         "stride_arg" + std::to_string(tensor_input_index) + "_" +
1020             std::to_string(stride_index),
1021         kLong);
1022     strideArgToVar_[std::pair<size_t, size_t>(
1023         tensor_input_index, stride_index)] = var;
1024     return std::move(var);
1025   }
1026   return it->second;
1027 }
1028 
getSymbolicStrideDesc(const torch::jit::Value * value)1029 std::vector<torch::jit::StrideInput>& TensorExprKernel::getSymbolicStrideDesc(
1030     const torch::jit::Value* value) {
1031   TORCH_INTERNAL_ASSERT(symbolic_strides_.count(value));
1032   return symbolic_strides_[value];
1033 }
1034 
getInputStrides(const torch::jit::Value * input,const std::vector<ExprHandle> & inputTensorDims)1035 std::vector<ExprHandle> TensorExprKernel::getInputStrides(
1036     const torch::jit::Value* input,
1037     const std::vector<ExprHandle>& inputTensorDims) {
1038   std::vector<ExprHandle> inputTensorStrides;
1039   if (input->isCompleteTensor()) {
1040     auto const strides =
1041         input->type()->expect<TensorType>()->strides().concrete_sizes();
1042     std::vector<ExprHandle> inputTensorStrides;
1043     for (size_t stride : *strides) {
1044       inputTensorStrides.push_back(LongImm::make(stride));
1045     }
1046     return inputTensorStrides;
1047   }
1048 
1049   size_t rank = inputTensorDims.size();
1050   std::vector<StrideInput>& stride_input = getSymbolicStrideDesc(input);
1051   if (stride_input.size() == 1 &&
1052       (stride_input[0] == StrideInput::TENSOR_CONT_CHANNELS_LAST ||
1053        stride_input[0] == StrideInput::TENSOR_CONT)) {
1054     auto strides = stride_input[0] == StrideInput::TENSOR_CONT
1055         ? make_contiguous_strides(inputTensorDims)
1056         : make_channels_last_strides(inputTensorDims);
1057     return fmap(
1058         strides, [&](ExprPtr stride) { return ExprHandle(std::move(stride)); });
1059   }
1060 
1061   inputTensorStrides.resize(rank);
1062   std::vector<bool> stride_set;
1063   for (size_t i = 0; i < rank; ++i) {
1064     stride_set.push_back(false);
1065   }
1066   // first, generate non-dependent values
1067   size_t generated_strides = 0;
1068   for (const auto i : c10::irange(rank)) {
1069     if (stride_input[i] == torch::jit::StrideInput::S_ONE) {
1070       inputTensorStrides[i] = LongImm::make(1);
1071       stride_set[i] = true;
1072       generated_strides++;
1073     } else if (stride_input[i] == torch::jit::StrideInput::S_AS_ARG) {
1074       size_t input_index = input->offset();
1075       inputTensorStrides[i] = getStrideArg(input_index, i);
1076       stride_set[i] = true;
1077       generated_strides++;
1078     }
1079   }
1080   // Contiguous and Transposed Contiguous depend on adjacent values
1081   while (generated_strides != rank) {
1082     for (int i = static_cast<int>(rank) - 1; i >= 0; i--) {
1083       if (stride_input[i] == torch::jit::StrideInput::S_CONT &&
1084           stride_set[i + 1]) {
1085         inputTensorStrides[i] =
1086             inputTensorStrides[i + 1] * inputTensorDims[i + 1];
1087 
1088         stride_set[i] = true;
1089         generated_strides++;
1090       }
1091     }
1092     for (int i = 0; i < static_cast<int>(rank); i++) {
1093       if (stride_input[i] == torch::jit::StrideInput::S_TRAN_CONT &&
1094           stride_set[i - 1]) {
1095         inputTensorStrides[i] =
1096             inputTensorStrides[i - 1] * inputTensorDims[i - 1];
1097         stride_set[i] = true;
1098         generated_strides++;
1099       }
1100     }
1101   }
1102   return inputTensorStrides;
1103 }
1104 
bindInput(const torch::jit::Value * input)1105 Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
1106   auto const& t = input->type();
1107   auto const& outputs = input->owningGraph()->outputs();
1108   std::unordered_set<const Value*> outputs_set(outputs.begin(), outputs.end());
1109 
1110   auto is_concrete_cont = [](const torch::jit::Value* input,
1111                              const MemoryLayoutPolicy& mem_layout_policy) {
1112     if (input->isCompleteTensor()) {
1113       auto mem_layout = (mem_layout_policy == MemoryLayoutPolicy::kContiguous)
1114           ? at::MemoryFormat::Contiguous
1115           : at::MemoryFormat::ChannelsLast;
1116       return isContiguous(input, mem_layout);
1117     } else {
1118       return false;
1119     }
1120   };
1121 
1122   auto is_symbolic_cont = [](std::vector<torch::jit::StrideInput> desc,
1123                              const MemoryLayoutPolicy& mem_layout_policy) {
1124     if (desc.size() == 1) {
1125       auto mem_layout = (mem_layout_policy == MemoryLayoutPolicy::kContiguous)
1126           ? torch::jit::StrideInput::TENSOR_CONT
1127           : torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST;
1128       return desc[0] == mem_layout;
1129     } else {
1130       return false;
1131     }
1132   };
1133 
1134   Tensor result(nullptr, nullptr);
1135   switch (t->kind()) {
1136     case TypeKind::TensorType: {
1137       auto tt = input->type()->cast<TensorType>();
1138       bool contiguous_concrete_tensor =
1139           is_concrete_cont(input, memory_layout_policy_);
1140       bool contiguous_symbolic_tensor = false;
1141       if (has_symbolic_shapes_) {
1142         auto desc = getSymbolicStrideDesc(input);
1143         contiguous_symbolic_tensor =
1144             is_symbolic_cont(desc, memory_layout_policy_);
1145       }
1146 
1147       // Get input size and strides
1148       auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
1149       auto inputTensorStrides = getInputStrides(input, size_handles);
1150 
1151       // We don't need to copy the input if:
1152       //  1) it is not an output AND
1153       //  2) it is contiguous
1154       bool contiguous =
1155           contiguous_concrete_tensor || contiguous_symbolic_tensor;
1156       if (!outputs_set.count(input) && contiguous) {
1157         BufHandle inBuffer(
1158             "t" + input_name_map_[input],
1159             sizesFromSymbolicShape(tt->symbolic_sizes()),
1160             inputTensorStrides,
1161             ToDtype(static_cast<ScalarType>(*tt->scalarType())));
1162         TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1163             inBuffer.node()->is_contiguous() ||
1164             inBuffer.node()->is_channels_last_1d_contiguous() ||
1165             inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast) ||
1166             inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast3d));
1167         bufs_.emplace(input, inBuffer.node());
1168         bufferArgs_.emplace_back(inBuffer);
1169         break;
1170       }
1171 
1172       // if the input isn't contiguous or is an output,
1173       // write strided input into  contiguous buffer that is
1174       // then used in all further compute
1175       ExprHandle flat_size = 1;
1176       for (size_t i = 0; i < size_handles.size(); ++i) {
1177         auto size = size_handles[i];
1178         if (size.AsNode<LongImm>() && immediateAs<int64_t>(size.node()) == 0) {
1179           flat_size = 0;
1180           break;
1181         }
1182         flat_size = flat_size + (size - 1) * inputTensorStrides[i];
1183       }
1184       flat_size = IRSimplifier::simplify(flat_size);
1185       BufHandle inBuffer(
1186           "t" + input_name_map_[input],
1187           {flat_size},
1188           ToDtype(static_cast<ScalarType>(*tt->scalarType())));
1189 
1190       result = Compute(
1191           "input" + std::to_string(bufs_.size() + 1),
1192           size_handles,
1193           [&](const std::vector<VarHandle>& axes) {
1194             ExprHandle idx = 0;
1195             for (size_t i = 0; i < axes.size(); i++) {
1196               idx = idx + axes[i] * inputTensorStrides[i];
1197             }
1198             return inBuffer.load(idx);
1199           });
1200       bufs_.emplace(input, result.buf());
1201       bufferArgs_.emplace_back(inBuffer);
1202       break;
1203     }
1204     case TypeKind::FloatType: {
1205       VarHandle v("v" + input_name_map_[input], kDouble);
1206       bufferArgs_.emplace_back(v);
1207       scalars_.emplace(input, v);
1208       break;
1209     }
1210     case TypeKind::BoolType: {
1211       VarHandle v("v" + input_name_map_[input], kBool);
1212       bufferArgs_.emplace_back(v);
1213       scalars_.emplace(input, v);
1214       break;
1215     }
1216     case TypeKind::IntType: {
1217       VarHandle v("v" + input_name_map_[input], kLong);
1218       bufferArgs_.emplace_back(v);
1219       scalars_.emplace(input, v);
1220       break;
1221     }
1222     default: {
1223       throw unsupported_dtype(t->repr_str());
1224       break;
1225     }
1226   }
1227   return result;
1228 }
1229 
getCustomLoweringFor(c10::Symbol op) const1230 NNCLoweringFunction TensorExprKernel::getCustomLoweringFor(
1231     c10::Symbol op) const {
1232   if (custom_lowerings_.count(op))
1233     return custom_lowerings_.at(op);
1234   return nullptr;
1235 }
1236 
1237 template <typename T>
reverse_sort_indices(const std::vector<T> & v)1238 std::vector<size_t> reverse_sort_indices(const std::vector<T>& v) {
1239   // initialize original index locations
1240   std::vector<size_t> idx(v.size());
1241   iota(idx.begin(), idx.end(), 0);
1242 
1243   std::sort(idx.begin(), idx.end(), [&v](size_t i1, size_t i2) {
1244     return v[i1] > v[i2];
1245   });
1246   return idx;
1247 }
1248 
denseAndNonOverlapping(at::ArrayRef<int64_t> sizes,at::ArrayRef<int64_t> strides)1249 static bool denseAndNonOverlapping(
1250     at::ArrayRef<int64_t> sizes,
1251     at::ArrayRef<int64_t> strides) {
1252   return (strides == at::infer_dense_strides(sizes, strides));
1253 }
1254 
convertSymbolicOutputToCorrectStrides(const std::vector<ExprHandle> & sizes,const std::vector<size_t> & sorted_stride_indices_descending,const std::vector<ExprPtr> & strides,BufPtr & buf)1255 Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
1256     const std::vector<ExprHandle>& sizes,
1257     const std::vector<size_t>& sorted_stride_indices_descending,
1258     const std::vector<ExprPtr>& strides,
1259     BufPtr& buf) {
1260   // We need to convert the output tensor so that its values are layed
1261   // so that when viewed from the output strides the values are correct.
1262   // A contiguous Tensor of size(2, 3) with values 0-5 is layed out as:
1263   // [0] [1] [2] [3] [4] [5]
1264   // The same valued tensor with strides (1, 2) would be layed out like
1265   // [0] [3] [1] [4] [2] [5]
1266   // When we are doing the re-ordering of values into the output tensor,
1267   // we are iterating per-element of the input, and we are fixed
1268   // in indexing in to the output tensor at [i, j] = val
1269   // `val` we want here is equal to the indices for the output
1270   // tensor that would have given the same position as the output
1271   // The position is equal to the sum of stride[i] * index[i],
1272   // and we can can calculate the equivalent indices in the
1273   // output tensor strides by iteratively computing the index of
1274   // the biggest stride:
1275   // absolute = ...
1276   // for stride in strides_from_largest_to_smallest:
1277   //     cur_idx = absolute // stride
1278   //     absolute = absolute % stride
1279   std::vector<ExprPtr> default_strides = make_contiguous_strides(sizes);
1280   auto zero = LongImm::make(0);
1281   return Compute(
1282       "output_1", sizes, [&](const std::vector<VarHandle>& axes_input) {
1283         std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
1284         auto absolute_position = ExprHandle(immLike(axes[0], 0));
1285         for (size_t i = 0; i < axes.size(); ++i) {
1286           ExprHandle stride(default_strides[i]);
1287           ExprHandle axis = axes[i];
1288           absolute_position = absolute_position + (stride * axis);
1289         }
1290         std::vector<ExprHandle> new_axes(
1291             sorted_stride_indices_descending.size());
1292         for (size_t stride_index : sorted_stride_indices_descending) {
1293           const auto& stride = strides[stride_index];
1294           auto index = absolute_position / ExprHandle(stride);
1295           // XXX, in symbolic output ordering, we do not the arbitrary
1296           // ordering of strides as in usual output ordering, just
1297           // channels last, so even in the presence of size == 1
1298           // we produce correct output here
1299           absolute_position = absolute_position % ExprHandle(stride);
1300           new_axes[stride_index] = index;
1301         }
1302         return BufHandle(buf).load(new_axes);
1303       });
1304 }
1305 
convertSymbolicOutputToCorrectStrides(torch::jit::Value * v)1306 Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
1307     torch::jit::Value* v) {
1308   const TensorTypePtr& tt = v->type()->expect<TensorType>();
1309   TORCH_INTERNAL_ASSERT(
1310       bufs_.count(v),
1311       buildErrorMessage(
1312           "Output tensor has no corresponding bufs in the fuser."));
1313   BufPtr buf = bufs_.at(v);
1314   TORCH_INTERNAL_ASSERT(buf != nullptr);
1315   TORCH_INTERNAL_ASSERT(tt != nullptr);
1316   TORCH_INTERNAL_ASSERT(tt->symbolic_sizes().rank() != std::nullopt);
1317 
1318   auto stride_desc = getSymbolicStrideDesc(v);
1319   TORCH_INTERNAL_ASSERT(stride_desc.size() == 1);
1320   auto memory_format = (stride_desc[0] == torch::jit::StrideInput::TENSOR_CONT)
1321       ? at::MemoryFormat::Contiguous
1322       : at::MemoryFormat::ChannelsLast;
1323   // output is contiguous with specified memory format, no work to do
1324   if (buf->is_contiguous(memory_format)) {
1325     return Tensor(buf, nullptr);
1326   }
1327 
1328   TORCH_INTERNAL_ASSERT(
1329       stride_desc[0] == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
1330   auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
1331   auto strides = make_channels_last_strides(sizes);
1332   // For a tensor with dimensions N C H W, channels last
1333   // format will is in format N H W C,
1334   // so the order largest to smallest will be N, H, W, C
1335   std::vector<size_t> sorted_stride_indices = {0, 2, 3, 1};
1336   auto zero = LongImm::make(0);
1337   std::vector<ExprPtr> default_strides = make_contiguous_strides(sizes);
1338   // See explanation in convertOutputToCorrectStrides
1339   return convertSymbolicOutputToCorrectStrides(
1340       sizes, sorted_stride_indices, strides, buf);
1341 }
1342 
convertStaticShapeOutputToCorrectStrides(torch::jit::Value * v)1343 Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides(
1344     torch::jit::Value* v) {
1345   const TensorTypePtr& tt = v->type()->expect<TensorType>();
1346   TORCH_INTERNAL_ASSERT(
1347       bufs_.count(v),
1348       buildErrorMessage(
1349           "Output tensor has no corresponding bufs in the fuser."));
1350   BufPtr buf = bufs_.at(v);
1351 
1352   // No shape info is present in the graph
1353   if (!tt->sizes().concrete_sizes()) {
1354     std::string msg =
1355         std::string("Shapes for output '%") + v->debugName() + "' are unknown";
1356     throw malformed_input(msg);
1357   }
1358 
1359   TORCH_INTERNAL_ASSERT(
1360       tt->sizes().concrete_sizes(),
1361       buildErrorMessage("Output shapes are unknown."));
1362   auto sizes = *tt->sizes().concrete_sizes();
1363   at::MemoryFormat memory_format =
1364       (memory_layout_policy_ == MemoryLayoutPolicy::kContiguous)
1365       ? c10::MemoryFormat::Contiguous
1366       : c10::MemoryFormat::ChannelsLast;
1367   std::vector<int64_t> default_strides =
1368       TensorType::contiguousStridesOf(sizes, memory_format);
1369   if (!tt->strides().concrete_sizes()) {
1370     return Tensor(buf, nullptr);
1371   }
1372   TORCH_INTERNAL_ASSERT(
1373       tt->strides().concrete_sizes(),
1374       buildErrorMessage("Output strides are unknown."));
1375   const std::vector<int64_t> strides = *tt->strides().concrete_sizes();
1376   // All Tensors in NNC are layed out in default, contiguous layout.
1377   // If the output is also default contiguous we don't need to do anything
1378   if (strides == default_strides) {
1379     return Tensor(buf, nullptr);
1380   }
1381   // If the tensor is not dense or overlaps, we have
1382   // no way of matching the profiled striding
1383   if (!denseAndNonOverlapping(sizes, strides)) {
1384     return Tensor(buf, nullptr);
1385   }
1386 
1387   auto dims = sizesForValue(v);
1388   auto zero = LongImm::make(0);
1389   std::vector<size_t> sorted_stride_indices = reverse_sort_indices(strides);
1390 
1391   // TODO: call into `convertOutputToCorrectStrides`. Currently this causes a
1392   // bug in IRSimplifier to occur. See explanation in
1393   // `convertOutputToCorrectStrides`
1394   return Compute(
1395       "output_1", dims, [&](const std::vector<VarHandle>& axes_input) {
1396         std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
1397         auto absolute_position = ExprHandle(immLike(axes[0], 0));
1398         for (size_t i = 0; i < axes.size(); ++i) {
1399           absolute_position = absolute_position +
1400               (ExprHandle(immLike(axes[i], default_strides[i])) * axes[i]);
1401         }
1402 
1403         std::vector<ExprHandle> new_axes(sorted_stride_indices.size());
1404         for (size_t stride_index : sorted_stride_indices) {
1405           auto size = sizes[stride_index];
1406           auto index = zero;
1407           if (size != 1) {
1408             auto stride = strides[stride_index];
1409             index = absolute_position /
1410                 ExprHandle(immLike(absolute_position, stride));
1411             absolute_position = absolute_position %
1412                 ExprHandle(immLike(absolute_position, stride));
1413           }
1414           new_axes[stride_index] = index;
1415         }
1416         return BufHandle(buf).load(new_axes);
1417       });
1418 }
1419 
bindConstant(const torch::jit::Value * v)1420 void TensorExprKernel::bindConstant(const torch::jit::Value* v) {
1421   auto val = toIValue(v).value();
1422   if (torch::isCustomClass(val)) {
1423     auto name_hint = "const_" + sanitizeName(v->debugName());
1424     auto dtype = Dtype(ScalarType::Float);
1425     std::vector<ExprPtr> dims;
1426     BufPtr buf = alloc<Buf>(name_hint, dims, dtype);
1427     auto dataPtr = val.toObjectRef().getSlot(0).toCapsule().get();
1428     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1429     constants_.push_back({buf, dataPtr, const_cast<Node*>(v->node())});
1430     bufs_[v] = buf;
1431     return;
1432   }
1433   if (!v->type()->cast<TensorType>()) {
1434     // Only Tensor constants need to be bound, scalar constants will be turned
1435     // into immediates in TE IR
1436     return;
1437   }
1438   auto const_tensor = toIValue(v)->toTensor();
1439   auto scalar_type = c10::typeMetaToScalarType(const_tensor.options().dtype());
1440   auto sizes = const_tensor.sizes();
1441   std::vector<ExprHandle> te_sizes;
1442   te_sizes.reserve(sizes.size());
1443   for (auto s : sizes) {
1444     te_sizes.emplace_back(s);
1445   }
1446   BufPtr buf = alloc<Buf>(
1447       "const_" + sanitizeName(v->debugName()),
1448       ExprHandleVectorToExprVector(te_sizes),
1449       ToDtype(scalar_type));
1450 
1451   if (!const_tensor.is_contiguous()) {
1452     const_tensor = const_tensor.clone().contiguous();
1453     unpacked_constant_tensors_.push_back(const_tensor);
1454   }
1455 
1456   constants_.push_back({buf, const_tensor.data_ptr()});
1457   bufs_[v] = buf;
1458 }
1459 
preAllocIntermediateBufs(const std::vector<BufPtr> & interm_bufs)1460 std::vector<BufPtr> TensorExprKernel::preAllocIntermediateBufs(
1461     const std::vector<BufPtr>& interm_bufs) {
1462   std::vector<BufPtr> remaining_interm_bufs;
1463   for (const auto& buf : interm_bufs) {
1464     // Check if buf shape is static and compute its size if static.
1465     bool is_static = true;
1466     size_t size =
1467         elementSize(buf->dtype().scalar_type()) * buf->dtype().lanes();
1468     for (auto& d : buf->dims()) {
1469       if (!d->isConstant()) {
1470         is_static = false;
1471         break;
1472       }
1473       size = size * (*intValue(d));
1474     }
1475     // Only allocate memory for static bufs.
1476     if (!is_static) {
1477       remaining_interm_bufs.push_back(buf);
1478       continue;
1479     }
1480     auto bp = (void*)malloc(size);
1481     if (!bp) {
1482       remaining_interm_bufs.push_back(buf);
1483       continue;
1484     }
1485     constants_.push_back({buf, bp});
1486   }
1487   return remaining_interm_bufs;
1488 }
1489 
bindAllInputs()1490 BlockPtr TensorExprKernel::bindAllInputs() {
1491   std::vector<CodeGen::BufferArg> symbolic_shape_args;
1492   std::vector<CodeGen::BufferArg> symbolic_stride_args;
1493 
1494   auto symbolic_shape_inputs_start_pos =
1495       nInputs_ - symbolic_shape_inputs_.size();
1496   if (has_symbolic_shapes_) {
1497     // The graph is supposed to have input params that represent the symbolic
1498     // dims at the end of the list of inputs. The number of such symbolic input
1499     // params is defined by the size of the `symbolic_shape_inputs_` vector.
1500     //
1501     // TODO: Check if the tensors with symbolic shapes are contiguous.
1502     TORCH_CHECK(
1503         nInputs_ > static_cast<int64_t>(symbolic_shape_inputs_.size()),
1504         "Symbolic dims not provided as inputs to the graph");
1505 
1506     // First, process the symbolic input params and create a new variable for
1507     // each of them.
1508     // NOTE: This has to be done before processing the tensor inputs, because
1509     // their symbolic sizes needs to be associated with these variables we
1510     // create for the symbolic input params.
1511     symbolic_shape_args.reserve(symbolic_shape_inputs_.size());
1512 
1513     for (size_t i = symbolic_shape_inputs_start_pos;
1514          i < static_cast<size_t>(nInputs_);
1515          ++i) {
1516       auto input = graph_->inputs()[i];
1517       if (input->type()->kind() != TypeKind::IntType) {
1518         throw std::runtime_error(
1519             "Expected integer type input to graph for symbolic dims.");
1520       }
1521       VarHandle v("v" + input_name_map_[input], kLong);
1522       symbolic_shape_args.emplace_back(v);
1523       scalars_.emplace(input, v);
1524       shapeSymbolInputPos_[scalars_[input].node()] = i;
1525     }
1526     // For every shape symbol, store a map to the corresponding var.
1527     for (size_t i = 0; i < symbolic_shape_inputs_.size(); ++i) {
1528       shapeSymbolToVar_[symbolic_shape_inputs_[i]] =
1529           scalars_[graph_->inputs()[symbolic_shape_inputs_start_pos + i]];
1530     }
1531 
1532     // Next, process symbolic input params and create an argument for symbolic
1533     for (size_t i = 0; i < symbolic_shape_inputs_start_pos; ++i) {
1534       auto input = graph_->inputs()[i];
1535       auto tt = input->type()->cast<TensorType>();
1536       if (!tt) {
1537         continue;
1538       }
1539       auto symbolic_stride = getSymbolicStrideDesc(input);
1540       for (size_t j = 0; j < symbolic_stride.size(); ++j) {
1541         if (symbolic_stride[j] == torch::jit::StrideInput::S_AS_ARG) {
1542           VarHandle v("v" + input_name_map_[input], kLong);
1543           symbolic_stride_args.emplace_back(v);
1544           strideArgToVar_[{i, j}] = v;
1545           input_stride_args_.emplace_back(i, j);
1546         }
1547       }
1548     }
1549   }
1550 
1551   // Block to collect the Stmts corresponding to all tensors.
1552   auto block = alloc<Block>(std::vector<StmtPtr>({}));
1553 
1554   // Process the inputs before the symbolic input params.
1555   for (const auto i : c10::irange(symbolic_shape_inputs_start_pos)) {
1556     auto input = graph_->inputs()[i];
1557     Tensor t = bindInput(input);
1558     if (t.stmt()) {
1559       block->append_stmt(t.stmt());
1560     }
1561   }
1562   // Now, add all the variables corresponding to the symbolic input params.
1563   bufferArgs_.insert(
1564       bufferArgs_.end(),
1565       symbolic_shape_args.begin(),
1566       symbolic_shape_args.end());
1567 
1568   // Now, add all the variables corresponding to symbolic stride inputs
1569   bufferArgs_.insert(
1570       bufferArgs_.end(),
1571       symbolic_stride_args.begin(),
1572       symbolic_stride_args.end());
1573 
1574   return block;
1575 }
1576 
deduceMemoryLayoutPolicy()1577 void TensorExprKernel::deduceMemoryLayoutPolicy() {
1578   // If the tensor is channels-last contiguous, the preferred memory layout
1579   // propagation policy is to use channels-last. Otherwise, the preferred policy
1580   // is to use contiguous.
1581   auto _prefer_symbolic_mem =
1582       [](const torch::jit::Value* val,
1583          const std::vector<torch::jit::StrideInput>& stride_desc_vec) {
1584         TORCH_INTERNAL_ASSERT(!stride_desc_vec.empty());
1585         // Has symbolic stride information
1586         auto cur_stride_desc = stride_desc_vec[0];
1587         return (cur_stride_desc ==
1588                 torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST)
1589             ? MemoryLayoutPolicy::kChannelsLastNdContiguous
1590             : MemoryLayoutPolicy::kContiguous;
1591       };
1592 
1593   auto _prefer_static_mem = [](const torch::jit::Value* val) {
1594     // No shape info is present in the graph
1595     TORCH_INTERNAL_ASSERT(
1596         val->isCompleteTensor(),
1597         buildErrorMessage(val->debugName() + " is not a complete tensor."));
1598     const auto& tt = val->type()->expect<TensorType>();
1599     const auto sizes = *tt->sizes().concrete_sizes();
1600     const auto strides = *tt->strides().concrete_sizes();
1601     return (c10::is_channels_last_strides_2d(sizes, strides))
1602         ? MemoryLayoutPolicy::kChannelsLastNdContiguous
1603         : MemoryLayoutPolicy::kContiguous;
1604   };
1605 
1606   // Filter out the tensor from the graph inputs and outputs to
1607   // deduce the memory layout propagation policy
1608   auto _is_tensor = [](const jit::Value* el) {
1609     return el->type()->kind() == TypeKind::TensorType;
1610   };
1611   std::vector<torch::jit::Value*> graph_io_tensors;
1612   std::copy_if(
1613       graph_->inputs().begin(),
1614       graph_->inputs().end(),
1615       std::back_inserter(graph_io_tensors),
1616       _is_tensor);
1617   std::copy_if(
1618       graph_->outputs().begin(),
1619       graph_->outputs().end(),
1620       std::back_inserter(graph_io_tensors),
1621       _is_tensor);
1622   // std::all_of returns true if the range is empty. But we prefer to keep
1623   // the original memory layout propagation policy for this case. So we
1624   // check whether the range is empty.
1625   auto prefer_channels_last = (!graph_io_tensors.empty());
1626   for (auto el : graph_io_tensors) {
1627     auto is_complete = el->isCompleteTensor();
1628     auto is_symbolic = symbolic_strides_.count(el);
1629 
1630     auto preferred_mem_layout = is_complete
1631         ? _prefer_static_mem(el)
1632         : (is_symbolic ? _prefer_symbolic_mem(el, symbolic_strides_[el])
1633                        : MemoryLayoutPolicy::kContiguous);
1634     if (preferred_mem_layout != MemoryLayoutPolicy::kChannelsLastNdContiguous) {
1635       prefer_channels_last = false;
1636       break;
1637     }
1638   }
1639 
1640   // If the memory layout of all the input and outputs is channels-last
1641   // contiguous, the propagated memory layout should be channels-last.
1642   // Otherwise, the propagated memory layout is contiguous which is as
1643   // same as current situation.
1644   memory_layout_policy_ = prefer_channels_last
1645       ? MemoryLayoutPolicy::kChannelsLastNdContiguous
1646       : MemoryLayoutPolicy::kContiguous;
1647 }
1648 
optimizeOwningGraph()1649 void TensorExprKernel::optimizeOwningGraph() {
1650   GRAPH_DUMP("TensorExprKernel graph (Before graph optimization):", graph_);
1651 
1652   // We may manipulate output pointers in graph manipulation. So we store the
1653   // original outputs for symbolic strides information synchronization
1654   auto _orignal_graph_outputs = graph_->outputs().vec();
1655 
1656   // Get the graph device information first. The graph optimization
1657   // might be device specific.
1658   device_ = *pickDeviceType(graph_);
1659 
1660   // Determine the propagated memory layout
1661   deduceMemoryLayoutPolicy();
1662 
1663   // Fuse Conv with Eltwise Op
1664   graph_rewrite_helper::replaceConvolutionWithAtenConv(graph_);
1665   FuseConvWithEltwise(graph_);
1666 
1667   // Optimize the concatenation
1668   OptimizeCat(graph_);
1669 
1670   // Synchronize the symbolic strides information
1671   auto graph_outputs = graph_->outputs();
1672   TORCH_INTERNAL_ASSERT(graph_outputs.size() == _orignal_graph_outputs.size());
1673   for (auto i : c10::irange(graph_outputs.size())) {
1674     auto el_orig = _orignal_graph_outputs.at(i);
1675     auto el_new = graph_outputs.at(i);
1676     if (symbolic_strides_.count(el_orig) && (el_orig != el_new)) {
1677       symbolic_strides_[el_new] = symbolic_strides_[el_orig];
1678       symbolic_strides_.erase(el_orig);
1679     }
1680   }
1681 
1682   GRAPH_DUMP("TensorExprKernel graph (After graph optimization):", graph_);
1683 }
1684 
compile()1685 void TensorExprKernel::compile() {
1686   GRAPH_DUMP("TensorExprKernel graph:", graph_);
1687 
1688   has_symbolic_shapes_ = !symbolic_shape_inputs_.empty();
1689   nInputs_ = graph_->inputs().size();
1690   nOutputs_ = graph_->outputs().size();
1691   genInputDebugNames();
1692 
1693   // Bind inputs to buffers.
1694   auto block = bindAllInputs();
1695 
1696   // Bind nodes to tensor compute expressions.
1697   for (auto const& n : graph_->nodes()) {
1698     if (n->kind() == prim::ListConstruct) {
1699       continue;
1700     } else if (n->kind() == prim::Constant) {
1701       bindConstant(n->output());
1702       continue;
1703     } else {
1704       for (auto const& output : n->outputs()) {
1705         if (output->hasUses()) {
1706           Tensor t = computeValue(output);
1707 
1708           // If there are for-loops before ExternalCall as follows,
1709           //   stmt1: for:
1710           //   stmt2    for:
1711           //   stmt3: ExternalCall
1712           // the for-loops would not be parallelized. So we mark the
1713           // buf args of ExternalCall as to be parallelized to make sure
1714           // its previous loop still could be parallelized.
1715           if (to<ExternalCall>(t.stmt())) {
1716             auto _external_call = to<ExternalCall>(t.stmt());
1717             for (const auto& _buf : _external_call->buf_args()) {
1718               bufsToBeParallelized_.insert(_buf);
1719             }
1720           }
1721 
1722           if (output->type()->cast<TensorType>()) {
1723             // Value is tensor
1724             if (t.buf()) {
1725               bufs_.emplace(output, t.buf());
1726             }
1727             block->append_stmt(t.stmt());
1728           } else {
1729             // Value is scalar
1730             //
1731             // We represent scalar computations in TE with a pair of statements:
1732             //   Let val = <compute_expression>
1733             //   Store(buf_for_scalar[0], val)
1734             //
1735             // Subsequent computations will use val when they refer to the
1736             // given value, and the buffer will be used if we need to return
1737             // the computed value as an output of the kernel. If this is not an
1738             // output, the store will be removed later by DCE.
1739             //
1740             // NB: NNC's lowering functions return Tensor, which is a pair
1741             // <Buf, Stmt>, but here we also need Var. How can we obtain all of
1742             // Var, Buf, and Stmt?
1743             // We use the following trick: the lowering function creates the
1744             // Let-stmt and a "fake" buffer, whose only purpose is to hold the
1745             // Var. Then outside the lowering function (namely, right here) we
1746             // generate the store and the actual buffer.
1747             VarPtr v = t.buf()->base_handle();
1748             scalars_[output] = VarHandle(v);
1749             block->append_stmt(t.stmt());
1750             std::vector<ExprPtr> dims;
1751             BufHandle buf(
1752                 "scalar_" + sanitizeName(output->debugName()), {}, v->dtype());
1753             StmtPtr store = Store::make(buf, {}, ExprHandle(v));
1754             block->append_stmt(store);
1755             bufs_.emplace(output, buf.node());
1756           }
1757         }
1758       }
1759     }
1760     if (hasRandom_ && hasBroadcast_) {
1761       throw std::runtime_error(
1762           "Cannot support broadcast and random within one kernel");
1763     }
1764   }
1765 
1766   // Move output operands from `bufs_` to `bufOutputs_`
1767   for (auto i : c10::irange(graph_->outputs().size())) {
1768     auto& output = graph_->outputs().at(i);
1769     if (!bufs_.count(output)) {
1770       throw malformed_input("cannot find output Tensor");
1771     }
1772     if (!output->type()->cast<TensorType>()) {
1773       // Scalar outputs are represented as 0-dim buffers.
1774       bufOutputs_.insert(bufs_.at(output));
1775       bufsToBeParallelized_.insert(bufs_.at(output));
1776       bufferArgs_.emplace_back(BufHandle(bufs_.at(output)));
1777       tensorOutputTensorOptions_.emplace_back(
1778           c10::TensorOptions(tensorType(bufs_.at(output))).device(device_));
1779       tensorOutputSizes_.emplace_back();
1780       tensorOutputStrides_.emplace_back();
1781       isOutputScalar_.push_back(true);
1782       bufs_.erase(output);
1783       continue;
1784     }
1785 
1786     const auto& tt = output->type()->expect<TensorType>();
1787     if (has_symbolic_shapes_) {
1788       auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
1789       tensorOutputSymbolicSizes_.push_back(sizes);
1790       TORCH_INTERNAL_ASSERT(symbolic_strides_.count(output));
1791       auto stride_desc_vec = symbolic_strides_[output];
1792       TORCH_INTERNAL_ASSERT(stride_desc_vec.size() == 1);
1793       auto stride_desc = stride_desc_vec[0];
1794       tensorOutputStrideDesc_.push_back(stride_desc);
1795       Tensor properly_strided_output =
1796           convertSymbolicOutputToCorrectStrides(output);
1797       if (properly_strided_output.stmt()) {
1798         block->append_stmt(properly_strided_output.stmt());
1799       }
1800       bufs_[output] = properly_strided_output.buf();
1801     } else {
1802       // The "strided" tensor will be incorrect if used in NNC,
1803       // since NNC views it as contiguous. Only convert it to the right
1804       // strides at the end of the kernel (if already contiguous it's a no-op)
1805       Tensor properly_strided_output =
1806           convertStaticShapeOutputToCorrectStrides(output);
1807       if (properly_strided_output.stmt()) {
1808         block->append_stmt(properly_strided_output.stmt());
1809       }
1810       bufs_[output] = properly_strided_output.buf();
1811       auto sizes = *tt->sizes().concrete_sizes();
1812       tensorOutputSizes_.push_back(sizes);
1813       auto strides = tt->strides().concrete_sizes();
1814 
1815       // If the tensor is not dense or overlaps, we have
1816       // no way of matching the profiled striding
1817       if (strides && denseAndNonOverlapping(sizes, *strides)) {
1818         tensorOutputStrides_.push_back(*strides);
1819       } else {
1820         tensorOutputStrides_.push_back(TensorType::contiguousStridesOf(sizes));
1821       }
1822     }
1823 
1824     bufOutputs_.insert(bufs_.at(output));
1825     bufsToBeParallelized_.insert(bufs_.at(output));
1826     bufferArgs_.emplace_back(BufHandle(bufs_.at(output)));
1827     tensorOutputTensorOptions_.emplace_back(
1828         c10::TensorOptions(tensorType(bufs_.at(output))).device(device_));
1829     isOutputScalar_.push_back(false);
1830     bufs_.erase(output);
1831   }
1832 
1833   BackendType backendType = inferBackendTypeFromDevice(device_);
1834   stmt_ = transformLoops(backendType, block);
1835 
1836   for (const auto& c : constants_) {
1837     bufferArgs_.emplace_back(BufHandle(c.buf));
1838   }
1839 
1840   if (has_symbolic_shapes_) {
1841     tensorOutputSizes_.resize(bufOutputs_.size());
1842     tensorOutputStrides_.resize(bufOutputs_.size());
1843   }
1844 
1845   // Generate code.
1846   codegen_ = CreateCodeGen(
1847       getCodeGenName(backendType),
1848       stmt_,
1849       bufferArgs_,
1850       device_,
1851       kernel_func_name_);
1852 }
1853 
recompile()1854 void TensorExprKernel::recompile() {
1855   codegen_ = CreateCodeGen(
1856       "llvm_codegen", stmt_, bufferArgs_, device_, kernel_func_name_);
1857 }
1858 
TensorExprKernel(const std::shared_ptr<Graph> & subgraph,std::string kernel_func_name,std::unordered_map<c10::Symbol,NNCLoweringFunction> custom_lowerings,std::vector<int64_t> symbolic_shape_inputs,bool pre_alloc,std::unordered_map<const torch::jit::Value *,std::vector<torch::jit::StrideInput>> symbolic_strides)1859 TensorExprKernel::TensorExprKernel(
1860     const std::shared_ptr<Graph>& subgraph,
1861     std::string kernel_func_name,
1862     std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings,
1863     std::vector<int64_t> symbolic_shape_inputs,
1864     bool pre_alloc /*= false*/,
1865     std::unordered_map<
1866         const torch::jit::Value*,
1867         std::vector<torch::jit::StrideInput>> symbolic_strides)
1868     : graph_(subgraph),
1869       code_(subgraph, ""),
1870       symbolic_shape_inputs_(std::move(symbolic_shape_inputs)),
1871       custom_lowerings_(std::move(custom_lowerings)),
1872       pre_alloc_(pre_alloc),
1873       kernel_func_name_(std::move(kernel_func_name)),
1874       symbolic_strides_(std::move(symbolic_strides)) {
1875   optimizeOwningGraph();
1876   // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
1877   allow_fallback_ = fallbackAllowed();
1878 
1879   if (!allow_fallback_) {
1880     compile();
1881     return;
1882   }
1883 
1884   use_fallback_ = fallbackEnforced();
1885   if (use_fallback_) {
1886     return;
1887   }
1888 
1889   try {
1890     compile();
1891   } catch (...) {
1892     use_fallback_ = true;
1893   }
1894 }
1895 
run(Stack & stack) const1896 void TensorExprKernel::run(Stack& stack) const {
1897   if (!use_fallback_ && !allow_fallback_) {
1898     runKernel(stack);
1899   } else if (!use_fallback_ && allow_fallback_) {
1900     try {
1901       runKernel(stack);
1902     } catch (...) {
1903       fallback(stack);
1904     }
1905   } else {
1906     fallback(stack);
1907   }
1908 }
1909 
getStaticOutputSizesAndStrides(const at::ArrayRef<IValue> & inputs,std::vector<std::vector<int64_t>> * sizes,std::vector<std::vector<int64_t>> * strides) const1910 void TensorExprKernel::getStaticOutputSizesAndStrides(
1911     const at::ArrayRef<IValue>& inputs,
1912     std::vector<std::vector<int64_t>>* sizes,
1913     std::vector<std::vector<int64_t>>* strides) const {
1914   TORCH_INTERNAL_ASSERT(has_symbolic_shapes_);
1915   // If there are symbolic shapes, then the output tensor size wouldn't have
1916   // been computed at compile time. That has to be done here by using the
1917   // symbolic shape input params passed in to this call.
1918   TORCH_INTERNAL_ASSERT(
1919       tensorOutputSymbolicSizes_.size() == bufOutputs_.size());
1920 
1921   TORCH_INTERNAL_ASSERT(sizes);
1922   TORCH_INTERNAL_ASSERT(strides);
1923   *sizes = tensorOutputSizes_;
1924   *strides = tensorOutputStrides_;
1925   auto& static_sizes = *sizes;
1926   auto& static_strides = *strides;
1927   for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
1928     static_sizes[i].clear();
1929     for (auto t : tensorOutputSymbolicSizes_[i]) {
1930       if (t.AsNode<LongImm>()) {
1931         static_sizes[i].emplace_back(immediateAs<int64_t>(t.node()));
1932       } else {
1933         auto input_pos = shapeSymbolInputPos_.at(t.node());
1934         TORCH_INTERNAL_ASSERT(input_pos < inputs.size());
1935         TORCH_INTERNAL_ASSERT(inputs[input_pos].isInt());
1936         static_sizes[i].emplace_back(inputs[input_pos].toInt());
1937       }
1938     }
1939 
1940     if (tensorOutputStrideDesc_[i] == torch::jit::StrideInput::TENSOR_CONT) {
1941       static_strides[i] = TensorType::contiguousStridesOf(static_sizes[i]);
1942 
1943     } else if (
1944         tensorOutputStrideDesc_[i] ==
1945         torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST) {
1946       static_strides[i] = at::get_channels_last_strides_2d(static_sizes[i]);
1947 
1948     } else {
1949       std::string output_desc = toString(tensorOutputStrideDesc_[i]);
1950       TORCH_INTERNAL_ASSERT(
1951           false, "Expected contiguous or channels last, got ", output_desc);
1952     }
1953   }
1954 }
1955 
prepareRunArgs(const at::ArrayRef<IValue> & inputs,std::vector<at::Tensor> & outputs) const1956 std::vector<CodeGen::CallArg> TensorExprKernel::prepareRunArgs(
1957     const at::ArrayRef<IValue>& inputs,
1958     std::vector<at::Tensor>& outputs) const {
1959   // TODO: preallocate `runArgs` during compilation and fill in values where
1960   // possible (e.g. for constant tensors)
1961   std::vector<CodeGen::CallArg> runArgs;
1962   runArgs.reserve(
1963       inputs.size() + input_stride_args_.size() + bufOutputs_.size());
1964 
1965   for (auto& input : inputs) {
1966     if (input.isInt()) {
1967       runArgs.emplace_back(input.toInt());
1968     } else if (input.isBool()) {
1969       runArgs.emplace_back(input.toBool());
1970     } else if (input.isDouble()) {
1971       runArgs.emplace_back(input.toDouble());
1972     } else if (input.isTensor()) {
1973       runArgs.emplace_back(input.toTensor().data_ptr());
1974     }
1975   }
1976 
1977   if (has_symbolic_shapes_) {
1978     std::vector<std::vector<int64_t>> static_sizes;
1979     std::vector<std::vector<int64_t>> static_strides;
1980     getStaticOutputSizesAndStrides(inputs, &static_sizes, &static_strides);
1981 
1982     // add stride args
1983     for (const auto& input_stride_arg : input_stride_args_) {
1984       runArgs.emplace_back(
1985           inputs[input_stride_arg.first].toTensor().strides().at(
1986               input_stride_arg.second));
1987     }
1988 
1989     for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
1990       auto const& opts = tensorOutputTensorOptions_[i];
1991       outputs.emplace_back(codegen_->empty_strided(
1992           static_sizes[i],
1993           static_strides[i],
1994           opts.dtype,
1995           opts.layout,
1996           opts.device,
1997           opts.pinned_memory));
1998       runArgs.emplace_back(outputs.back().data_ptr());
1999     }
2000   } else {
2001     for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
2002       auto const& opts = tensorOutputTensorOptions_[i];
2003       outputs.emplace_back(codegen_->empty_strided(
2004           tensorOutputSizes_[i],
2005           tensorOutputStrides_[i],
2006           opts.dtype,
2007           opts.layout,
2008           opts.device,
2009           opts.pinned_memory));
2010       runArgs.emplace_back(outputs.back().data_ptr());
2011     }
2012   }
2013 
2014   for (const auto& c : constants_) {
2015     runArgs.emplace_back(c.ptr);
2016   }
2017 
2018   return runArgs;
2019 }
2020 
getCodeGenStmt()2021 StmtPtr TensorExprKernel::getCodeGenStmt() {
2022   return codegen_->stmt();
2023 }
2024 
runKernel(Stack & stack) const2025 void TensorExprKernel::runKernel(Stack& stack) const {
2026   // Set up arguments (inputs, then outputs) for kernel call.
2027   auto inputs = last(stack, nInputs_);
2028   std::vector<at::Tensor> outputs;
2029 
2030   std::vector<CodeGen::CallArg> runArgs = prepareRunArgs(inputs, outputs);
2031 
2032   // Call the kernel.
2033   codegen_->call(runArgs);
2034 
2035   // Update the stack.
2036   drop(stack, nInputs_);
2037 
2038   int64_t idx = 0;
2039   for (auto& o : outputs) {
2040     if (isOutputScalar_[idx++]) {
2041       // Scalar outputs are returned as 0-dim tensors, we need to extract the
2042       // scalar value from them
2043       push_one(stack, o.item());
2044     } else {
2045       push_one(stack, std::move(o));
2046     }
2047   }
2048 }
2049 
runFast(const std::vector<void * > & inputs,const std::vector<void * > & outputs) const2050 void TensorExprKernel::runFast(
2051     const std::vector<void*>& inputs,
2052     const std::vector<void*>& outputs) const {
2053   std::vector<void*> args(inputs);
2054   args.reserve(inputs.size() + outputs.size() + constants_.size());
2055   args.insert(args.end(), outputs.begin(), outputs.end());
2056 
2057   // TODO: we can consider preallocating and pre-filling the args vector.
2058   for (const auto& c : constants_) {
2059     args.push_back(c.ptr);
2060   }
2061 
2062   // Call the kernel.
2063   codegen_->call_raw(args);
2064 }
2065 
runWithAllocatedOutputs(Stack & stack) const2066 void TensorExprKernel::runWithAllocatedOutputs(Stack& stack) const {
2067   TORCH_INTERNAL_ASSERT(
2068       device_ == at::kCPU,
2069       "Pre-allocated output tensors are supported only on CPUs.");
2070   std::vector<void*> args;
2071   args.reserve(nInputs_ + nOutputs_ + constants_.size());
2072 
2073   // stack has inputs on the top and outputs right below them.
2074   auto stack_ivals = last(stack, nOutputs_ + nInputs_);
2075   auto stack_outputs = stack_ivals.slice(0, nOutputs_);
2076   auto stack_inputs = stack_ivals.slice(nOutputs_);
2077 
2078   std::vector<int64_t> int_inputs(nInputs_);
2079   for (auto i : c10::irange(nInputs_)) {
2080     auto inp = stack_inputs[i];
2081     if (inp.isInt()) {
2082       int_inputs[i] = inp.toInt();
2083       args.emplace_back(&int_inputs[i]);
2084     } else if (inp.isTensor()) {
2085       args.emplace_back(inp.toTensor().data_ptr());
2086     } else {
2087       TORCH_INTERNAL_ASSERT(
2088           false, "Unhandled input type while calling TensorExprKernel");
2089     }
2090   }
2091 
2092   std::vector<int64_t> stride_values(input_stride_args_.size());
2093   if (has_symbolic_shapes_) {
2094     std::vector<std::vector<int64_t>> static_sizes;
2095     std::vector<std::vector<int64_t>> static_strides;
2096     getStaticOutputSizesAndStrides(
2097         stack_inputs, &static_sizes, &static_strides);
2098 
2099     // add stride args
2100     for (auto idx : c10::irange(input_stride_args_.size())) {
2101       const auto& input_stride_arg = input_stride_args_[idx];
2102       stride_values[idx] =
2103           stack_inputs[input_stride_arg.first].toTensor().strides().at(
2104               input_stride_arg.second);
2105       args.emplace_back(&stride_values[idx]);
2106     }
2107 
2108     TORCH_INTERNAL_ASSERT(
2109         nOutputs_ == static_cast<int64_t>(bufOutputs_.size()));
2110     for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
2111       auto& out = stack_outputs[i].toTensor();
2112       // This has only been tested on CPUs.
2113       // TODO: Test on GPUs.
2114       out.resize_(static_sizes[i]);
2115       args.emplace_back(out.data_ptr());
2116     }
2117   } else {
2118     for (auto i : c10::irange(nOutputs_)) {
2119       args.emplace_back(stack_outputs[i].toTensor().data_ptr());
2120     }
2121   }
2122 
2123   for (const auto& c : constants_) {
2124     args.emplace_back(c.ptr);
2125   }
2126 
2127   // Call the kernel.
2128   codegen_->call_raw(args);
2129 
2130   // Remove the inputs from the stack. The outputs are already below the inputs
2131   // in the stack.
2132   drop(stack, nInputs_);
2133 }
2134