1 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
2 #include <torch/csrc/jit/tensorexpr/kernel.h>
3 #include <torch/csrc/jit/tensorexpr/operators/misc.h>
4 #include <torch/csrc/jit/tensorexpr/tensor.h>
5
6 namespace torch::jit::tensorexpr {
7
normalizeAndCheckIndex(int64_t idx,int64_t list_size)8 int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size) {
9 if (idx < 0) {
10 // Handle negative indexing
11 idx = list_size + idx;
12 }
13
14 if (idx < 0 || idx >= list_size) {
15 AT_ERROR("Invalid index ", idx, " for list_size", list_size);
16 }
17 return idx;
18 }
19
20 // Convert boolean to integer, if needed.
boolToInteger(const ExprHandle & x)21 ExprHandle boolToInteger(const ExprHandle& x) {
22 return x.dtype().scalar_type() == ScalarType::Bool ? cast<int>(x) : x;
23 }
24
promoteToDtype(ExprHandle e,ScalarType dt)25 ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) {
26 if (e.dtype().scalar_type() == dt) {
27 return e;
28 }
29
30 switch (dt) {
31 #define TYPE_CASE(Type, Name) \
32 case ScalarType::Name: \
33 e = cast<Type>(e); \
34 break;
35 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
36 #undef TYPE_CASE
37 case ScalarType::QUInt8:
38 e = cast<c10::quint8>(e);
39 break;
40 case ScalarType::QInt8:
41 e = cast<c10::qint8>(e);
42 break;
43 default:
44 throw unsupported_dtype();
45 }
46 return e;
47 }
48
checkTypes(const ScalarType highType,const int typeConstraints)49 static bool checkTypes(const ScalarType highType, const int typeConstraints) {
50 if (typeConstraints == kAllTypes) {
51 return true;
52 }
53
54 if (c10::isIntegralType(highType, false)) {
55 return (typeConstraints & kIntegralTypes) != 0;
56 } else if (c10::isFloatingType(highType)) {
57 return (typeConstraints & kFloatingPointTypes) != 0;
58 } else if (highType == ScalarType::Bool) {
59 return (typeConstraints & kBoolType) != 0;
60 }
61
62 // assume JIT not supporting complex and qint yet
63 TORCH_INTERNAL_ASSERT(
64 (typeConstraints & (kQintTypes | kComplexTypes)) == 0,
65 buildErrorMessage(
66 "Qint and Complex types are not supported in the fuser."));
67 return false;
68 }
69
isScalar(const ExprHandle & e)70 static bool isScalar(const ExprHandle& e) {
71 auto n = e.node();
72 return n->isConstant() || to<Var>(n);
73 }
74
promoteHalfToFloat(const ExprHandle & e)75 ExprHandle promoteHalfToFloat(const ExprHandle& e) {
76 auto scalarType = static_cast<c10::ScalarType>(e.dtype().scalar_type());
77 auto floatType = static_cast<c10::ScalarType>(tensorexpr::ScalarType::Float);
78 if (c10::isFloatingType(scalarType) &&
79 (c10::elementSize(scalarType) < c10::elementSize(floatType))) {
80 return Cast::make(
81 Dtype(tensorexpr::ScalarType::Float, e.dtype().lanes()), e);
82 } else {
83 return e;
84 }
85 }
86
promoteInputs(std::vector<ExprHandle> & inputs,const int typeConstraints)87 void promoteInputs(std::vector<ExprHandle>& inputs, const int typeConstraints) {
88 if (inputs.empty()) {
89 return;
90 }
91
92 // Find the highest type among the inputs.
93 ScalarType highType = inputs[0].dtype().scalar_type();
94 for (const auto& input : inputs) {
95 auto inputType = input.dtype().scalar_type();
96 if (isScalar(input)) {
97 if (isIntegralType(highType, false) && isFloatingType(inputType)) {
98 highType = c10::get_default_dtype_as_scalartype();
99 } else if (highType == c10::kBool) {
100 highType = inputType;
101 }
102 } else {
103 highType = promoteTypes(highType, inputType);
104 }
105 }
106
107 if (!checkTypes(highType, typeConstraints)) {
108 throw unsupported_dtype();
109 }
110
111 for (ExprHandle& e : inputs) {
112 e = promoteToDtype(e, highType);
113 }
114 }
115
promoteIntegerToDefaultType(const ExprHandle & e)116 ExprHandle promoteIntegerToDefaultType(const ExprHandle& e) {
117 auto scalarType = static_cast<c10::ScalarType>(e.dtype().scalar_type());
118 if (!c10::isIntegralType(scalarType, /*includeBool*/ true)) {
119 return e;
120 }
121
122 auto defaultType = c10::typeMetaToScalarType(c10::get_default_dtype());
123
124 // We intend to promote Integers to floating-point types
125 TORCH_INTERNAL_ASSERT(
126 !c10::isIntegralType(defaultType, /*includeBool*/ true));
127
128 return Cast::make(
129 Dtype(
130 static_cast<tensorexpr::ScalarType>(defaultType), e.dtype().lanes()),
131 e);
132 }
133
demoteOutput(const ExprHandle & e,const std::optional<ScalarType> type)134 ExprHandle demoteOutput(
135 const ExprHandle& e,
136 const std::optional<ScalarType> type) {
137 if (!type.has_value()) {
138 return e;
139 }
140 if (*type == e.dtype().scalar_type()) {
141 return e;
142 }
143
144 switch (*type) {
145 #define TYPE_CASE(Type, Name) \
146 case ScalarType::Name: \
147 return cast<Type>(e);
148 AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
149 #undef TYPE_CASE
150 case ScalarType::Bool:
151 return cast<bool>(e);
152 default:
153 throw unsupported_dtype();
154 }
155
156 return e;
157 }
158
getTensorInfo(const BufHandle & b)159 std::optional<TensorInfo> getTensorInfo(const BufHandle& b) {
160 std::vector<int64_t> dims;
161 auto b_dims = b.dims();
162 dims.reserve(b_dims.size());
163 for (auto dim : b_dims) {
164 auto val = intValue(dim.node());
165 if (!val) {
166 return std::nullopt;
167 }
168 dims.push_back(*val);
169 }
170 return TensorInfo{dims, static_cast<at::ScalarType>(b.dtype().scalar_type())};
171 }
172
clamp(const ExprHandle & cmin,const ExprHandle & cmax,const ExprHandle & input)173 ExprHandle clamp(
174 const ExprHandle& cmin,
175 const ExprHandle& cmax,
176 const ExprHandle& input) {
177 auto mm = CompareSelect::make(input, cmin, cmin, input, kLT);
178 return CompareSelect::make(mm, cmax, cmax, mm, kGT);
179 }
180
isOne(const ExprHandle & e)181 static bool isOne(const ExprHandle& e) {
182 auto const& n = intValue(e);
183 if (!n) {
184 return false;
185 }
186 return *n == 1;
187 }
188
broadcastShapesImpl(const std::vector<ExprHandle> & a,const std::vector<ExprHandle> & b)189 static std::pair<std::vector<ExprHandle>, bool> broadcastShapesImpl(
190 const std::vector<ExprHandle>& a,
191 const std::vector<ExprHandle>& b) {
192 auto at = a.rbegin();
193 auto bt = b.rbegin();
194 std::vector<ExprHandle> ret;
195 bool hasBroadcast = false;
196 while (at != a.rend() || bt != b.rend()) {
197 if (at == a.rend()) {
198 hasBroadcast = true;
199 ret.push_back(*bt++);
200 continue;
201 }
202 if (bt == b.rend()) {
203 hasBroadcast = true;
204 ret.push_back(*at++);
205 continue;
206 }
207 // TODO: if neither *at nor *bt is 1, ensure they are identical
208 // expressions. Nb: `==` doesn't work since that simply produces a new
209 // ExprHandle.
210 ExprHandle dim = *at;
211 if (isOne(*at)) {
212 if (!isOne(*bt)) {
213 dim = *bt;
214 hasBroadcast = true;
215 }
216 }
217 ret.push_back(dim);
218 at++;
219 bt++;
220 }
221 std::reverse(ret.begin(), ret.end());
222 return {ret, hasBroadcast};
223 }
224
broadcastShapesImpl(std::vector<std::vector<ExprHandle>> shapes)225 static std::pair<std::vector<ExprHandle>, bool> broadcastShapesImpl(
226 std::vector<std::vector<ExprHandle>> shapes) {
227 size_t n = shapes.size();
228 if (n == 1) {
229 return {shapes[0], false};
230 }
231 auto res1 = broadcastShapesImpl(shapes[n - 2], shapes[n - 1]);
232 shapes[n - 2] = res1.first;
233 shapes.pop_back();
234 auto res2 = broadcastShapesImpl(shapes);
235 return {res2.first, (res1.second || res2.second)};
236 }
237
broadcastShapes(std::vector<std::vector<ExprHandle>> shapes)238 std::vector<ExprHandle> broadcastShapes(
239 std::vector<std::vector<ExprHandle>> shapes) {
240 return broadcastShapesImpl(std::move(shapes)).first;
241 }
242
broadcastShapes(const std::vector<ExprHandle> & a,const std::vector<ExprHandle> & b)243 std::vector<ExprHandle> broadcastShapes(
244 const std::vector<ExprHandle>& a,
245 const std::vector<ExprHandle>& b) {
246 return broadcastShapesImpl(a, b).first;
247 }
248
valueShape(const ArgValue & v)249 std::vector<ExprHandle> valueShape(const ArgValue& v) {
250 if (auto b = std::get_if<tensorexpr::BufHandle>(&v)) {
251 return b->dims();
252 }
253 return {};
254 }
255
tensorOrConstant(const ArgValue & v,const std::vector<ExprHandle> & axes)256 ExprHandle tensorOrConstant(
257 const ArgValue& v,
258 const std::vector<ExprHandle>& axes) {
259 if (auto b = std::get_if<BufHandle>(&v)) {
260 return broadcast(*b, axes);
261 }
262 return constant(v);
263 }
264
scalarOrConstant(const ArgValue & v)265 ExprHandle scalarOrConstant(const ArgValue& v) {
266 if (auto vh = std::get_if<VarHandle>(&v)) {
267 return *vh;
268 }
269 return constant(v);
270 }
271
broadcast(const BufHandle & b,const std::vector<ExprHandle> & axes)272 ExprHandle broadcast(const BufHandle& b, const std::vector<ExprHandle>& axes) {
273 return b.load(computeIndicesToBroadcast(axes, b.dims()));
274 }
275
constant(const ArgValue & v)276 ExprHandle constant(const ArgValue& v) {
277 if (auto s = std::get_if<tensorexpr::VarHandle>(&v)) {
278 return *s;
279 } else if (auto d = std::get_if<double>(&v)) {
280 return DoubleImm::make(*d);
281 } else if (auto i = std::get_if<int64_t>(&v)) {
282 return LongImm::make(*i);
283 } else if (auto b = std::get_if<bool>(&v)) {
284 return BoolImm::make(*b);
285 } else if (std::get_if<ArgNone>(&v)) {
286 // This is just a placeholder so we don't throw. None-handling
287 // is operator-specific and should be handled properly in
288 // the operator-specific lowering code.
289 return IntImm::make(0);
290 } else {
291 throw unsupported_dtype("Trying to convert unsupported dtype to constant");
292 }
293 }
294
computeIndicesToBroadcast(const std::vector<ExprHandle> & outputAxes,const std::vector<ExprHandle> & inputSizes)295 std::vector<ExprHandle> computeIndicesToBroadcast(
296 const std::vector<ExprHandle>& outputAxes,
297 const std::vector<ExprHandle>& inputSizes) {
298 if (outputAxes.size() < inputSizes.size()) {
299 throw malformed_input("Cannot broadcast to a lower rank tensor");
300 }
301 std::vector<ExprHandle> bcast;
302 auto axisIt = outputAxes.rbegin();
303 auto sizeIt = inputSizes.rbegin();
304 while (sizeIt != inputSizes.rend()) {
305 auto const& size = intValue(*sizeIt);
306 if (size && *size == 1) {
307 bcast.emplace_back(LongImm::make(0));
308 } else {
309 bcast.emplace_back(*axisIt);
310 }
311 ++axisIt;
312 ++sizeIt;
313 }
314 std::reverse(bcast.begin(), bcast.end());
315 return bcast;
316 }
317
computeChunk(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)318 Tensor computeChunk(
319 const std::vector<ArgValue>& inputs,
320 const std::vector<ExprHandle>& outputShape,
321 const std::vector<ExprHandle>& outputStrides,
322 const std::optional<ScalarType>& outputType,
323 at::Device device) {
324 return Compute(
325 "prim_constantchunk",
326 outputShape,
327 [inputs](const std::vector<VarHandle>& axes) {
328 const auto& b = std::get<BufHandle>(inputs[0]);
329 int64_t chunkIdx = std::get<int64_t>(inputs[1]);
330 int64_t dim = std::get<int64_t>(inputs[2]);
331 int64_t chunks = std::get<int64_t>(inputs[3]);
332 std::vector<ExprHandle> indices(axes.begin(), axes.end());
333
334 auto norm_dim = normalizeAndCheckIndex(dim, indices.size());
335 auto buf_info = getTensorInfo(b);
336 size_t step = buf_info->dims[norm_dim] / chunks;
337
338 std::vector<ExprHandle> new_indices;
339 for (int64_t i = 0; i < static_cast<int64_t>(indices.size()); ++i) {
340 if (i == norm_dim) {
341 new_indices.push_back(
342 indices[i] + ExprHandle(immLike(indices[i], chunkIdx * step)));
343 } else {
344 new_indices.push_back(indices[i]);
345 }
346 }
347
348 return b.load(new_indices);
349 });
350 }
351
computeTranspose(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)352 Tensor computeTranspose(
353 const std::vector<ArgValue>& inputs,
354 const std::vector<ExprHandle>& outputShape,
355 const std::vector<ExprHandle>& outputStrides,
356 const std::optional<ScalarType>& outputType,
357 at::Device device) {
358 auto A = std::get<BufHandle>(inputs[0]);
359 // Trivial case of 0-dim and 1-dim tensors: transpose is just a copy
360 if (A.ndim() <= 1) {
361 return Compute(
362 "aten_transpose", outputShape, [&](const std::vector<VarHandle>& axes) {
363 TORCH_INTERNAL_ASSERT(
364 axes.size() <= 1,
365 buildErrorMessage("Invalid axes size in transpose"));
366 return A.load(axes);
367 });
368 }
369 // Usual case where transpose actually swaps dimensions
370 auto start_dim = at::maybe_wrap_dim(std::get<int64_t>(inputs[1]), A.ndim());
371 auto to_dim = at::maybe_wrap_dim(std::get<int64_t>(inputs[2]), A.ndim());
372 return Compute(
373 "aten_transpose", outputShape, [&](std::vector<VarHandle> axes) {
374 std::swap(axes[start_dim], axes[to_dim]);
375 return A.load(axes);
376 });
377 }
378
computeExpand(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)379 Tensor computeExpand(
380 const std::vector<ArgValue>& inputs,
381 const std::vector<ExprHandle>& outputShape,
382 const std::vector<ExprHandle>& outputStrides,
383 const std::optional<ScalarType>& outputType,
384 at::Device device) {
385 auto A = std::get<BufHandle>(inputs[0]);
386 return Compute(
387 "aten_expand", outputShape, [&](const std::vector<VarHandle>& axes) {
388 std::vector<ExprHandle> indices(axes.begin(), axes.end());
389 return broadcast(A, indices);
390 });
391 }
392
computeReshape(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)393 Tensor computeReshape(
394 const std::vector<ArgValue>& inputs,
395 const std::vector<ExprHandle>& outputShape,
396 const std::vector<ExprHandle>& outputStrides,
397 const std::optional<ScalarType>& outputType,
398 at::Device device) {
399 auto A = std::get<BufHandle>(inputs[0]);
400 if (A.ndim() == 0) {
401 return Compute(
402 "aten_view", outputShape, [&](const std::vector<VarHandle>& axes) {
403 std::vector<ExprHandle> empty_indices;
404 return A.load(empty_indices);
405 });
406 }
407 return Compute(
408 "aten_reshape", outputShape, [&](const std::vector<VarHandle>& axes) {
409 std::vector<VarHandle> new_axes;
410 assert(outputShape.size() == axes.size());
411 /*
412 Example for the index transformation. Assume we have a tensor A and
413 its view B:
414 A.size() = [6,2,3]
415 B = A.view(2,1,9,1,2)
416
417 In TE IR we would want to represent B as the following loopnest:
418 for (i1 in 0..2)
419 for (i2 in 0..1)
420 for (i3 in 0..9)
421 for (i4 in 0..1)
422 for (i5 in 0..2)
423 idx = i5 + i4*2 + i3*2 + i2*18 + i1*18
424 B[i1,i2,i3,i4,i5] = A[idx/(3*2), (idx/3)%2, idx%3]
425 */
426 std::vector<ExprPtr> dims, indices;
427 for (size_t idx = 0; idx < outputShape.size(); idx++) {
428 dims.push_back(outputShape[idx].node());
429 indices.push_back(axes[idx].node());
430 }
431
432 auto ndim = dims.size();
433 std::vector<ExprPtr> strides(ndim);
434 strides[ndim - 1] = immLike(dims[ndim - 1], 1);
435 for (size_t i = 1; i < ndim; i++) {
436 strides[ndim - 1 - i] = alloc<Mul>(strides[ndim - i], dims[ndim - i]);
437 }
438
439 ExprHandle flat_idx = ExprHandle(flatten_index(dims, indices, strides));
440 std::vector<ExprHandle> orig_buf_indexes(A.ndim(), ExprHandle(0));
441 ExprHandle stride = ExprHandle(immLike(flat_idx, 1));
442 for (size_t idx = 0; idx < A.ndim(); idx++) {
443 size_t dim_idx = A.ndim() - idx - 1;
444 // We don't need to generate mod-div for the first dimension -
445 // ideally IRSimplifier would get rid of that for us, but for now
446 // let's just avoid generating it in the first place.
447 if (dim_idx > 0) {
448 orig_buf_indexes[dim_idx] = flat_idx / stride % A.dim(dim_idx);
449 } else {
450 orig_buf_indexes[dim_idx] = flat_idx / stride;
451 }
452 // In the example above the stride is initially 1 for dim_idx = 2,
453 // then it's 3 for dim_idx = 1, and then it's 3*2 for dim_idx = 0.
454 stride = stride * A.dim(dim_idx);
455 }
456 return A.load(orig_buf_indexes);
457 });
458 }
459
computeFlatten(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)460 Tensor computeFlatten(
461 const std::vector<ArgValue>& inputs,
462 const std::vector<ExprHandle>& outputShape,
463 const std::vector<ExprHandle>& outputStrides,
464 const std::optional<ScalarType>& outputType,
465 at::Device device) {
466 std::vector<int64_t> outputShapeVec;
467 for (const auto dim : c10::irange(outputShape.size())) {
468 outputShapeVec.push_back(outputShape[dim].AsNode<LongImm>()->value());
469 }
470 std::vector<ArgValue> reshapeInputs;
471 reshapeInputs.push_back(inputs[0]);
472 reshapeInputs.emplace_back(outputShapeVec);
473 return computeReshape(
474 reshapeInputs, outputShape, outputStrides, outputType, device);
475 }
476
processCatList(const std::vector<BufHandle> & bufList)477 static std::pair<ScalarType, std::vector<BufHandle>> processCatList(
478 const std::vector<BufHandle>& bufList) {
479 if (bufList.empty()) {
480 throw std::runtime_error("Empty input list is passed to aten::cat");
481 }
482 std::vector<BufHandle> bufInputs;
483 std::vector<BufHandle> nonEmptyInputs;
484 for (auto buf : bufList) {
485 bufInputs.push_back(buf);
486 TORCH_INTERNAL_ASSERT(
487 !buf.node()->dims().empty(), buildErrorMessage("Invalid buf rank"));
488 // Ignore buffers that are 0-sized on any dimension.
489 bool hasEmptyDims = false;
490 for (const auto& dim : buf.dims()) {
491 if (dim.AsNode<LongImm>() && immediateAs<int64_t>(dim) == 0ll) {
492 hasEmptyDims = true;
493 break;
494 }
495 }
496 if (!hasEmptyDims) {
497 nonEmptyInputs.push_back(buf);
498 }
499 }
500 ScalarType highType = bufInputs[0].dtype().scalar_type();
501 for (const auto& input : bufInputs) {
502 auto maybe_dtype = input.dtype().scalar_type();
503 highType = promoteTypes(highType, maybe_dtype);
504 }
505 return {highType, nonEmptyInputs};
506 }
507
computeCatWoConditionals(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides)508 static Tensor computeCatWoConditionals(
509 const std::vector<ArgValue>& inputs,
510 const std::vector<ExprHandle>& outputShape,
511 const std::vector<ExprHandle>& outputStrides) {
512 auto const& input_list = std::get<BufList>(inputs[0]);
513 auto arg_dim = inputs[1];
514 auto cat_info = processCatList(input_list);
515 ScalarType high_type = cat_info.first;
516 std::vector<BufHandle> non_empty_inputs = cat_info.second;
517
518 // Now we build one loop per input:
519 //
520 // for i
521 // for j
522 // for k
523 // output[i,j,k] = inp1[i,j,k]
524 // for i
525 // for j
526 // for k
527 // output[i,j+l1,k] = inp2[i,j,k]
528 // for i
529 // for j
530 // for k
531 // output[i,j+l2,k] = inp3[i,j,k]
532
533 auto output_sizes_expr = ExprHandleVectorToExprVector(outputShape);
534 auto output_strides_expr = ExprHandleVectorToExprVector(outputStrides);
535 auto output_buf = alloc<Buf>(
536 "aten_cat",
537 output_sizes_expr,
538 ToDtype(high_type),
539 nullptr,
540 output_strides_expr);
541 if (non_empty_inputs.empty()) {
542 return Tensor(
543 output_buf, alloc<tensorexpr::Block>(std::vector<StmtPtr>({})));
544 }
545
546 int64_t concat_dim = std::get<int64_t>(arg_dim);
547 auto norm_concat_dim = normalizeAndCheckIndex(concat_dim, outputShape.size());
548
549 auto loop_order_fn = [&](const BufPtr& buf_) {
550 std::vector<int32_t> loop_order;
551 if (buf_->is_contiguous()) {
552 for (int32_t i = buf_->ndim() - 1; i >= 0; i--) {
553 loop_order.push_back(i);
554 }
555 } else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast)) {
556 loop_order = {1, 3, 2, 0};
557 } else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast3d)) {
558 loop_order = {1, 4, 3, 2, 0};
559 } else {
560 loop_order = {1, 2, 0};
561 }
562
563 return loop_order;
564 };
565
566 auto gen_code_for_input = [&](const BufHandle& inp,
567 size_t inp_pos,
568 const ExprPtr& concat_dim_size,
569 const std::vector<ExprHandle>& dims) {
570 std::vector<VarPtr> for_vars(dims.size());
571 std::vector<ExprPtr> load_indices(dims.size());
572 std::vector<ExprPtr> store_indices(dims.size());
573 for (int64_t i = 0; i < static_cast<int64_t>(dims.size()); ++i) {
574 for_vars[i] = alloc<Var>(
575 "i" + std::to_string(inp_pos) + "_" + std::to_string(i),
576 dims[i].dtype());
577 load_indices[i] = for_vars[i];
578 if (i == norm_concat_dim) {
579 store_indices[i] = alloc<Add>(for_vars[i], concat_dim_size);
580 } else {
581 store_indices[i] = for_vars[i];
582 }
583 }
584 auto inp_buf = inp.node();
585 auto load_expr = alloc<Load>(inp_buf, load_indices);
586 auto load_promoted = promoteToDtype(ExprHandle(load_expr), high_type);
587 StmtPtr st = alloc<Store>(output_buf, store_indices, load_promoted.node());
588
589 auto loop_order = loop_order_fn(inp.node());
590 for (auto dim_index : loop_order) {
591 st = alloc<For>(
592 for_vars[dim_index],
593 immLike(dims[dim_index], 0),
594 dims[dim_index].node(),
595 st);
596 }
597
598 return st;
599 };
600
601 ExprPtr concat_dim_size = nullptr;
602 auto block = alloc<tensorexpr::Block>(std::vector<StmtPtr>({}));
603 for (size_t i = 0; i < non_empty_inputs.size(); ++i) {
604 auto input_dims =
605 ExprVectorToExprHandleVector(non_empty_inputs[i].node()->dims());
606 if (concat_dim_size == nullptr) {
607 concat_dim_size = immLike(input_dims[norm_concat_dim], 0);
608 }
609 block->append_stmt(gen_code_for_input(
610 non_empty_inputs[i], i, concat_dim_size, input_dims));
611 concat_dim_size =
612 alloc<Add>(concat_dim_size, input_dims[norm_concat_dim].node());
613 }
614 return Tensor(output_buf, IRSimplifier::simplify(block));
615 }
616
computeCat(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)617 Tensor computeCat(
618 const std::vector<ArgValue>& inputs,
619 const std::vector<ExprHandle>& outputShape,
620 const std::vector<ExprHandle>& outputStrides,
621 const std::optional<ScalarType>& outputType,
622 at::Device device) {
623 if (device == at::kCPU && getCatWoConditionals()) {
624 return computeCatWoConditionals(inputs, outputShape, outputStrides);
625 }
626 auto const& inputList = std::get<BufList>(inputs[0]);
627 auto argDim = inputs[1];
628 auto catInfo = processCatList(inputList);
629 ScalarType highType = catInfo.first;
630 std::vector<BufHandle> nonEmptyInputs = catInfo.second;
631 return Compute(
632 "aten_cat",
633 outputShape,
634 outputStrides,
635 [&](const std::vector<VarHandle>& axes) {
636 if (nonEmptyInputs.empty()) {
637 return ExprHandle(0);
638 }
639
640 int64_t dim_ = std::get<int64_t>(argDim);
641 auto dim = normalizeAndCheckIndex(dim_, axes.size());
642 // Promote input types.
643 // Note that we need to consider all inputs, including empty - they
644 // also affect the resultant dtype.
645
646 // Now we know the final dtype, we know what inputs are non-empty,
647 // and we know that there is at least one such an input. With all
648 // that we construct a tensor expression performing the
649 // concatenation.
650 // The expression we build here is a cascading if-then-else that
651 // essentially represents:
652 //
653 // inp1[i, j, k] if 0 < i < l1,
654 // out[i,j,k] = inp2[i, j-l1, k] if l1 =< i < l1 + l2,
655 // ...
656 // inpN[i, j-l_N_1, k] if l1+l2+...l_N_1 < i
657 // where l_i is the corresponding size of the i-th input.
658 std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
659 ExprHandle load = promoteToDtype(
660 tensorOrConstant(nonEmptyInputs[0], newAxes), highType);
661 auto offset = ExprHandle(nonEmptyInputs[0].node()->dim(dim));
662 newAxes[dim] = newAxes[dim] - offset;
663
664 for (size_t ii = 1; ii < nonEmptyInputs.size(); ++ii) {
665 auto input = nonEmptyInputs[ii];
666 load = ifThenElse(
667 CompareSelect::make(axes[dim], offset, kLT),
668 load,
669 promoteToDtype(tensorOrConstant(input, newAxes), highType));
670
671 offset = offset + ExprHandle(input.node()->dim(dim));
672 newAxes[dim] = axes[dim] - offset;
673 }
674
675 return load;
676 });
677 }
678
computeEmbedding(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)679 Tensor computeEmbedding(
680 const std::vector<ArgValue>& inputs,
681 const std::vector<ExprHandle>& outputShape,
682 const std::vector<ExprHandle>& outputStrides,
683 const std::optional<ScalarType>& outputType,
684 at::Device device) {
685 Dtype dtype = kFloat;
686 if (outputType) {
687 dtype = Dtype(*outputType);
688 }
689
690 BufHandle ResultBuf("emb", outputShape, dtype);
691 const BufHandle& w = std::get<BufHandle>(inputs[0]);
692 const BufHandle& indices = std::get<BufHandle>(inputs[1]);
693
694 StmtPtr s =
695 ExternalCall::make(ResultBuf, "nnc_aten_embedding", {w, indices}, {});
696 return Tensor(ResultBuf.node(), s);
697 }
698
699 } // namespace torch::jit::tensorexpr
700