xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/operators/misc.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
2 #include <torch/csrc/jit/tensorexpr/kernel.h>
3 #include <torch/csrc/jit/tensorexpr/operators/misc.h>
4 #include <torch/csrc/jit/tensorexpr/tensor.h>
5 
6 namespace torch::jit::tensorexpr {
7 
normalizeAndCheckIndex(int64_t idx,int64_t list_size)8 int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size) {
9   if (idx < 0) {
10     // Handle negative indexing
11     idx = list_size + idx;
12   }
13 
14   if (idx < 0 || idx >= list_size) {
15     AT_ERROR("Invalid index ", idx, " for list_size", list_size);
16   }
17   return idx;
18 }
19 
20 // Convert boolean to integer, if needed.
boolToInteger(const ExprHandle & x)21 ExprHandle boolToInteger(const ExprHandle& x) {
22   return x.dtype().scalar_type() == ScalarType::Bool ? cast<int>(x) : x;
23 }
24 
promoteToDtype(ExprHandle e,ScalarType dt)25 ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) {
26   if (e.dtype().scalar_type() == dt) {
27     return e;
28   }
29 
30   switch (dt) {
31 #define TYPE_CASE(Type, Name) \
32   case ScalarType::Name:      \
33     e = cast<Type>(e);        \
34     break;
35     AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
36 #undef TYPE_CASE
37     case ScalarType::QUInt8:
38       e = cast<c10::quint8>(e);
39       break;
40     case ScalarType::QInt8:
41       e = cast<c10::qint8>(e);
42       break;
43     default:
44       throw unsupported_dtype();
45   }
46   return e;
47 }
48 
checkTypes(const ScalarType highType,const int typeConstraints)49 static bool checkTypes(const ScalarType highType, const int typeConstraints) {
50   if (typeConstraints == kAllTypes) {
51     return true;
52   }
53 
54   if (c10::isIntegralType(highType, false)) {
55     return (typeConstraints & kIntegralTypes) != 0;
56   } else if (c10::isFloatingType(highType)) {
57     return (typeConstraints & kFloatingPointTypes) != 0;
58   } else if (highType == ScalarType::Bool) {
59     return (typeConstraints & kBoolType) != 0;
60   }
61 
62   // assume JIT not supporting complex and qint yet
63   TORCH_INTERNAL_ASSERT(
64       (typeConstraints & (kQintTypes | kComplexTypes)) == 0,
65       buildErrorMessage(
66           "Qint and Complex types are not supported in the fuser."));
67   return false;
68 }
69 
isScalar(const ExprHandle & e)70 static bool isScalar(const ExprHandle& e) {
71   auto n = e.node();
72   return n->isConstant() || to<Var>(n);
73 }
74 
promoteHalfToFloat(const ExprHandle & e)75 ExprHandle promoteHalfToFloat(const ExprHandle& e) {
76   auto scalarType = static_cast<c10::ScalarType>(e.dtype().scalar_type());
77   auto floatType = static_cast<c10::ScalarType>(tensorexpr::ScalarType::Float);
78   if (c10::isFloatingType(scalarType) &&
79       (c10::elementSize(scalarType) < c10::elementSize(floatType))) {
80     return Cast::make(
81         Dtype(tensorexpr::ScalarType::Float, e.dtype().lanes()), e);
82   } else {
83     return e;
84   }
85 }
86 
promoteInputs(std::vector<ExprHandle> & inputs,const int typeConstraints)87 void promoteInputs(std::vector<ExprHandle>& inputs, const int typeConstraints) {
88   if (inputs.empty()) {
89     return;
90   }
91 
92   // Find the highest type among the inputs.
93   ScalarType highType = inputs[0].dtype().scalar_type();
94   for (const auto& input : inputs) {
95     auto inputType = input.dtype().scalar_type();
96     if (isScalar(input)) {
97       if (isIntegralType(highType, false) && isFloatingType(inputType)) {
98         highType = c10::get_default_dtype_as_scalartype();
99       } else if (highType == c10::kBool) {
100         highType = inputType;
101       }
102     } else {
103       highType = promoteTypes(highType, inputType);
104     }
105   }
106 
107   if (!checkTypes(highType, typeConstraints)) {
108     throw unsupported_dtype();
109   }
110 
111   for (ExprHandle& e : inputs) {
112     e = promoteToDtype(e, highType);
113   }
114 }
115 
promoteIntegerToDefaultType(const ExprHandle & e)116 ExprHandle promoteIntegerToDefaultType(const ExprHandle& e) {
117   auto scalarType = static_cast<c10::ScalarType>(e.dtype().scalar_type());
118   if (!c10::isIntegralType(scalarType, /*includeBool*/ true)) {
119     return e;
120   }
121 
122   auto defaultType = c10::typeMetaToScalarType(c10::get_default_dtype());
123 
124   // We intend to promote Integers to floating-point types
125   TORCH_INTERNAL_ASSERT(
126       !c10::isIntegralType(defaultType, /*includeBool*/ true));
127 
128   return Cast::make(
129       Dtype(
130           static_cast<tensorexpr::ScalarType>(defaultType), e.dtype().lanes()),
131       e);
132 }
133 
demoteOutput(const ExprHandle & e,const std::optional<ScalarType> type)134 ExprHandle demoteOutput(
135     const ExprHandle& e,
136     const std::optional<ScalarType> type) {
137   if (!type.has_value()) {
138     return e;
139   }
140   if (*type == e.dtype().scalar_type()) {
141     return e;
142   }
143 
144   switch (*type) {
145 #define TYPE_CASE(Type, Name) \
146   case ScalarType::Name:      \
147     return cast<Type>(e);
148     AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
149 #undef TYPE_CASE
150     case ScalarType::Bool:
151       return cast<bool>(e);
152     default:
153       throw unsupported_dtype();
154   }
155 
156   return e;
157 }
158 
getTensorInfo(const BufHandle & b)159 std::optional<TensorInfo> getTensorInfo(const BufHandle& b) {
160   std::vector<int64_t> dims;
161   auto b_dims = b.dims();
162   dims.reserve(b_dims.size());
163   for (auto dim : b_dims) {
164     auto val = intValue(dim.node());
165     if (!val) {
166       return std::nullopt;
167     }
168     dims.push_back(*val);
169   }
170   return TensorInfo{dims, static_cast<at::ScalarType>(b.dtype().scalar_type())};
171 }
172 
clamp(const ExprHandle & cmin,const ExprHandle & cmax,const ExprHandle & input)173 ExprHandle clamp(
174     const ExprHandle& cmin,
175     const ExprHandle& cmax,
176     const ExprHandle& input) {
177   auto mm = CompareSelect::make(input, cmin, cmin, input, kLT);
178   return CompareSelect::make(mm, cmax, cmax, mm, kGT);
179 }
180 
isOne(const ExprHandle & e)181 static bool isOne(const ExprHandle& e) {
182   auto const& n = intValue(e);
183   if (!n) {
184     return false;
185   }
186   return *n == 1;
187 }
188 
broadcastShapesImpl(const std::vector<ExprHandle> & a,const std::vector<ExprHandle> & b)189 static std::pair<std::vector<ExprHandle>, bool> broadcastShapesImpl(
190     const std::vector<ExprHandle>& a,
191     const std::vector<ExprHandle>& b) {
192   auto at = a.rbegin();
193   auto bt = b.rbegin();
194   std::vector<ExprHandle> ret;
195   bool hasBroadcast = false;
196   while (at != a.rend() || bt != b.rend()) {
197     if (at == a.rend()) {
198       hasBroadcast = true;
199       ret.push_back(*bt++);
200       continue;
201     }
202     if (bt == b.rend()) {
203       hasBroadcast = true;
204       ret.push_back(*at++);
205       continue;
206     }
207     // TODO: if neither *at nor *bt is 1, ensure they are identical
208     // expressions.  Nb: `==` doesn't work since that simply produces a new
209     // ExprHandle.
210     ExprHandle dim = *at;
211     if (isOne(*at)) {
212       if (!isOne(*bt)) {
213         dim = *bt;
214         hasBroadcast = true;
215       }
216     }
217     ret.push_back(dim);
218     at++;
219     bt++;
220   }
221   std::reverse(ret.begin(), ret.end());
222   return {ret, hasBroadcast};
223 }
224 
broadcastShapesImpl(std::vector<std::vector<ExprHandle>> shapes)225 static std::pair<std::vector<ExprHandle>, bool> broadcastShapesImpl(
226     std::vector<std::vector<ExprHandle>> shapes) {
227   size_t n = shapes.size();
228   if (n == 1) {
229     return {shapes[0], false};
230   }
231   auto res1 = broadcastShapesImpl(shapes[n - 2], shapes[n - 1]);
232   shapes[n - 2] = res1.first;
233   shapes.pop_back();
234   auto res2 = broadcastShapesImpl(shapes);
235   return {res2.first, (res1.second || res2.second)};
236 }
237 
broadcastShapes(std::vector<std::vector<ExprHandle>> shapes)238 std::vector<ExprHandle> broadcastShapes(
239     std::vector<std::vector<ExprHandle>> shapes) {
240   return broadcastShapesImpl(std::move(shapes)).first;
241 }
242 
broadcastShapes(const std::vector<ExprHandle> & a,const std::vector<ExprHandle> & b)243 std::vector<ExprHandle> broadcastShapes(
244     const std::vector<ExprHandle>& a,
245     const std::vector<ExprHandle>& b) {
246   return broadcastShapesImpl(a, b).first;
247 }
248 
valueShape(const ArgValue & v)249 std::vector<ExprHandle> valueShape(const ArgValue& v) {
250   if (auto b = std::get_if<tensorexpr::BufHandle>(&v)) {
251     return b->dims();
252   }
253   return {};
254 }
255 
tensorOrConstant(const ArgValue & v,const std::vector<ExprHandle> & axes)256 ExprHandle tensorOrConstant(
257     const ArgValue& v,
258     const std::vector<ExprHandle>& axes) {
259   if (auto b = std::get_if<BufHandle>(&v)) {
260     return broadcast(*b, axes);
261   }
262   return constant(v);
263 }
264 
scalarOrConstant(const ArgValue & v)265 ExprHandle scalarOrConstant(const ArgValue& v) {
266   if (auto vh = std::get_if<VarHandle>(&v)) {
267     return *vh;
268   }
269   return constant(v);
270 }
271 
broadcast(const BufHandle & b,const std::vector<ExprHandle> & axes)272 ExprHandle broadcast(const BufHandle& b, const std::vector<ExprHandle>& axes) {
273   return b.load(computeIndicesToBroadcast(axes, b.dims()));
274 }
275 
constant(const ArgValue & v)276 ExprHandle constant(const ArgValue& v) {
277   if (auto s = std::get_if<tensorexpr::VarHandle>(&v)) {
278     return *s;
279   } else if (auto d = std::get_if<double>(&v)) {
280     return DoubleImm::make(*d);
281   } else if (auto i = std::get_if<int64_t>(&v)) {
282     return LongImm::make(*i);
283   } else if (auto b = std::get_if<bool>(&v)) {
284     return BoolImm::make(*b);
285   } else if (std::get_if<ArgNone>(&v)) {
286     // This is just a placeholder so we don't throw.  None-handling
287     // is operator-specific and should be handled properly in
288     // the operator-specific lowering code.
289     return IntImm::make(0);
290   } else {
291     throw unsupported_dtype("Trying to convert unsupported dtype to constant");
292   }
293 }
294 
computeIndicesToBroadcast(const std::vector<ExprHandle> & outputAxes,const std::vector<ExprHandle> & inputSizes)295 std::vector<ExprHandle> computeIndicesToBroadcast(
296     const std::vector<ExprHandle>& outputAxes,
297     const std::vector<ExprHandle>& inputSizes) {
298   if (outputAxes.size() < inputSizes.size()) {
299     throw malformed_input("Cannot broadcast to a lower rank tensor");
300   }
301   std::vector<ExprHandle> bcast;
302   auto axisIt = outputAxes.rbegin();
303   auto sizeIt = inputSizes.rbegin();
304   while (sizeIt != inputSizes.rend()) {
305     auto const& size = intValue(*sizeIt);
306     if (size && *size == 1) {
307       bcast.emplace_back(LongImm::make(0));
308     } else {
309       bcast.emplace_back(*axisIt);
310     }
311     ++axisIt;
312     ++sizeIt;
313   }
314   std::reverse(bcast.begin(), bcast.end());
315   return bcast;
316 }
317 
computeChunk(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)318 Tensor computeChunk(
319     const std::vector<ArgValue>& inputs,
320     const std::vector<ExprHandle>& outputShape,
321     const std::vector<ExprHandle>& outputStrides,
322     const std::optional<ScalarType>& outputType,
323     at::Device device) {
324   return Compute(
325       "prim_constantchunk",
326       outputShape,
327       [inputs](const std::vector<VarHandle>& axes) {
328         const auto& b = std::get<BufHandle>(inputs[0]);
329         int64_t chunkIdx = std::get<int64_t>(inputs[1]);
330         int64_t dim = std::get<int64_t>(inputs[2]);
331         int64_t chunks = std::get<int64_t>(inputs[3]);
332         std::vector<ExprHandle> indices(axes.begin(), axes.end());
333 
334         auto norm_dim = normalizeAndCheckIndex(dim, indices.size());
335         auto buf_info = getTensorInfo(b);
336         size_t step = buf_info->dims[norm_dim] / chunks;
337 
338         std::vector<ExprHandle> new_indices;
339         for (int64_t i = 0; i < static_cast<int64_t>(indices.size()); ++i) {
340           if (i == norm_dim) {
341             new_indices.push_back(
342                 indices[i] + ExprHandle(immLike(indices[i], chunkIdx * step)));
343           } else {
344             new_indices.push_back(indices[i]);
345           }
346         }
347 
348         return b.load(new_indices);
349       });
350 }
351 
computeTranspose(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)352 Tensor computeTranspose(
353     const std::vector<ArgValue>& inputs,
354     const std::vector<ExprHandle>& outputShape,
355     const std::vector<ExprHandle>& outputStrides,
356     const std::optional<ScalarType>& outputType,
357     at::Device device) {
358   auto A = std::get<BufHandle>(inputs[0]);
359   // Trivial case of 0-dim and 1-dim tensors: transpose is just a copy
360   if (A.ndim() <= 1) {
361     return Compute(
362         "aten_transpose", outputShape, [&](const std::vector<VarHandle>& axes) {
363           TORCH_INTERNAL_ASSERT(
364               axes.size() <= 1,
365               buildErrorMessage("Invalid axes size in transpose"));
366           return A.load(axes);
367         });
368   }
369   // Usual case where transpose actually swaps dimensions
370   auto start_dim = at::maybe_wrap_dim(std::get<int64_t>(inputs[1]), A.ndim());
371   auto to_dim = at::maybe_wrap_dim(std::get<int64_t>(inputs[2]), A.ndim());
372   return Compute(
373       "aten_transpose", outputShape, [&](std::vector<VarHandle> axes) {
374         std::swap(axes[start_dim], axes[to_dim]);
375         return A.load(axes);
376       });
377 }
378 
computeExpand(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)379 Tensor computeExpand(
380     const std::vector<ArgValue>& inputs,
381     const std::vector<ExprHandle>& outputShape,
382     const std::vector<ExprHandle>& outputStrides,
383     const std::optional<ScalarType>& outputType,
384     at::Device device) {
385   auto A = std::get<BufHandle>(inputs[0]);
386   return Compute(
387       "aten_expand", outputShape, [&](const std::vector<VarHandle>& axes) {
388         std::vector<ExprHandle> indices(axes.begin(), axes.end());
389         return broadcast(A, indices);
390       });
391 }
392 
computeReshape(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)393 Tensor computeReshape(
394     const std::vector<ArgValue>& inputs,
395     const std::vector<ExprHandle>& outputShape,
396     const std::vector<ExprHandle>& outputStrides,
397     const std::optional<ScalarType>& outputType,
398     at::Device device) {
399   auto A = std::get<BufHandle>(inputs[0]);
400   if (A.ndim() == 0) {
401     return Compute(
402         "aten_view", outputShape, [&](const std::vector<VarHandle>& axes) {
403           std::vector<ExprHandle> empty_indices;
404           return A.load(empty_indices);
405         });
406   }
407   return Compute(
408       "aten_reshape", outputShape, [&](const std::vector<VarHandle>& axes) {
409         std::vector<VarHandle> new_axes;
410         assert(outputShape.size() == axes.size());
411         /*
412         Example for the index transformation. Assume we have a tensor A and
413         its view B:
414           A.size() = [6,2,3]
415           B = A.view(2,1,9,1,2)
416 
417         In TE IR we would want to represent B as the following loopnest:
418           for (i1 in 0..2)
419             for (i2 in 0..1)
420               for (i3 in 0..9)
421                 for (i4 in 0..1)
422                   for (i5 in 0..2)
423                     idx = i5 + i4*2 + i3*2 + i2*18 + i1*18
424                     B[i1,i2,i3,i4,i5] = A[idx/(3*2), (idx/3)%2, idx%3]
425         */
426         std::vector<ExprPtr> dims, indices;
427         for (size_t idx = 0; idx < outputShape.size(); idx++) {
428           dims.push_back(outputShape[idx].node());
429           indices.push_back(axes[idx].node());
430         }
431 
432         auto ndim = dims.size();
433         std::vector<ExprPtr> strides(ndim);
434         strides[ndim - 1] = immLike(dims[ndim - 1], 1);
435         for (size_t i = 1; i < ndim; i++) {
436           strides[ndim - 1 - i] = alloc<Mul>(strides[ndim - i], dims[ndim - i]);
437         }
438 
439         ExprHandle flat_idx = ExprHandle(flatten_index(dims, indices, strides));
440         std::vector<ExprHandle> orig_buf_indexes(A.ndim(), ExprHandle(0));
441         ExprHandle stride = ExprHandle(immLike(flat_idx, 1));
442         for (size_t idx = 0; idx < A.ndim(); idx++) {
443           size_t dim_idx = A.ndim() - idx - 1;
444           // We don't need to generate mod-div for the first dimension -
445           // ideally IRSimplifier would get rid of that for us, but for now
446           // let's just avoid generating it in the first place.
447           if (dim_idx > 0) {
448             orig_buf_indexes[dim_idx] = flat_idx / stride % A.dim(dim_idx);
449           } else {
450             orig_buf_indexes[dim_idx] = flat_idx / stride;
451           }
452           // In the example above the stride is initially 1 for dim_idx = 2,
453           // then it's 3 for dim_idx = 1, and then it's 3*2 for dim_idx = 0.
454           stride = stride * A.dim(dim_idx);
455         }
456         return A.load(orig_buf_indexes);
457       });
458 }
459 
computeFlatten(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)460 Tensor computeFlatten(
461     const std::vector<ArgValue>& inputs,
462     const std::vector<ExprHandle>& outputShape,
463     const std::vector<ExprHandle>& outputStrides,
464     const std::optional<ScalarType>& outputType,
465     at::Device device) {
466   std::vector<int64_t> outputShapeVec;
467   for (const auto dim : c10::irange(outputShape.size())) {
468     outputShapeVec.push_back(outputShape[dim].AsNode<LongImm>()->value());
469   }
470   std::vector<ArgValue> reshapeInputs;
471   reshapeInputs.push_back(inputs[0]);
472   reshapeInputs.emplace_back(outputShapeVec);
473   return computeReshape(
474       reshapeInputs, outputShape, outputStrides, outputType, device);
475 }
476 
processCatList(const std::vector<BufHandle> & bufList)477 static std::pair<ScalarType, std::vector<BufHandle>> processCatList(
478     const std::vector<BufHandle>& bufList) {
479   if (bufList.empty()) {
480     throw std::runtime_error("Empty input list is passed to aten::cat");
481   }
482   std::vector<BufHandle> bufInputs;
483   std::vector<BufHandle> nonEmptyInputs;
484   for (auto buf : bufList) {
485     bufInputs.push_back(buf);
486     TORCH_INTERNAL_ASSERT(
487         !buf.node()->dims().empty(), buildErrorMessage("Invalid buf rank"));
488     // Ignore buffers that are 0-sized on any dimension.
489     bool hasEmptyDims = false;
490     for (const auto& dim : buf.dims()) {
491       if (dim.AsNode<LongImm>() && immediateAs<int64_t>(dim) == 0ll) {
492         hasEmptyDims = true;
493         break;
494       }
495     }
496     if (!hasEmptyDims) {
497       nonEmptyInputs.push_back(buf);
498     }
499   }
500   ScalarType highType = bufInputs[0].dtype().scalar_type();
501   for (const auto& input : bufInputs) {
502     auto maybe_dtype = input.dtype().scalar_type();
503     highType = promoteTypes(highType, maybe_dtype);
504   }
505   return {highType, nonEmptyInputs};
506 }
507 
computeCatWoConditionals(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides)508 static Tensor computeCatWoConditionals(
509     const std::vector<ArgValue>& inputs,
510     const std::vector<ExprHandle>& outputShape,
511     const std::vector<ExprHandle>& outputStrides) {
512   auto const& input_list = std::get<BufList>(inputs[0]);
513   auto arg_dim = inputs[1];
514   auto cat_info = processCatList(input_list);
515   ScalarType high_type = cat_info.first;
516   std::vector<BufHandle> non_empty_inputs = cat_info.second;
517 
518   // Now we build one loop per input:
519   //
520   // for i
521   //   for j
522   //     for k
523   //       output[i,j,k] = inp1[i,j,k]
524   // for i
525   //   for j
526   //     for k
527   //       output[i,j+l1,k] = inp2[i,j,k]
528   // for i
529   //   for j
530   //     for k
531   //       output[i,j+l2,k] = inp3[i,j,k]
532 
533   auto output_sizes_expr = ExprHandleVectorToExprVector(outputShape);
534   auto output_strides_expr = ExprHandleVectorToExprVector(outputStrides);
535   auto output_buf = alloc<Buf>(
536       "aten_cat",
537       output_sizes_expr,
538       ToDtype(high_type),
539       nullptr,
540       output_strides_expr);
541   if (non_empty_inputs.empty()) {
542     return Tensor(
543         output_buf, alloc<tensorexpr::Block>(std::vector<StmtPtr>({})));
544   }
545 
546   int64_t concat_dim = std::get<int64_t>(arg_dim);
547   auto norm_concat_dim = normalizeAndCheckIndex(concat_dim, outputShape.size());
548 
549   auto loop_order_fn = [&](const BufPtr& buf_) {
550     std::vector<int32_t> loop_order;
551     if (buf_->is_contiguous()) {
552       for (int32_t i = buf_->ndim() - 1; i >= 0; i--) {
553         loop_order.push_back(i);
554       }
555     } else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast)) {
556       loop_order = {1, 3, 2, 0};
557     } else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast3d)) {
558       loop_order = {1, 4, 3, 2, 0};
559     } else {
560       loop_order = {1, 2, 0};
561     }
562 
563     return loop_order;
564   };
565 
566   auto gen_code_for_input = [&](const BufHandle& inp,
567                                 size_t inp_pos,
568                                 const ExprPtr& concat_dim_size,
569                                 const std::vector<ExprHandle>& dims) {
570     std::vector<VarPtr> for_vars(dims.size());
571     std::vector<ExprPtr> load_indices(dims.size());
572     std::vector<ExprPtr> store_indices(dims.size());
573     for (int64_t i = 0; i < static_cast<int64_t>(dims.size()); ++i) {
574       for_vars[i] = alloc<Var>(
575           "i" + std::to_string(inp_pos) + "_" + std::to_string(i),
576           dims[i].dtype());
577       load_indices[i] = for_vars[i];
578       if (i == norm_concat_dim) {
579         store_indices[i] = alloc<Add>(for_vars[i], concat_dim_size);
580       } else {
581         store_indices[i] = for_vars[i];
582       }
583     }
584     auto inp_buf = inp.node();
585     auto load_expr = alloc<Load>(inp_buf, load_indices);
586     auto load_promoted = promoteToDtype(ExprHandle(load_expr), high_type);
587     StmtPtr st = alloc<Store>(output_buf, store_indices, load_promoted.node());
588 
589     auto loop_order = loop_order_fn(inp.node());
590     for (auto dim_index : loop_order) {
591       st = alloc<For>(
592           for_vars[dim_index],
593           immLike(dims[dim_index], 0),
594           dims[dim_index].node(),
595           st);
596     }
597 
598     return st;
599   };
600 
601   ExprPtr concat_dim_size = nullptr;
602   auto block = alloc<tensorexpr::Block>(std::vector<StmtPtr>({}));
603   for (size_t i = 0; i < non_empty_inputs.size(); ++i) {
604     auto input_dims =
605         ExprVectorToExprHandleVector(non_empty_inputs[i].node()->dims());
606     if (concat_dim_size == nullptr) {
607       concat_dim_size = immLike(input_dims[norm_concat_dim], 0);
608     }
609     block->append_stmt(gen_code_for_input(
610         non_empty_inputs[i], i, concat_dim_size, input_dims));
611     concat_dim_size =
612         alloc<Add>(concat_dim_size, input_dims[norm_concat_dim].node());
613   }
614   return Tensor(output_buf, IRSimplifier::simplify(block));
615 }
616 
computeCat(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)617 Tensor computeCat(
618     const std::vector<ArgValue>& inputs,
619     const std::vector<ExprHandle>& outputShape,
620     const std::vector<ExprHandle>& outputStrides,
621     const std::optional<ScalarType>& outputType,
622     at::Device device) {
623   if (device == at::kCPU && getCatWoConditionals()) {
624     return computeCatWoConditionals(inputs, outputShape, outputStrides);
625   }
626   auto const& inputList = std::get<BufList>(inputs[0]);
627   auto argDim = inputs[1];
628   auto catInfo = processCatList(inputList);
629   ScalarType highType = catInfo.first;
630   std::vector<BufHandle> nonEmptyInputs = catInfo.second;
631   return Compute(
632       "aten_cat",
633       outputShape,
634       outputStrides,
635       [&](const std::vector<VarHandle>& axes) {
636         if (nonEmptyInputs.empty()) {
637           return ExprHandle(0);
638         }
639 
640         int64_t dim_ = std::get<int64_t>(argDim);
641         auto dim = normalizeAndCheckIndex(dim_, axes.size());
642         // Promote input types.
643         // Note that we need to consider all inputs, including empty - they
644         // also affect the resultant dtype.
645 
646         // Now we know the final dtype, we know what inputs are non-empty,
647         // and we know that there is at least one such an input. With all
648         // that we construct a tensor expression performing the
649         // concatenation.
650         // The expression we build here is a cascading if-then-else that
651         // essentially represents:
652         //
653         //              inp1[i, j, k]         if 0   < i < l1,
654         // out[i,j,k] = inp2[i, j-l1, k]      if l1 =< i < l1 + l2,
655         //              ...
656         //              inpN[i, j-l_N_1, k]   if l1+l2+...l_N_1  < i
657         // where l_i is the corresponding size of the i-th input.
658         std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
659         ExprHandle load = promoteToDtype(
660             tensorOrConstant(nonEmptyInputs[0], newAxes), highType);
661         auto offset = ExprHandle(nonEmptyInputs[0].node()->dim(dim));
662         newAxes[dim] = newAxes[dim] - offset;
663 
664         for (size_t ii = 1; ii < nonEmptyInputs.size(); ++ii) {
665           auto input = nonEmptyInputs[ii];
666           load = ifThenElse(
667               CompareSelect::make(axes[dim], offset, kLT),
668               load,
669               promoteToDtype(tensorOrConstant(input, newAxes), highType));
670 
671           offset = offset + ExprHandle(input.node()->dim(dim));
672           newAxes[dim] = axes[dim] - offset;
673         }
674 
675         return load;
676       });
677 }
678 
computeEmbedding(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)679 Tensor computeEmbedding(
680     const std::vector<ArgValue>& inputs,
681     const std::vector<ExprHandle>& outputShape,
682     const std::vector<ExprHandle>& outputStrides,
683     const std::optional<ScalarType>& outputType,
684     at::Device device) {
685   Dtype dtype = kFloat;
686   if (outputType) {
687     dtype = Dtype(*outputType);
688   }
689 
690   BufHandle ResultBuf("emb", outputShape, dtype);
691   const BufHandle& w = std::get<BufHandle>(inputs[0]);
692   const BufHandle& indices = std::get<BufHandle>(inputs[1]);
693 
694   StmtPtr s =
695       ExternalCall::make(ResultBuf, "nnc_aten_embedding", {w, indices}, {});
696   return Tensor(ResultBuf.node(), s);
697 }
698 
699 } // namespace torch::jit::tensorexpr
700