1 #include <torch/csrc/jit/codegen/fuser/codegen.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/code_template.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/jit/codegen/fuser/compiler.h>
7 #include <torch/csrc/jit/codegen/fuser/interface.h>
8 #include <torch/csrc/jit/codegen/fuser/tensor_info.h>
9 #include <torch/csrc/jit/ir/ir.h>
10
11 #include <torch/csrc/jit/codegen/fuser/cpu/resource_strings.h>
12 #include <torch/csrc/jit/codegen/fuser/cuda/resource_strings.h>
13
14 #include <cmath>
15 #include <cstdint>
16 #include <iostream>
17 #include <sstream>
18 #include <tuple>
19 #include <vector>
20
21 namespace torch::jit::fuser {
22
23 // Template for computing the offset into the tensor to access a value
24 static auto dim_calc = at::jit::CodeTemplate(R"(
25 //printf("tensor ${tensor} sizes[${d}] = %d, strides[${d}] = %d\n", ${tensor}.sizes[${d}],${tensor}.strides[${d}]);
26 size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes};
27 ${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride};
28 )");
29
valueName(const Value * n)30 static std::string valueName(const Value* n) {
31 return "n" + std::to_string(n->unique());
32 }
33
scalarValue(const int64_t v)34 static std::string scalarValue(const int64_t v) {
35 return std::to_string(v);
36 }
37
scalarValue(const bool v)38 static std::string scalarValue(const bool v) {
39 return std::to_string(v);
40 }
41
42 // Note: The NAN, NEG_INFINITY and POS_INFINITY strings map to device-specific
43 // implementations of these special values. These macros are found in the
44 // resource strings for each device.
scalarValue(const double v)45 static std::string scalarValue(const double v) {
46 std::ostringstream out;
47 if (std::isnan(v)) {
48 out << "NAN";
49 } else if (std::isinf(v)) {
50 if (v < 0) {
51 out << "NEG_INFINITY";
52 } else {
53 out << "POS_INFINITY";
54 }
55 } else {
56 out << std::setprecision(16) << v;
57 }
58 return out.str();
59 }
60
61 // Note: Half is special-cased to avoid returning at::Half
scalarTypeName(const at::ScalarType type)62 static const char* scalarTypeName(const at::ScalarType type) {
63 if (type == at::ScalarType::Half) {
64 return "half";
65 }
66 if (type == at::ScalarType::BFloat16) {
67 return cuda::bfloat16_type_string;
68 }
69
70 switch (type) {
71 #define DEFINE_CASE(ctype, name) \
72 case at::ScalarType::name: \
73 return #ctype;
74 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
75 #undef DEFINE_CASE
76 default:
77 throw std::runtime_error("unknown scalar type");
78 }
79 }
80
calcScalarTypeName(const at::ScalarType type)81 static const char* calcScalarTypeName(const at::ScalarType type) {
82 if (type == at::ScalarType::Half) {
83 return "float";
84 }
85 if (type == at::ScalarType::BFloat16) {
86 return "float";
87 }
88 return scalarTypeName(type);
89 }
90
variableType(const c10::Type & t)91 static std::string variableType(const c10::Type& t) {
92 if (t.kind() == TypeKind::IntType) {
93 return "int64_t";
94 } else if (t.kind() == TypeKind::FloatType) {
95 return "double";
96 } else if (t.kind() == TypeKind::BoolType) {
97 return "bool";
98 } else if (auto scalar_type = t.expectRef<TensorType>().scalarType()) {
99 return calcScalarTypeName(*scalar_type);
100 }
101 // something went wrong with the type analysis during shape propagation
102 throw std::runtime_error(
103 "unknown scalar type during JIT fusion code generation");
104 }
105
typeCastedValueName(const c10::Type & t,const at::ScalarType outtype,const std::string & vn)106 static std::string typeCastedValueName(
107 const c10::Type& t,
108 const at::ScalarType outtype,
109 const std::string& vn) {
110 if (t.kind() == TypeKind::IntType || t.kind() == TypeKind::BoolType) {
111 if (!isIntegralType(outtype, /*includeBool=*/false)) {
112 return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
113 }
114 return vn;
115 } else if (t.kind() == TypeKind::FloatType) {
116 // We don't guard this on anything because in our type system for scalars,
117 // there is not a distinction between `float` and `double`, however there
118 // *is* a distinction in tensor scalar types. We conservatively insert a
119 // cast here, which may end up being a no-op if the tensor's scalar type
120 // is `double`.
121 return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
122 } else if (t.kind() == TypeKind::NoneType) {
123 // Support None value for optional arguments like memory format
124 return vn;
125 } else if (auto scalar_type = t.expectRef<TensorType>().scalarType()) {
126 if (*scalar_type != outtype) {
127 return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
128 }
129 return vn;
130 }
131 // something went wrong with the type analysis during shape propagation
132 throw std::runtime_error(
133 "unknown scalar type during JIT fusion code generation");
134 }
135
136 // Writes RHS of special handling "simple mappable" ops
encodeSpecialRHS(const Node * n,at::jit::TemplateEnv & env)137 static std::string encodeSpecialRHS(const Node* n, at::jit::TemplateEnv& env) {
138 // special case for clamp fusion on missing min/max inputs
139 // Note: It may seem unusual to have the bounds as the first case below,
140 // this is so that if min or max is NaN, they are "ignored"
141 // and when the input is NaN, the output is, too
142 if (n->kind() == aten::clamp) {
143 const auto min = n->input(1);
144 const auto max = n->input(2);
145 env.s("0", valueName(n->input(0)));
146
147 if (!min->node()->mustBeNone() && !max->node()->mustBeNone()) {
148 env.s("1", valueName(min));
149 env.s("2", valueName(max));
150 return format("(${0} < ${1} ? ${1} : (${0} > ${2}? ${2} : ${0}))", env);
151 } else if (min->node()->mustBeNone()) {
152 env.s("1", valueName(max));
153 return format("(${0} > ${1} ? ${1} : ${0})", env);
154 } else if (max->node()->mustBeNone()) {
155 env.s("1", valueName(min));
156 return format("(${0} < ${1} ? ${1} : ${0})", env);
157 } else {
158 throw std::runtime_error(
159 "At least one of 'min' or 'max' must not be None");
160 }
161 } else {
162 throw std::runtime_error("Cannot encode RHS of the node, op not supported");
163 }
164 }
165
166 // This struct specifies a template for dispatching specific aten:: operators.
167 // The current variants of RHS code selection we support are for double and
168 // float output values. For example, an aten::log operator which is assigned
169 // to a float value would emit logf(), whereas an aten::log operator which is
170 // assigned to a double would emit log().
171 struct RHSTemplate {
172 // Common case: float and double dispatch are identical
RHSTemplatetorch::jit::fuser::RHSTemplate173 RHSTemplate(const char* for_float)
174 : for_float(for_float), for_double(for_float) {}
175
RHSTemplatetorch::jit::fuser::RHSTemplate176 RHSTemplate(const char* for_float, const char* for_double)
177 : for_float(for_float), for_double(for_double) {}
178
179 const char* for_float;
180 const char* for_double;
181 };
182
183 // Writes "simple mappable" ops
encodeRHS(const Node * n)184 static std::string encodeRHS(const Node* n) {
185 static std::unordered_map<NodeKind, RHSTemplate> simple_map_ops = {
186 // unary
187 {aten::_cast_Float, "static_cast<float>(${0})"},
188 {aten::abs, "fabs(${0})"},
189 {aten::sigmoid, {"1.f / (1.f + expf(-${0}))", "1. / (1. + exp(-${0}))"}},
190 {aten::relu, "${0} < 0 ? 0.f : ${0} "},
191 {aten::threshold,
192 "${0} <= ${1} ? static_cast<decltype(${0})>(${2}) : ${0} "},
193 {aten::log, {"logf(${0})", "log(${0})"}},
194 {aten::log10, {"log10f(${0})", "log10(${0})"}},
195 {aten::log1p, {"log1pf(${0})", "log1p(${0})"}},
196 {aten::log2, {"log2f(${0})", "log2(${0})"}},
197 {aten::lgamma, {"lgammaf(${0})", "lgamma(${0})"}},
198 {aten::exp, {"expf(${0})", "exp(${0})"}},
199 {aten::expm1, {"expm1f(${0})", "expm1(${0})"}},
200 {aten::erf, {"erff(${0})", "erf(${0})"}},
201 {aten::erfc, {"erfcf(${0})", "erfc(${0})"}},
202 {aten::cos, {"cosf(${0})", "cos(${0})"}},
203 {aten::acos, {"acosf(${0})", "acos(${0})"}},
204 {aten::cosh, {"coshf(${0})", "cosh(${0})"}},
205 {aten::sin, {"sinf(${0})", "sin(${0})"}},
206 {aten::asin, {"asinf(${0})", "asin(${0})"}},
207 {aten::sinh, {"sinhf(${0})", "sinh(${0})"}},
208 {aten::tan, {"tanf(${0})", "tan(${0})"}},
209 {aten::atan, {"atanf(${0})", "atan(${0})"}},
210 {aten::tanh, {"tanhf(${0})", "tanh(${0})"}},
211 {aten::sqrt, {"sqrtf(${0})", "sqrt(${0})"}},
212 {aten::rsqrt, {"rsqrtf(${0})", "rsqrt(${0})"}},
213 {aten::ceil, {"ceilf(${0})", "ceil(${0})"}},
214 {aten::floor, {"floorf(${0})", "floor(${0})"}},
215 {aten::round, {"roundf(${0})", "round(${0})"}},
216 {aten::trunc, {"truncf(${0})", "trunc(${0})"}},
217 {aten::frac, {"${0} - truncf(${0})", "${0} - trunc(${0})"}},
218 {aten::reciprocal, {"1.f/(${0})", "1./(${0})"}},
219 {aten::neg, "-${0}"},
220 // simple binary
221 {aten::atan2, "atan2(${0}, ${1})"},
222 {aten::min,
223 "isnan(${0}) ? ${0} : (isnan(${1}) ? ${1} : (${0} < ${1} ? ${0} : ${1}))"},
224 {aten::max,
225 "isnan(${0}) ? ${0} : (isnan(${1}) ? ${1} : (${0} < ${1} ? ${1} : ${0}))"},
226
227 // binary with other
228 // TODO: some of these ops will not get generated because
229 // we only work on float inputs/outputs, but they are here to record
230 // that they are valid mappable ops once we handle more type
231
232 {aten::__and__, "${0} && ${1}"},
233 {aten::__lshift__, "${0} << ${1}"},
234 {aten::__or__, "${0} || ${1}"},
235 {aten::__rshift__, "${0} >> ${1}"},
236 {aten::__xor__, "${0} ^ ${1}"},
237 {aten::addcmul, "${0} + ${3} * ${1} * ${2}"},
238 {aten::div, "${0} / ${1}"},
239 {aten::eq, "${0_nocast} == ${1_nocast}"},
240 {aten::fmod, "fmodf(${0}, ${1})"},
241 {aten::ge, "(${0_nocast} >= ${1_nocast})"},
242 {aten::gt, "${0_nocast} > ${1_nocast}"},
243 {aten::le, "(${0_nocast} <= ${1_nocast})"},
244 {aten::lt, "${0_nocast} < ${1_nocast}"},
245 {aten::lerp, "${0} + ${2} * (${1} - ${0})"},
246 {aten::type_as, "(${0})"},
247 {aten::mul, "${0} * ${1}"},
248 {aten::ne, "${0_nocast} != ${1_nocast}"},
249 {aten::remainder, "fmod((${1} + fmod(${0}, ${1})), ${1})"},
250 {aten::pow, {"powf(${0}, ${1})", "pow(${0}, ${1})"}},
251
252 // alpha
253 {aten::add, "${0} + ${2}*${1}"},
254 {aten::sub, "(${0} - ${2}*${1})"},
255 {aten::rand_like, "uniform(rnd())"},
256
257 // where
258 {aten::where, "(${0} ? ${1} : ${2})"},
259 };
260
261 at::jit::TemplateEnv env;
262
263 if (simple_map_ops.find(n->kind()) == simple_map_ops.end()) {
264 return encodeSpecialRHS(n, env);
265 } else {
266 size_t i = 0;
267
268 auto outtype = n->output()->type()->expectRef<TensorType>().scalarType();
269 TORCH_INTERNAL_ASSERT(outtype);
270
271 for (auto in : n->inputs()) {
272 // PyTorch converts (scalar) argument types to result before applying the
273 // operator e.g. 1.4-torch.tensor(3) = -2
274 env.s(
275 std::to_string(i),
276 typeCastedValueName(*in->type(), *outtype, valueName(in)));
277 // Uncasted operands only used for comparison operators
278 env.s(std::to_string(i) + "_nocast", valueName(in));
279 i++;
280 }
281
282 const auto& templ = simple_map_ops.at(n->kind());
283 const char* str = nullptr;
284 if (*outtype == at::kFloat) {
285 str = templ.for_float;
286 } else {
287 str = templ.for_double;
288 }
289 AT_ASSERT(str);
290 return format(str, env);
291 }
292 }
293
emitIndexingFor(std::ostream & out,const std::string & tensor,const int ndim,const bool last_is_cont)294 static void emitIndexingFor(
295 std::ostream& out,
296 const std::string& tensor,
297 const int ndim,
298 const bool last_is_cont) {
299 at::jit::TemplateEnv env;
300 env.s("tensor", tensor);
301 out << format("IndexType ${tensor}_offset = 0;\n", env);
302 out << format("IndexType ${tensor}_linearIndex = linearIndex;\n", env);
303 for (int d = ndim - 1; d >= 0; --d) {
304 env.d("d", d);
305 env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]", env) : "");
306 env.s(
307 "times_stride",
308 (d < ndim - 1 || !last_is_cont)
309 ? format("* ${tensor}.strides[${d}]", env)
310 : "");
311 out << dim_calc.format(env);
312 if (d > 0) {
313 out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n", env);
314 }
315 }
316 }
317
emitCheckFor(std::ostream & out,const std::string & tensor,const int ndim,const TensorDesc & desc)318 static void emitCheckFor(
319 std::ostream& out,
320 const std::string& tensor,
321 const int ndim,
322 const TensorDesc& desc) {
323 at::jit::TemplateEnv env;
324 env.s("tensor", tensor);
325 env.s("scalar_type", scalarTypeName(desc.scalar_type));
326
327 // allocate buffer to load 4
328 out << format("${scalar_type} ${tensor}_buf[4];\n", env);
329
330 // check if last dim is contiguous
331 if (!desc.lastIsContiguous()) {
332 out << "flag_vec4 = false;\n";
333 return;
334 }
335
336 // disable on dtype > 4 bytes for performance
337 if (at::elementSize(desc.scalar_type) > 4) {
338 out << "flag_vec4 = false;\n";
339 return;
340 }
341
342 // last dim size multiple of 4, other dim stride multiple of 4
343 for (int d = ndim - 1; d >= 0; --d) {
344 env.d("d", d);
345 if (d == ndim - 1) {
346 // last dim stride already checked above at compile time
347 out << format(
348 "if(${tensor}.sizes[${d}] % 4 != 0) flag_vec4 = false;\n", env);
349 } else {
350 out << format(
351 "if(${tensor}.strides[${d}] % 4 != 0) flag_vec4 = false;\n", env);
352 }
353 }
354
355 // pointer aligned
356 out << format(
357 "if(((uint64_t) ${tensor}.data) % (4 * sizeof(${scalar_type})) != 0) flag_vec4 = false;\n",
358 env);
359 }
360
361 // TODO: handle cases where we need to generate > 2^32 element tensors
generateKernel(const std::string & name,const Graph & graph,const std::vector<std::pair<const Value *,const std::optional<TensorDesc>>> & inputs,const std::vector<std::pair<const Value *,const TensorDesc>> & outputs,const bool use_cuda)362 std::string generateKernel(
363 const std::string& name,
364 const Graph& graph,
365 const std::vector<std::pair<const Value*, const std::optional<TensorDesc>>>&
366 inputs,
367 const std::vector<std::pair<const Value*, const TensorDesc>>& outputs,
368 const bool use_cuda) {
369 at::jit::TemplateEnv env;
370 env.s("kernelName", name);
371 env.s(
372 "IndexType",
373 "unsigned int"); // Note: not uint32_t to avoid including cstdint
374
375 std::stringstream tensorChecks;
376 std::stringstream body;
377 std::stringstream body_vec4;
378 std::stringstream load;
379 std::stringstream store;
380 std::stringstream tensorOffsets;
381 std::vector<std::string> formals;
382 std::vector<std::string> argument_loads;
383
384 // Lambda for writing arguments
385 auto emitFormal = [&](const Value* n, const TensorDesc& desc) {
386 env.d(
387 "formal_index",
388 formals.size() +
389 1); // + 1 because the first argument is the linearIndex
390 std::string tensor =
391 "t" +
392 std::to_string(
393 formals.size()); // can't be unique() because Param may be an output
394 const auto nDim = desc.nDim();
395 emitCheckFor(tensorChecks, tensor, nDim, desc);
396 emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
397 env.s("tensor", tensor);
398 env.d("nDim", nDim);
399 env.s("scalar_type", scalarTypeName(desc.scalar_type));
400 formals.push_back(
401 format("const TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
402 argument_loads.push_back(format(
403 "*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",
404 env));
405 };
406
407 auto emitScalarFormal = [&](const Value* n) {
408 env.d(
409 "formal_index",
410 formals.size() +
411 1); // + 1 because the first argument is the linearIndex
412 std::string scalar =
413 "s" +
414 std::to_string(
415 formals.size()); // can't be unique() because Param may be an output
416 env.d(
417 "formal_index",
418 formals.size() +
419 1); // + 1 because the first argument is the linearIndex
420 env.s("scalar", scalar);
421 env.s("scalar_type", variableType(*n->type()));
422 formals.push_back(format("${scalar_type} ${scalar}", env));
423 argument_loads.push_back(
424 format("*static_cast<${scalar_type}*>(args[${formal_index}])", env));
425 };
426
427 // Writes input parameters
428 for (const auto& input : inputs) {
429 if (input.second.has_value()) {
430 emitFormal(input.first, *input.second);
431 } else {
432 emitScalarFormal(input.first);
433 }
434 }
435
436 // Writes output parameters
437 for (const auto& output : outputs) {
438 emitFormal(output.first, output.second);
439 }
440
441 // Acquires input values
442 bool has_half_tensor = false;
443 bool has_bfloat_tensor = false;
444 size_t formal_count = 0;
445 for (const auto& input : inputs) {
446 auto p = input.first;
447 env.s("node", valueName(p));
448 env.d("formal", formal_count++);
449
450 // Acquires and converts (if needed) inputs
451 // Note: conversion from half is only supported for CUDA kernels.
452 // The conversion immediately converts fp16 inputs to float.
453 // Access for other types is common to CUDA and CPU kernels.
454 if (input.second.has_value()) {
455 const auto is_half = input.second.has_value() &&
456 ((*input.second).scalar_type == at::ScalarType::Half);
457 const auto is_bfloat = input.second.has_value() &&
458 ((*input.second).scalar_type == at::ScalarType::BFloat16);
459 const auto is_bool = input.second.has_value() &&
460 ((*input.second).scalar_type == at::ScalarType::Bool);
461 if (is_half) {
462 AT_ASSERT(use_cuda);
463 env.s(
464 "access",
465 format("__half2float(t${formal}.data[t${formal}_offset])", env));
466 env.s("access_vec4", format("__half2float(t${formal}_buf[i])", env));
467 has_half_tensor = true;
468 } else if (is_bfloat) {
469 AT_ASSERT(use_cuda);
470 env.s(
471 "access",
472 format(
473 "__bfloat162float(t${formal}.data[t${formal}_offset])", env));
474 env.s(
475 "access_vec4", format("__bfloat162float(t${formal}_buf[i])", env));
476 has_bfloat_tensor = true;
477 } else if (use_cuda) {
478 // No __ldg overload for bool
479 if (is_bool) {
480 env.s("access", format("t${formal}.data[t${formal}_offset]", env));
481 } else {
482 env.s(
483 "access",
484 format("__ldg(&t${formal}.data[t${formal}_offset])", env));
485 }
486 env.s("access_vec4", format("t${formal}_buf[i]", env));
487 } else {
488 env.s("access", format("t${formal}.data[t${formal}_offset]", env));
489 env.s("access_vec4", format("t${formal}_buf[i]", env));
490 }
491 env.s("lhs_type", calcScalarTypeName(input.second->scalar_type));
492
493 // load input in vectorized code path
494 auto ele_size = at::elementSize((*input.second).scalar_type);
495 if (ele_size == 1) {
496 env.s(
497 "load4",
498 format(
499 "*(reinterpret_cast<float*>(t${formal}_buf)) = *(reinterpret_cast<float*>(t${formal}.data + t${formal}_offset))",
500 env));
501 } else if (ele_size == 2) {
502 env.s(
503 "load4",
504 format(
505 "*(reinterpret_cast<float2*>(t${formal}_buf)) = *(reinterpret_cast<float2*>(t${formal}.data + t${formal}_offset))",
506 env));
507 } else if (ele_size == 4) {
508 env.s(
509 "load4",
510 format(
511 "*(reinterpret_cast<float4*>(t${formal}_buf)) = *(reinterpret_cast<float4*>(t${formal}.data + t${formal}_offset))",
512 env));
513 } else {
514 env.s(
515 "load4",
516 format(
517 "for(int i = 0; i<4; i++) t${formal}_buf[i] = t${formal}.data[t${formal}_offset + i]",
518 env));
519 }
520 load << format("${load4};\n", env);
521
522 } else {
523 env.s("access", format("s${formal}", env));
524 env.s("access_vec4", format("s${formal}", env));
525 env.s("lhs_type", variableType(*input.first->type()));
526 }
527 body << format("${lhs_type} ${node} = ${access};\n", env);
528 body_vec4 << format("${lhs_type} ${node} = ${access_vec4};\n", env);
529 }
530
531 bool has_random = false;
532 // Generates code for intermediate nodes
533 // Note: Concat and Chunk are implicitly generated
534 // Note: Random number generation is only supported for CUDA kernels.
535 // Note: Constant None node is ignored and we will handle it in the
536 // places where the constant None node is used
537 // Note: No need to iterate over reference as n is a pointer
538 for (const auto n : graph.nodes()) {
539 static_assert(std::is_pointer_v<decltype(n)>, "n must be a pointer");
540 // Note: FusedConcat nodes work by narrowing the output Tensors before the
541 // kernel runs
542 if (n->kind() == prim::FusedConcat)
543 continue;
544 if (n->kind() == prim::ConstantChunk)
545 continue;
546 if (n->mustBeNone())
547 continue;
548 if (n->kind() == aten::rand_like) {
549 AT_ASSERT(use_cuda);
550 has_random = true;
551 }
552 // Always emit double for prim::Constant. This will be narrowed later based
553 // on either:
554 // - Tensor-Scalar operator type rules
555 // - Math function rules
556 if (n->kind() == prim::Constant) {
557 const auto val = toIValue(n->output()).value();
558 std::string rhs;
559 if (val.isDouble()) {
560 rhs = scalarValue(val.toDouble());
561 } else if (val.isBool()) {
562 rhs = scalarValue(val.toBool());
563 } else {
564 AT_ASSERT(val.isInt());
565 rhs = scalarValue(val.toInt());
566 }
567 env.s("node", valueName(n->output()));
568 env.s("rhs", rhs);
569 env.s("lhs_type", variableType(*n->output()->type()));
570 } else {
571 env.s("node", valueName(n->output()));
572 env.s("rhs", encodeRHS(n));
573 env.s("lhs_type", variableType(*n->output()->type()));
574 }
575
576 body << format("${lhs_type} ${node} = ${rhs};\n", env);
577 body_vec4 << format("${lhs_type} ${node} = ${rhs};\n", env);
578 }
579
580 // Generates writes to output tensors
581 for (const auto& output : outputs) {
582 env.d("formal", formal_count++);
583 env.s("access", format("t${formal}.data[t${formal}_offset]", env));
584 env.s("access_vec4", format("t${formal}_buf[i]", env));
585 env.s("node", valueName(output.first));
586
587 // Acquires and converts (if needed) outputs
588 // Note: conversion to half is only supported for CUDA kernels.
589 const auto is_half = (output.second.scalar_type == at::ScalarType::Half);
590 const auto is_bfloat =
591 (output.second.scalar_type == at::ScalarType::BFloat16);
592 if (is_half) {
593 AT_ASSERT(use_cuda);
594 body << format("${access} = __float2half(${node});\n", env);
595 body_vec4 << format("${access_vec4} = __float2half(${node});\n", env);
596 has_half_tensor = true;
597 } else if (is_bfloat) {
598 AT_ASSERT(use_cuda);
599 body << format("${access} = __float2bfloat16(${node});\n", env);
600 body_vec4 << format("${access_vec4} = __float2bfloat16(${node});\n", env);
601 has_bfloat_tensor = true;
602 } else {
603 body << format("${access} = ${node};\n", env);
604 body_vec4 << format("${access_vec4} = ${node};\n", env);
605 }
606
607 // store output in vectorized code path
608 auto ele_size = at::elementSize(output.second.scalar_type);
609 if (ele_size == 1) {
610 env.s(
611 "store4",
612 format(
613 "*(reinterpret_cast<float*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float*>(t${formal}_buf))",
614 env));
615 } else if (ele_size == 2) {
616 env.s(
617 "store4",
618 format(
619 "*(reinterpret_cast<float2*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float2*>(t${formal}_buf))",
620 env));
621 } else if (ele_size == 4) {
622 env.s(
623 "store4",
624 format(
625 "*(reinterpret_cast<float4*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float4*>(t${formal}_buf))",
626 env));
627 } else {
628 env.s(
629 "store4",
630 format(
631 "for(int i = 0; i<4; i++) t${formal}.data[t${formal}_offset + i] = t${formal}_buf[i]",
632 env));
633 }
634 store << format("${store4};\n", env);
635 }
636
637 // Includes headers
638 // Note: CUDA kernels support halfs and random generation, CPU kernels do not
639 if (has_half_tensor) {
640 env.s("HalfHeader", cuda::half_support_literal);
641 } else {
642 env.s("HalfHeader", "");
643 }
644 if (has_bfloat_tensor) {
645 env.s("BFloat16Header", cuda::bfloat16_support_literal);
646 } else {
647 env.s("BFloat16Header", "");
648 }
649
650 if (has_random) {
651 env.s("RandHeader", cuda::rand_support_literal);
652 env.s("RandParam", cuda::rand_param);
653 env.s("RandInit", cuda::rand_init);
654 } else {
655 env.s("RandHeader", "");
656 env.s("RandParam", "");
657 env.s("RandInit", "");
658 }
659
660 // clang-format on
661
662 // Instantiates the CUDA or CPU-specific templates
663 env.s("tensorOffsets", tensorOffsets.str());
664 env.s("tensorChecks", tensorChecks.str());
665 env.s("kernelBody", body.str());
666 env.s("kernelBody_vec4", body_vec4.str());
667 env.s("kernelLoad", load.str());
668 env.s("kernelStore", store.str());
669 env.v("formals", formals);
670 env.v("argument_loads", argument_loads);
671 std::string code_string;
672 if (use_cuda) {
673 env.s("type_declarations", cuda::type_declarations_template.format(env));
674 code_string = cuda::cuda_compilation_unit_template.format(env);
675 } else {
676 env.s("type_declarations", cpu::type_declarations_template.format(env));
677 code_string = cpu::cpu_compilation_unit_template.format(env);
678 }
679
680 if (debugFuser()) {
681 std::cerr << "fusion code:" << code_string << '\n';
682 }
683 return code_string;
684 }
685
686 } // namespace torch::jit::fuser
687