xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/codegen.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/fuser/codegen.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/code_template.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/jit/codegen/fuser/compiler.h>
7 #include <torch/csrc/jit/codegen/fuser/interface.h>
8 #include <torch/csrc/jit/codegen/fuser/tensor_info.h>
9 #include <torch/csrc/jit/ir/ir.h>
10 
11 #include <torch/csrc/jit/codegen/fuser/cpu/resource_strings.h>
12 #include <torch/csrc/jit/codegen/fuser/cuda/resource_strings.h>
13 
14 #include <cmath>
15 #include <cstdint>
16 #include <iostream>
17 #include <sstream>
18 #include <tuple>
19 #include <vector>
20 
21 namespace torch::jit::fuser {
22 
23 // Template for computing the offset into the tensor to access a value
24 static auto dim_calc = at::jit::CodeTemplate(R"(
25 //printf("tensor ${tensor} sizes[${d}] = %d, strides[${d}] = %d\n", ${tensor}.sizes[${d}],${tensor}.strides[${d}]);
26 size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes};
27 ${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride};
28 )");
29 
valueName(const Value * n)30 static std::string valueName(const Value* n) {
31   return "n" + std::to_string(n->unique());
32 }
33 
scalarValue(const int64_t v)34 static std::string scalarValue(const int64_t v) {
35   return std::to_string(v);
36 }
37 
scalarValue(const bool v)38 static std::string scalarValue(const bool v) {
39   return std::to_string(v);
40 }
41 
42 // Note: The NAN, NEG_INFINITY and POS_INFINITY strings map to device-specific
43 // implementations of these special values. These macros are found in the
44 // resource strings for each device.
scalarValue(const double v)45 static std::string scalarValue(const double v) {
46   std::ostringstream out;
47   if (std::isnan(v)) {
48     out << "NAN";
49   } else if (std::isinf(v)) {
50     if (v < 0) {
51       out << "NEG_INFINITY";
52     } else {
53       out << "POS_INFINITY";
54     }
55   } else {
56     out << std::setprecision(16) << v;
57   }
58   return out.str();
59 }
60 
61 // Note: Half is special-cased to avoid returning at::Half
scalarTypeName(const at::ScalarType type)62 static const char* scalarTypeName(const at::ScalarType type) {
63   if (type == at::ScalarType::Half) {
64     return "half";
65   }
66   if (type == at::ScalarType::BFloat16) {
67     return cuda::bfloat16_type_string;
68   }
69 
70   switch (type) {
71 #define DEFINE_CASE(ctype, name) \
72   case at::ScalarType::name:     \
73     return #ctype;
74     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
75 #undef DEFINE_CASE
76     default:
77       throw std::runtime_error("unknown scalar type");
78   }
79 }
80 
calcScalarTypeName(const at::ScalarType type)81 static const char* calcScalarTypeName(const at::ScalarType type) {
82   if (type == at::ScalarType::Half) {
83     return "float";
84   }
85   if (type == at::ScalarType::BFloat16) {
86     return "float";
87   }
88   return scalarTypeName(type);
89 }
90 
variableType(const c10::Type & t)91 static std::string variableType(const c10::Type& t) {
92   if (t.kind() == TypeKind::IntType) {
93     return "int64_t";
94   } else if (t.kind() == TypeKind::FloatType) {
95     return "double";
96   } else if (t.kind() == TypeKind::BoolType) {
97     return "bool";
98   } else if (auto scalar_type = t.expectRef<TensorType>().scalarType()) {
99     return calcScalarTypeName(*scalar_type);
100   }
101   // something went wrong with the type analysis during shape propagation
102   throw std::runtime_error(
103       "unknown scalar type during JIT fusion code generation");
104 }
105 
typeCastedValueName(const c10::Type & t,const at::ScalarType outtype,const std::string & vn)106 static std::string typeCastedValueName(
107     const c10::Type& t,
108     const at::ScalarType outtype,
109     const std::string& vn) {
110   if (t.kind() == TypeKind::IntType || t.kind() == TypeKind::BoolType) {
111     if (!isIntegralType(outtype, /*includeBool=*/false)) {
112       return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
113     }
114     return vn;
115   } else if (t.kind() == TypeKind::FloatType) {
116     // We don't guard this on anything because in our type system for scalars,
117     // there is not a distinction between `float` and `double`, however there
118     // *is* a distinction in tensor scalar types. We conservatively insert a
119     // cast here, which may end up being a no-op if the tensor's scalar type
120     // is `double`.
121     return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
122   } else if (t.kind() == TypeKind::NoneType) {
123     // Support None value for optional arguments like memory format
124     return vn;
125   } else if (auto scalar_type = t.expectRef<TensorType>().scalarType()) {
126     if (*scalar_type != outtype) {
127       return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
128     }
129     return vn;
130   }
131   // something went wrong with the type analysis during shape propagation
132   throw std::runtime_error(
133       "unknown scalar type during JIT fusion code generation");
134 }
135 
136 // Writes RHS of special handling "simple mappable" ops
encodeSpecialRHS(const Node * n,at::jit::TemplateEnv & env)137 static std::string encodeSpecialRHS(const Node* n, at::jit::TemplateEnv& env) {
138   // special case for clamp fusion on missing min/max inputs
139   // Note: It may seem unusual to have the bounds as the first case below,
140   // this is so that if min or max is NaN, they are "ignored"
141   // and when the input is NaN, the output is, too
142   if (n->kind() == aten::clamp) {
143     const auto min = n->input(1);
144     const auto max = n->input(2);
145     env.s("0", valueName(n->input(0)));
146 
147     if (!min->node()->mustBeNone() && !max->node()->mustBeNone()) {
148       env.s("1", valueName(min));
149       env.s("2", valueName(max));
150       return format("(${0} < ${1} ? ${1} : (${0} > ${2}? ${2} : ${0}))", env);
151     } else if (min->node()->mustBeNone()) {
152       env.s("1", valueName(max));
153       return format("(${0} > ${1} ? ${1} : ${0})", env);
154     } else if (max->node()->mustBeNone()) {
155       env.s("1", valueName(min));
156       return format("(${0} < ${1} ? ${1} : ${0})", env);
157     } else {
158       throw std::runtime_error(
159           "At least one of 'min' or 'max' must not be None");
160     }
161   } else {
162     throw std::runtime_error("Cannot encode RHS of the node, op not supported");
163   }
164 }
165 
166 // This struct specifies a template for dispatching specific aten:: operators.
167 // The current variants of RHS code selection we support are for double and
168 // float output values. For example, an aten::log operator which is assigned
169 // to a float value would emit logf(), whereas an aten::log operator which is
170 // assigned to a double would emit log().
171 struct RHSTemplate {
172   // Common case: float and double dispatch are identical
RHSTemplatetorch::jit::fuser::RHSTemplate173   RHSTemplate(const char* for_float)
174       : for_float(for_float), for_double(for_float) {}
175 
RHSTemplatetorch::jit::fuser::RHSTemplate176   RHSTemplate(const char* for_float, const char* for_double)
177       : for_float(for_float), for_double(for_double) {}
178 
179   const char* for_float;
180   const char* for_double;
181 };
182 
183 // Writes "simple mappable" ops
encodeRHS(const Node * n)184 static std::string encodeRHS(const Node* n) {
185   static std::unordered_map<NodeKind, RHSTemplate> simple_map_ops = {
186       // unary
187       {aten::_cast_Float, "static_cast<float>(${0})"},
188       {aten::abs, "fabs(${0})"},
189       {aten::sigmoid, {"1.f / (1.f + expf(-${0}))", "1. / (1. + exp(-${0}))"}},
190       {aten::relu, "${0} < 0 ? 0.f : ${0} "},
191       {aten::threshold,
192        "${0} <= ${1} ? static_cast<decltype(${0})>(${2}) : ${0} "},
193       {aten::log, {"logf(${0})", "log(${0})"}},
194       {aten::log10, {"log10f(${0})", "log10(${0})"}},
195       {aten::log1p, {"log1pf(${0})", "log1p(${0})"}},
196       {aten::log2, {"log2f(${0})", "log2(${0})"}},
197       {aten::lgamma, {"lgammaf(${0})", "lgamma(${0})"}},
198       {aten::exp, {"expf(${0})", "exp(${0})"}},
199       {aten::expm1, {"expm1f(${0})", "expm1(${0})"}},
200       {aten::erf, {"erff(${0})", "erf(${0})"}},
201       {aten::erfc, {"erfcf(${0})", "erfc(${0})"}},
202       {aten::cos, {"cosf(${0})", "cos(${0})"}},
203       {aten::acos, {"acosf(${0})", "acos(${0})"}},
204       {aten::cosh, {"coshf(${0})", "cosh(${0})"}},
205       {aten::sin, {"sinf(${0})", "sin(${0})"}},
206       {aten::asin, {"asinf(${0})", "asin(${0})"}},
207       {aten::sinh, {"sinhf(${0})", "sinh(${0})"}},
208       {aten::tan, {"tanf(${0})", "tan(${0})"}},
209       {aten::atan, {"atanf(${0})", "atan(${0})"}},
210       {aten::tanh, {"tanhf(${0})", "tanh(${0})"}},
211       {aten::sqrt, {"sqrtf(${0})", "sqrt(${0})"}},
212       {aten::rsqrt, {"rsqrtf(${0})", "rsqrt(${0})"}},
213       {aten::ceil, {"ceilf(${0})", "ceil(${0})"}},
214       {aten::floor, {"floorf(${0})", "floor(${0})"}},
215       {aten::round, {"roundf(${0})", "round(${0})"}},
216       {aten::trunc, {"truncf(${0})", "trunc(${0})"}},
217       {aten::frac, {"${0} - truncf(${0})", "${0} - trunc(${0})"}},
218       {aten::reciprocal, {"1.f/(${0})", "1./(${0})"}},
219       {aten::neg, "-${0}"},
220       // simple binary
221       {aten::atan2, "atan2(${0}, ${1})"},
222       {aten::min,
223        "isnan(${0}) ? ${0} : (isnan(${1}) ? ${1} : (${0} < ${1} ? ${0} : ${1}))"},
224       {aten::max,
225        "isnan(${0}) ? ${0} : (isnan(${1}) ? ${1} : (${0} < ${1} ? ${1} : ${0}))"},
226 
227       // binary with other
228       // TODO: some of these ops will not get generated because
229       // we only work on float inputs/outputs, but they are here to record
230       // that they are valid mappable ops once we handle more type
231 
232       {aten::__and__, "${0} && ${1}"},
233       {aten::__lshift__, "${0} << ${1}"},
234       {aten::__or__, "${0} || ${1}"},
235       {aten::__rshift__, "${0} >> ${1}"},
236       {aten::__xor__, "${0} ^ ${1}"},
237       {aten::addcmul, "${0} + ${3} * ${1} * ${2}"},
238       {aten::div, "${0} / ${1}"},
239       {aten::eq, "${0_nocast} == ${1_nocast}"},
240       {aten::fmod, "fmodf(${0}, ${1})"},
241       {aten::ge, "(${0_nocast} >= ${1_nocast})"},
242       {aten::gt, "${0_nocast} > ${1_nocast}"},
243       {aten::le, "(${0_nocast} <= ${1_nocast})"},
244       {aten::lt, "${0_nocast} < ${1_nocast}"},
245       {aten::lerp, "${0} + ${2} * (${1} - ${0})"},
246       {aten::type_as, "(${0})"},
247       {aten::mul, "${0} * ${1}"},
248       {aten::ne, "${0_nocast} != ${1_nocast}"},
249       {aten::remainder, "fmod((${1} + fmod(${0}, ${1})), ${1})"},
250       {aten::pow, {"powf(${0}, ${1})", "pow(${0}, ${1})"}},
251 
252       // alpha
253       {aten::add, "${0} + ${2}*${1}"},
254       {aten::sub, "(${0} - ${2}*${1})"},
255       {aten::rand_like, "uniform(rnd())"},
256 
257       // where
258       {aten::where, "(${0} ? ${1} : ${2})"},
259   };
260 
261   at::jit::TemplateEnv env;
262 
263   if (simple_map_ops.find(n->kind()) == simple_map_ops.end()) {
264     return encodeSpecialRHS(n, env);
265   } else {
266     size_t i = 0;
267 
268     auto outtype = n->output()->type()->expectRef<TensorType>().scalarType();
269     TORCH_INTERNAL_ASSERT(outtype);
270 
271     for (auto in : n->inputs()) {
272       // PyTorch converts (scalar) argument types to result before applying the
273       // operator e.g. 1.4-torch.tensor(3) = -2
274       env.s(
275           std::to_string(i),
276           typeCastedValueName(*in->type(), *outtype, valueName(in)));
277       // Uncasted operands only used for comparison operators
278       env.s(std::to_string(i) + "_nocast", valueName(in));
279       i++;
280     }
281 
282     const auto& templ = simple_map_ops.at(n->kind());
283     const char* str = nullptr;
284     if (*outtype == at::kFloat) {
285       str = templ.for_float;
286     } else {
287       str = templ.for_double;
288     }
289     AT_ASSERT(str);
290     return format(str, env);
291   }
292 }
293 
emitIndexingFor(std::ostream & out,const std::string & tensor,const int ndim,const bool last_is_cont)294 static void emitIndexingFor(
295     std::ostream& out,
296     const std::string& tensor,
297     const int ndim,
298     const bool last_is_cont) {
299   at::jit::TemplateEnv env;
300   env.s("tensor", tensor);
301   out << format("IndexType ${tensor}_offset = 0;\n", env);
302   out << format("IndexType ${tensor}_linearIndex = linearIndex;\n", env);
303   for (int d = ndim - 1; d >= 0; --d) {
304     env.d("d", d);
305     env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]", env) : "");
306     env.s(
307         "times_stride",
308         (d < ndim - 1 || !last_is_cont)
309             ? format("* ${tensor}.strides[${d}]", env)
310             : "");
311     out << dim_calc.format(env);
312     if (d > 0) {
313       out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n", env);
314     }
315   }
316 }
317 
emitCheckFor(std::ostream & out,const std::string & tensor,const int ndim,const TensorDesc & desc)318 static void emitCheckFor(
319     std::ostream& out,
320     const std::string& tensor,
321     const int ndim,
322     const TensorDesc& desc) {
323   at::jit::TemplateEnv env;
324   env.s("tensor", tensor);
325   env.s("scalar_type", scalarTypeName(desc.scalar_type));
326 
327   // allocate buffer to load 4
328   out << format("${scalar_type} ${tensor}_buf[4];\n", env);
329 
330   // check if last dim is contiguous
331   if (!desc.lastIsContiguous()) {
332     out << "flag_vec4 = false;\n";
333     return;
334   }
335 
336   // disable on dtype > 4 bytes for performance
337   if (at::elementSize(desc.scalar_type) > 4) {
338     out << "flag_vec4 = false;\n";
339     return;
340   }
341 
342   // last dim size multiple of 4, other dim stride multiple of 4
343   for (int d = ndim - 1; d >= 0; --d) {
344     env.d("d", d);
345     if (d == ndim - 1) {
346       // last dim stride already checked above at compile time
347       out << format(
348           "if(${tensor}.sizes[${d}] % 4 != 0) flag_vec4 = false;\n", env);
349     } else {
350       out << format(
351           "if(${tensor}.strides[${d}] % 4 != 0) flag_vec4 = false;\n", env);
352     }
353   }
354 
355   // pointer aligned
356   out << format(
357       "if(((uint64_t) ${tensor}.data) % (4 * sizeof(${scalar_type})) != 0) flag_vec4 = false;\n",
358       env);
359 }
360 
361 // TODO: handle cases where we need to generate > 2^32 element tensors
generateKernel(const std::string & name,const Graph & graph,const std::vector<std::pair<const Value *,const std::optional<TensorDesc>>> & inputs,const std::vector<std::pair<const Value *,const TensorDesc>> & outputs,const bool use_cuda)362 std::string generateKernel(
363     const std::string& name,
364     const Graph& graph,
365     const std::vector<std::pair<const Value*, const std::optional<TensorDesc>>>&
366         inputs,
367     const std::vector<std::pair<const Value*, const TensorDesc>>& outputs,
368     const bool use_cuda) {
369   at::jit::TemplateEnv env;
370   env.s("kernelName", name);
371   env.s(
372       "IndexType",
373       "unsigned int"); // Note: not uint32_t to avoid including cstdint
374 
375   std::stringstream tensorChecks;
376   std::stringstream body;
377   std::stringstream body_vec4;
378   std::stringstream load;
379   std::stringstream store;
380   std::stringstream tensorOffsets;
381   std::vector<std::string> formals;
382   std::vector<std::string> argument_loads;
383 
384   // Lambda for writing arguments
385   auto emitFormal = [&](const Value* n, const TensorDesc& desc) {
386     env.d(
387         "formal_index",
388         formals.size() +
389             1); // + 1 because the first argument is the linearIndex
390     std::string tensor =
391         "t" +
392         std::to_string(
393             formals.size()); // can't be unique() because Param may be an output
394     const auto nDim = desc.nDim();
395     emitCheckFor(tensorChecks, tensor, nDim, desc);
396     emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
397     env.s("tensor", tensor);
398     env.d("nDim", nDim);
399     env.s("scalar_type", scalarTypeName(desc.scalar_type));
400     formals.push_back(
401         format("const TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
402     argument_loads.push_back(format(
403         "*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",
404         env));
405   };
406 
407   auto emitScalarFormal = [&](const Value* n) {
408     env.d(
409         "formal_index",
410         formals.size() +
411             1); // + 1 because the first argument is the linearIndex
412     std::string scalar =
413         "s" +
414         std::to_string(
415             formals.size()); // can't be unique() because Param may be an output
416     env.d(
417         "formal_index",
418         formals.size() +
419             1); // + 1 because the first argument is the linearIndex
420     env.s("scalar", scalar);
421     env.s("scalar_type", variableType(*n->type()));
422     formals.push_back(format("${scalar_type} ${scalar}", env));
423     argument_loads.push_back(
424         format("*static_cast<${scalar_type}*>(args[${formal_index}])", env));
425   };
426 
427   // Writes input parameters
428   for (const auto& input : inputs) {
429     if (input.second.has_value()) {
430       emitFormal(input.first, *input.second);
431     } else {
432       emitScalarFormal(input.first);
433     }
434   }
435 
436   // Writes output parameters
437   for (const auto& output : outputs) {
438     emitFormal(output.first, output.second);
439   }
440 
441   // Acquires input values
442   bool has_half_tensor = false;
443   bool has_bfloat_tensor = false;
444   size_t formal_count = 0;
445   for (const auto& input : inputs) {
446     auto p = input.first;
447     env.s("node", valueName(p));
448     env.d("formal", formal_count++);
449 
450     // Acquires and converts (if needed) inputs
451     // Note: conversion from half is only supported for CUDA kernels.
452     //  The conversion immediately converts fp16 inputs to float.
453     //  Access for other types is common to CUDA and CPU kernels.
454     if (input.second.has_value()) {
455       const auto is_half = input.second.has_value() &&
456           ((*input.second).scalar_type == at::ScalarType::Half);
457       const auto is_bfloat = input.second.has_value() &&
458           ((*input.second).scalar_type == at::ScalarType::BFloat16);
459       const auto is_bool = input.second.has_value() &&
460           ((*input.second).scalar_type == at::ScalarType::Bool);
461       if (is_half) {
462         AT_ASSERT(use_cuda);
463         env.s(
464             "access",
465             format("__half2float(t${formal}.data[t${formal}_offset])", env));
466         env.s("access_vec4", format("__half2float(t${formal}_buf[i])", env));
467         has_half_tensor = true;
468       } else if (is_bfloat) {
469         AT_ASSERT(use_cuda);
470         env.s(
471             "access",
472             format(
473                 "__bfloat162float(t${formal}.data[t${formal}_offset])", env));
474         env.s(
475             "access_vec4", format("__bfloat162float(t${formal}_buf[i])", env));
476         has_bfloat_tensor = true;
477       } else if (use_cuda) {
478         // No __ldg overload for bool
479         if (is_bool) {
480           env.s("access", format("t${formal}.data[t${formal}_offset]", env));
481         } else {
482           env.s(
483               "access",
484               format("__ldg(&t${formal}.data[t${formal}_offset])", env));
485         }
486         env.s("access_vec4", format("t${formal}_buf[i]", env));
487       } else {
488         env.s("access", format("t${formal}.data[t${formal}_offset]", env));
489         env.s("access_vec4", format("t${formal}_buf[i]", env));
490       }
491       env.s("lhs_type", calcScalarTypeName(input.second->scalar_type));
492 
493       // load input in vectorized code path
494       auto ele_size = at::elementSize((*input.second).scalar_type);
495       if (ele_size == 1) {
496         env.s(
497             "load4",
498             format(
499                 "*(reinterpret_cast<float*>(t${formal}_buf)) = *(reinterpret_cast<float*>(t${formal}.data + t${formal}_offset))",
500                 env));
501       } else if (ele_size == 2) {
502         env.s(
503             "load4",
504             format(
505                 "*(reinterpret_cast<float2*>(t${formal}_buf)) = *(reinterpret_cast<float2*>(t${formal}.data + t${formal}_offset))",
506                 env));
507       } else if (ele_size == 4) {
508         env.s(
509             "load4",
510             format(
511                 "*(reinterpret_cast<float4*>(t${formal}_buf)) = *(reinterpret_cast<float4*>(t${formal}.data + t${formal}_offset))",
512                 env));
513       } else {
514         env.s(
515             "load4",
516             format(
517                 "for(int i = 0; i<4; i++) t${formal}_buf[i] = t${formal}.data[t${formal}_offset + i]",
518                 env));
519       }
520       load << format("${load4};\n", env);
521 
522     } else {
523       env.s("access", format("s${formal}", env));
524       env.s("access_vec4", format("s${formal}", env));
525       env.s("lhs_type", variableType(*input.first->type()));
526     }
527     body << format("${lhs_type} ${node} = ${access};\n", env);
528     body_vec4 << format("${lhs_type} ${node} = ${access_vec4};\n", env);
529   }
530 
531   bool has_random = false;
532   // Generates code for intermediate nodes
533   // Note: Concat and Chunk are implicitly generated
534   // Note: Random number generation is only supported for CUDA kernels.
535   // Note: Constant None node is ignored and we will handle it in the
536   //       places where the constant None node is used
537   // Note: No need to iterate over reference as n is a pointer
538   for (const auto n : graph.nodes()) {
539     static_assert(std::is_pointer_v<decltype(n)>, "n must be a pointer");
540     // Note: FusedConcat nodes work by narrowing the output Tensors before the
541     // kernel runs
542     if (n->kind() == prim::FusedConcat)
543       continue;
544     if (n->kind() == prim::ConstantChunk)
545       continue;
546     if (n->mustBeNone())
547       continue;
548     if (n->kind() == aten::rand_like) {
549       AT_ASSERT(use_cuda);
550       has_random = true;
551     }
552     // Always emit double for prim::Constant. This will be narrowed later based
553     // on either:
554     //  - Tensor-Scalar operator type rules
555     //  - Math function rules
556     if (n->kind() == prim::Constant) {
557       const auto val = toIValue(n->output()).value();
558       std::string rhs;
559       if (val.isDouble()) {
560         rhs = scalarValue(val.toDouble());
561       } else if (val.isBool()) {
562         rhs = scalarValue(val.toBool());
563       } else {
564         AT_ASSERT(val.isInt());
565         rhs = scalarValue(val.toInt());
566       }
567       env.s("node", valueName(n->output()));
568       env.s("rhs", rhs);
569       env.s("lhs_type", variableType(*n->output()->type()));
570     } else {
571       env.s("node", valueName(n->output()));
572       env.s("rhs", encodeRHS(n));
573       env.s("lhs_type", variableType(*n->output()->type()));
574     }
575 
576     body << format("${lhs_type} ${node} = ${rhs};\n", env);
577     body_vec4 << format("${lhs_type} ${node} = ${rhs};\n", env);
578   }
579 
580   // Generates writes to output tensors
581   for (const auto& output : outputs) {
582     env.d("formal", formal_count++);
583     env.s("access", format("t${formal}.data[t${formal}_offset]", env));
584     env.s("access_vec4", format("t${formal}_buf[i]", env));
585     env.s("node", valueName(output.first));
586 
587     // Acquires and converts (if needed) outputs
588     // Note: conversion to half is only supported for CUDA kernels.
589     const auto is_half = (output.second.scalar_type == at::ScalarType::Half);
590     const auto is_bfloat =
591         (output.second.scalar_type == at::ScalarType::BFloat16);
592     if (is_half) {
593       AT_ASSERT(use_cuda);
594       body << format("${access} = __float2half(${node});\n", env);
595       body_vec4 << format("${access_vec4} = __float2half(${node});\n", env);
596       has_half_tensor = true;
597     } else if (is_bfloat) {
598       AT_ASSERT(use_cuda);
599       body << format("${access} = __float2bfloat16(${node});\n", env);
600       body_vec4 << format("${access_vec4} = __float2bfloat16(${node});\n", env);
601       has_bfloat_tensor = true;
602     } else {
603       body << format("${access} = ${node};\n", env);
604       body_vec4 << format("${access_vec4} = ${node};\n", env);
605     }
606 
607     // store output in vectorized code path
608     auto ele_size = at::elementSize(output.second.scalar_type);
609     if (ele_size == 1) {
610       env.s(
611           "store4",
612           format(
613               "*(reinterpret_cast<float*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float*>(t${formal}_buf))",
614               env));
615     } else if (ele_size == 2) {
616       env.s(
617           "store4",
618           format(
619               "*(reinterpret_cast<float2*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float2*>(t${formal}_buf))",
620               env));
621     } else if (ele_size == 4) {
622       env.s(
623           "store4",
624           format(
625               "*(reinterpret_cast<float4*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float4*>(t${formal}_buf))",
626               env));
627     } else {
628       env.s(
629           "store4",
630           format(
631               "for(int i = 0; i<4; i++) t${formal}.data[t${formal}_offset + i] = t${formal}_buf[i]",
632               env));
633     }
634     store << format("${store4};\n", env);
635   }
636 
637   // Includes headers
638   // Note: CUDA kernels support halfs and random generation, CPU kernels do not
639   if (has_half_tensor) {
640     env.s("HalfHeader", cuda::half_support_literal);
641   } else {
642     env.s("HalfHeader", "");
643   }
644   if (has_bfloat_tensor) {
645     env.s("BFloat16Header", cuda::bfloat16_support_literal);
646   } else {
647     env.s("BFloat16Header", "");
648   }
649 
650   if (has_random) {
651     env.s("RandHeader", cuda::rand_support_literal);
652     env.s("RandParam", cuda::rand_param);
653     env.s("RandInit", cuda::rand_init);
654   } else {
655     env.s("RandHeader", "");
656     env.s("RandParam", "");
657     env.s("RandInit", "");
658   }
659 
660   // clang-format on
661 
662   // Instantiates the CUDA or CPU-specific templates
663   env.s("tensorOffsets", tensorOffsets.str());
664   env.s("tensorChecks", tensorChecks.str());
665   env.s("kernelBody", body.str());
666   env.s("kernelBody_vec4", body_vec4.str());
667   env.s("kernelLoad", load.str());
668   env.s("kernelStore", store.str());
669   env.v("formals", formals);
670   env.v("argument_loads", argument_loads);
671   std::string code_string;
672   if (use_cuda) {
673     env.s("type_declarations", cuda::type_declarations_template.format(env));
674     code_string = cuda::cuda_compilation_unit_template.format(env);
675   } else {
676     env.s("type_declarations", cpu::type_declarations_template.format(env));
677     code_string = cpu::cpu_compilation_unit_template.format(env);
678   }
679 
680   if (debugFuser()) {
681     std::cerr << "fusion code:" << code_string << '\n';
682   }
683   return code_string;
684 }
685 
686 } // namespace torch::jit::fuser
687