xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/cuda_codegen.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/cuda_codegen.h>
2 #include <torch/csrc/jit/tensorexpr/half_support.h>
3 
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/CUDAGeneratorImpl.h>
6 #include <ATen/native/cuda/jit_utils.h>
7 #include <c10/cuda/CUDAFunctions.h>
8 #include <c10/util/irange.h>
9 #include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h>
10 #include <torch/csrc/jit/codegen/fuser/cuda/resource_strings.h>
11 #include <torch/csrc/jit/jit_log.h>
12 #include <torch/csrc/jit/tensorexpr/analysis.h>
13 #include <torch/csrc/jit/tensorexpr/cuda_random.h>
14 #include <torch/csrc/jit/tensorexpr/eval.h>
15 #include <torch/csrc/jit/tensorexpr/exceptions.h>
16 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
17 #include <torch/csrc/jit/tensorexpr/registerizer.h>
18 
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/NativeFunctions.h>
21 #else
22 #include <ATen/ops/empty_strided_native.h>
23 #endif
24 
25 #include <unordered_map>
26 #include <utility>
27 
28 namespace torch::jit::tensorexpr {
29 
30 // A RAII wrapper to manage a variable and name pair in the look-up table.
31 // TODO: move this to a more shared place.
32 class ScopedVarName {
33  public:
ScopedVarName(VarNameMap * mapping,const VarPtr & var,const std::string & name)34   ScopedVarName(VarNameMap* mapping, const VarPtr& var, const std::string& name)
35       : mapping_(mapping), var_(var) {
36     auto iter = mapping->find(var);
37     if (iter != mapping->end()) {
38       throw std::runtime_error("Duplicate var entry: " + var->name_hint());
39     }
40     mapping->insert(std::make_pair(var, name));
41   }
42 
ScopedVarName(UniqueNameManager * manager,const VarPtr & var,const std::string & name)43   ScopedVarName(
44       UniqueNameManager* manager,
45       const VarPtr& var,
46       const std::string& name)
47       : ScopedVarName(&manager->unique_name_mapping_, var, name) {}
48 
49   ScopedVarName(const ScopedVarName&) = delete;
50   ScopedVarName& operator=(const ScopedVarName&) = delete;
51 
~ScopedVarName()52   ~ScopedVarName() noexcept(false) {
53     mapping_->erase(var_);
54   }
55 
56  private:
57   VarNameMap* mapping_ = nullptr;
58   VarPtr var_ = nullptr;
59 };
60 
is_zero(const ExprPtr & expr)61 static bool is_zero(const ExprPtr& expr) {
62   auto v = intValue(expr);
63   return v && *v == 0;
64 }
65 
nvrtc()66 static const at::cuda::NVRTC& nvrtc() {
67   return at::globalContext().getNVRTC();
68 }
69 
dtypeToCppString(const Dtype & dtype)70 std::string CudaPrinter::dtypeToCppString(const Dtype& dtype) {
71   switch (dtype.scalar_type()) {
72     case ScalarType::Bool:
73       return "bool";
74     case ScalarType::Half:
75       return "half";
76     case ScalarType::BFloat16:
77       return fuser::cuda::bfloat16_type_string;
78     case ScalarType::Char:
79       return "char";
80     case ScalarType::Byte:
81       return "unsigned char";
82     case ScalarType::Short:
83       return "short";
84     case ScalarType::Long:
85       return "long long";
86     default:
87       return dtype.ToCppString();
88   }
89 }
90 
visit(const FreePtr & v)91 void CudaAnalysis::visit(const FreePtr& v) {
92   if (thread_local_bufs_.count(v->buffer_var()) == 0 &&
93       cross_block_bufs_.count(v->buffer_var()) == 0) {
94     throw std::runtime_error("Global free not supported yet");
95   }
96 }
97 
visit(const AllocatePtr & v)98 void CudaAnalysis::visit(const AllocatePtr& v) {
99   StmtPtr p = v->get_parent();
100   while (p) {
101     ForPtr for_v = to<For>(p);
102     if (for_v) {
103       if (for_v->loop_options().is_gpu_block_index()) {
104         // TODO: This isn't right if there's a thread index at a higher level
105         // than this.
106         cross_block_bufs_.insert(v->buffer_var());
107         return;
108       } else if (for_v->loop_options().is_gpu_thread_index()) {
109         thread_local_bufs_.insert(v->buffer_var());
110         return;
111       }
112     }
113     p = p->get_parent();
114   }
115   throw std::runtime_error("Global alloc not supported yet");
116 }
117 
visit(const PlacementAllocatePtr & v)118 void CudaAnalysis::visit(const PlacementAllocatePtr& v) {
119   throw std::runtime_error("Memory reuse not supported yet");
120 }
121 
visit(const ForPtr & v)122 void CudaAnalysis::visit(const ForPtr& v) {
123   // Recurse first.
124   v->body()->accept(this);
125 
126   const LoopOptions& loop_options = v->loop_options();
127   if (loop_options.is_gpu_block_index()) {
128     int gpu_block_index = loop_options.gpu_block_index();
129     if (gpu_block_index >= 3) {
130       throw std::runtime_error("support only 3D gpu_block_index");
131     }
132     ExprPtr prev = nullptr;
133     if (gpu_block_extents_.size() <= static_cast<size_t>(gpu_block_index)) {
134       gpu_block_extents_.resize(gpu_block_index + 1);
135     } else {
136       prev = gpu_block_extents_[gpu_block_index];
137     }
138     if (!is_zero(v->start())) {
139       throw std::runtime_error(
140           "start must be zero for gpu_block_index: " +
141           std::to_string(v->start()));
142     }
143 
144     // NOLINTNEXTLINE(bugprone-branch-clone)
145     if (prev == nullptr) {
146       gpu_block_extents_[gpu_block_index] = v->stop();
147     } else if (prev->isConstant() && immediateEquals(prev, 1)) {
148       // extents must be positive so if the current extent is 1 then even if the
149       // stop is symbolic it's the max.
150       gpu_block_extents_[gpu_block_index] = v->stop();
151     } else {
152       gpu_block_extents_[gpu_block_index] =
153           IRSimplifier::simplify(alloc<Max>(prev, v->stop(), true));
154     }
155   } else if (loop_options.is_gpu_thread_index()) {
156     int gpu_thread_index = loop_options.gpu_thread_index();
157     if (gpu_thread_index >= 3) {
158       throw std::runtime_error("support only 3D gpu_thread_index");
159     }
160     ExprPtr prev = nullptr;
161     if (gpu_thread_extents_.size() <= static_cast<size_t>(gpu_thread_index)) {
162       gpu_thread_extents_.resize(gpu_thread_index + 1);
163     } else {
164       prev = gpu_thread_extents_[gpu_thread_index];
165     }
166     if (!is_zero(v->start())) {
167       throw std::runtime_error(
168           "start must be zero for gpu_thread_index: " +
169           std::to_string(v->start()));
170     }
171 
172     // NOLINTNEXTLINE(bugprone-branch-clone)
173     if (prev == nullptr) {
174       gpu_thread_extents_[gpu_thread_index] = v->stop();
175     } else if (prev->isConstant() && immediateEquals(prev, 1)) {
176       // extents must be positive so if the current extent is 1 then even if the
177       // stop is symbolic it's the max.
178       gpu_thread_extents_[gpu_thread_index] = v->stop();
179     } else {
180       gpu_thread_extents_[gpu_thread_index] =
181           IRSimplifier::simplify(alloc<Max>(prev, v->stop(), true));
182     }
183   }
184 }
185 
print_flat_alloc(const AllocatePtr & alloc)186 void CudaPrinter::print_flat_alloc(const AllocatePtr& alloc) {
187   std::vector<ExprPtr> dims = alloc->dims();
188   // TODO: this should be merged with the storage flattener.
189   int64_t flat_size = 1;
190   for (const auto& dim : dims) {
191     auto dim_i = intValue(dim);
192     if (dim_i) {
193       flat_size *= *dim_i;
194     } else {
195       throw std::runtime_error("Only integer dimensions are supported for now");
196     }
197   }
198   os() << dtypeToCppString(alloc->dtype()) << " " << (*alloc->buffer_var())
199        << "[" << flat_size << "];" << '\n';
200 }
201 
visit(const AllocatePtr & v)202 void CudaPrinter::visit(const AllocatePtr& v) {
203   // TODO: handle dynamic shapes here.
204   if (cuda_analysis_->cross_block_bufs().count(v->buffer_var()) != 0) {
205     emitIndent();
206     os() << "__shared__ ";
207     print_flat_alloc(v);
208     return;
209   }
210 
211   if (cuda_analysis_->thread_local_bufs().count(v->buffer_var()) != 0) {
212     emitIndent();
213     print_flat_alloc(v);
214     return;
215   }
216 
217   throw std::runtime_error("Encountered Alloc not local to block or thread");
218 }
219 
visit(const FreePtr & v)220 void CudaPrinter::visit(const FreePtr& v) {
221   // do nothing
222 }
223 
visit(const ForPtr & v)224 void CudaPrinter::visit(const ForPtr& v) {
225   IRPrinter::visit(v);
226 }
227 
visit(const CastPtr & v)228 void CudaPrinter::visit(const CastPtr& v) {
229   std::string castFn = v->dtype().scalar_type() == ScalarType::Half
230       ? "__float2half"
231       : v->dtype().scalar_type() == ScalarType::BFloat16 ? "__float2bfloat16"
232       : v->src_value()->dtype().scalar_type() == ScalarType::Half
233       ? "__half2float"
234       : v->src_value()->dtype().scalar_type() == ScalarType::BFloat16
235       ? "__bfloat162float"
236       : ("(" + dtypeToCppString(v->dtype()) + ")");
237   os() << castFn << "(";
238   v->src_value()->accept(this);
239   os() << ")";
240 }
241 
visit(const IntrinsicsPtr & v)242 void CudaPrinter::visit(const IntrinsicsPtr& v) {
243   if (v->op_type() == IntrinsicsOp::kRand) {
244     os() << "Uint32ToFloat(" << *rand_func_ << "())";
245     return;
246   }
247 
248   std::string func_name = v->func_name();
249 
250   // get type of resulting expression.
251   ScalarType returnType = v->param(0)->dtype().scalar_type();
252   for (size_t i = 1; i < v->nparams(); ++i) {
253     returnType = promoteTypes(returnType, v->param(i)->dtype().scalar_type());
254   }
255 
256   if (returnType == ScalarType::Half || returnType == ScalarType::Float) {
257     func_name = func_name + "f";
258   }
259   if (v->op_type() == IntrinsicsOp::kAbs &&
260       !c10::isIntegralType(returnType, true)) {
261     // since kAbs's func_name is `abs`, prefix `f` for floating point
262     func_name = "f" + func_name;
263   }
264   if (v->op_type() == IntrinsicsOp::kIsNan) {
265     func_name = "isnan";
266   }
267 
268   os() << func_name << "(";
269   for (const auto i : c10::irange(v->nparams())) {
270     if (i > 0) {
271       os() << ", ";
272     }
273     os() << *v->param(i);
274   }
275   os() << ")";
276 }
277 
visit(const ExternalCallPtr & v)278 void CudaPrinter::visit(const ExternalCallPtr& v) {
279   throw unimplemented_lowering(v);
280 }
281 
visit(const LoadPtr & v)282 void CudaPrinter::visit(const LoadPtr& v) {
283   // TODO: find a better metric in using ldg or not. Support different dtypes.
284   // Detects whether the load target is also a store target.
285   // TODO: this is currently too wide. It detects whether a store-target
286   // exists within the program. In fact, this check is only necessary within a
287   // kernel.
288   if (v->indices().empty()) {
289     os() << *v->base_handle();
290     return;
291   }
292   if (v->dtype().scalar_type() == ScalarType::Bool ||
293       v->dtype().scalar_type() == ScalarType::Half ||
294       v->dtype().scalar_type() == ScalarType::BFloat16) {
295     // There's no __ldg overload for bool or half.
296     os() << *v->base_handle() << "[" << *v->flat_index() << "]";
297     return;
298   }
299   if (cuda_analysis_->is_buf_store_target(v->buf())) {
300     // Cuda __ldg can only be applied on read-only buffers.
301     os() << *v->base_handle() << "[" << *v->flat_index() << "]";
302     return;
303   }
304   os() << "__ldg(" << *v->base_handle() << " + " << *v->flat_index() << ")";
305 }
306 
307 // TODO: maybe this should be a more shared location?
308 // TODO: investigate how "ExprPtr" can be implicitly converted to "ExprHandle"
309 // as a bool.
CheckEqual(const ExprPtr & lhs,const ExprPtr & rhs)310 static bool CheckEqual(const ExprPtr& lhs, const ExprPtr& rhs) {
311   // The fast path. Checks if the pointers are the same.
312   if (lhs == rhs) {
313     return true;
314   }
315   ExprHandle diff = Sub::make(ExprHandle(lhs), ExprHandle(rhs));
316   ExprHandle diff_s = IRSimplifier::simplify(diff);
317   return immediateEquals(diff_s.node(), 0);
318 }
319 
320 class AtomicAddFuser : public IRMutator {
321  public:
AtomicAddFuser(const std::unordered_set<VarPtr> & thread_local_bufs,const GPUMetaVarRewriter & metavars)322   AtomicAddFuser(
323       const std::unordered_set<VarPtr>& thread_local_bufs,
324       const GPUMetaVarRewriter& metavars)
325       : thread_local_bufs_(thread_local_bufs) {
326     const std::vector<ExprPtr>& block_extents = metavars.gpu_block_extents();
327     const std::vector<VarPtr>& block_vars = metavars.gpu_block_vars();
328     for (size_t i = 0; i < block_extents.size(); ++i) {
329       MetaVarExtent extent{block_extents[i], false};
330       if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
331         extent.trivial = true;
332       } else {
333         nontrivial_metavars_.insert(block_vars[i]);
334       }
335       metavars_[block_vars[i]] = extent;
336     }
337 
338     const std::vector<ExprPtr>& thread_extents = metavars.gpu_thread_extents();
339     const std::vector<VarPtr>& thread_vars = metavars.gpu_thread_vars();
340     for (size_t i = 0; i < thread_extents.size(); ++i) {
341       MetaVarExtent extent{thread_extents[i], false};
342       if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
343         extent.trivial = true;
344       } else {
345         nontrivial_metavars_.insert(thread_vars[i]);
346       }
347       metavars_[thread_vars[i]] = extent;
348     }
349   }
350 
mutate(const StorePtr & v)351   StmtPtr mutate(const StorePtr& v) override {
352     BufPtr buf = v->buf();
353 
354     // Thread locals never need to be atomic.
355     if (thread_local_bufs_.count(buf->base_handle()) != 0) {
356       return v;
357     }
358 
359     ScalarType dtype = v->value()->dtype().scalar_type();
360     if (dtype != ScalarType::Float && dtype != ScalarType::Double) {
361       return v;
362     }
363     AddPtr add_v = to<Add>(v->value());
364     if (!add_v) {
365       return v;
366     }
367     LoadPtr load_v = to<Load>(add_v->lhs());
368     if (!load_v) {
369       return v;
370     }
371     if (v->base_handle() != load_v->base_handle()) {
372       return v;
373     }
374     if (v->indices().empty() && load_v->indices().empty()) {
375       return v;
376     }
377     bool index_equal = CheckEqual(v->flat_index(), load_v->flat_index());
378     if (!index_equal) {
379       return v;
380     }
381 
382     // TODO: this checks that the metavars occur directly as an index, but this
383     // is pessimistic, blockIdx.x + 1 is fine too if there is no overlapping.
384     std::unordered_set<VarPtr> vars_to_find = nontrivial_metavars_;
385     for (const ExprPtr& e : v->indices()) {
386       if (VarPtr v = to<Var>(e)) {
387         vars_to_find.erase(v);
388       }
389     }
390 
391     if (vars_to_find.empty()) {
392       // All metavars accounted for.
393       return v;
394     }
395 
396     return alloc<AtomicAdd>(buf, v->indices(), add_v->rhs());
397   }
398 
399  private:
400   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
401   const std::unordered_set<VarPtr>& thread_local_bufs_;
402   struct MetaVarExtent {
403     ExprPtr expr{nullptr};
404     bool trivial{false};
405   };
406   std::unordered_map<VarPtr, MetaVarExtent> metavars_;
407   std::unordered_set<VarPtr> nontrivial_metavars_;
408 };
409 
visit(const StorePtr & v)410 void CudaPrinter::visit(const StorePtr& v) {
411   emitIndent();
412   if (v->indices().empty()) {
413     os() << *v->base_handle() << " = ";
414   } else {
415     os() << *v->base_handle() << "[" << *v->flat_index() << "] = ";
416   }
417   os() << *v->value() << ";";
418   os() << '\n';
419 }
420 
visit(const AtomicAddPtr & v)421 void CudaPrinter::visit(const AtomicAddPtr& v) {
422   emitIndent();
423   if (cuda_analysis_->thread_local_bufs().count(v->base_handle()) > 0) {
424     // atomicAdd only works on global and shared memory
425     os() << *v->base_handle() << "[" << *v->flat_index()
426          << "] += " << *v->value() << ";";
427   } else {
428     os() << "atomicAdd(&" << *v->base_handle() << "[" << *v->flat_index() << "]"
429          << ", " << *v->value() << ");";
430   }
431   os() << '\n';
432 }
433 
visit(const MaxPtr & v)434 void CudaPrinter::visit(const MaxPtr& v) {
435   if (v->dtype().is_integral()) {
436     os() << "max(";
437   } else {
438     os() << "maximum(";
439   }
440   v->lhs()->accept(this);
441   os() << ",";
442   v->rhs()->accept(this);
443   os() << ")";
444 }
445 
visit(const MinPtr & v)446 void CudaPrinter::visit(const MinPtr& v) {
447   if (v->dtype().is_integral()) {
448     os() << "min(";
449   } else {
450     os() << "minimum(";
451   }
452   v->lhs()->accept(this);
453   os() << ",";
454   v->rhs()->accept(this);
455   os() << ")";
456 }
457 
visit(const IfThenElsePtr & v)458 void CudaPrinter::visit(const IfThenElsePtr& v) {
459   os() << "((";
460   v->condition()->accept(this);
461   os() << ") ? ";
462   v->true_value()->accept(this);
463   os() << " : ";
464   v->false_value()->accept(this);
465   os() << ")";
466 }
467 
visit(const BlockPtr & v)468 void CudaPrinter::visit(const BlockPtr& v) {
469   os() << "{" << '\n';
470   indent_++;
471 
472   for (const StmtPtr& s : v->stmts()) {
473     s->accept(this);
474   }
475 
476   indent_--;
477   emitIndent();
478   os() << "}";
479 }
480 
visit(const LetPtr & v)481 void CudaPrinter::visit(const LetPtr& v) {
482   emitIndent();
483   os() << dtypeToCppString(v->var()->dtype());
484   os() << " " << *v->var() << " = ";
485   v->value()->accept(this);
486   os() << ";" << '\n';
487 }
488 
489 class PrioritizeLoad : public IRMutator {
490  public:
mutate(const LoadPtr & v)491   ExprPtr mutate(const LoadPtr& v) override {
492     // Look at the declaration of this variable for more details.
493     if (nested_if_then_else_ > 0) {
494       return IRMutator::mutate(v);
495     }
496     if (nested_let_) {
497       return IRMutator::mutate(v);
498     }
499     if (thread_local_bufs_.count(v->base_handle()) > 0) {
500       return IRMutator::mutate(v);
501     }
502     if (v->indices().empty()) {
503       return IRMutator::mutate(v);
504     }
505     if (nested_store_) {
506       if (v->base_handle() == nested_store_->buf()->base_handle() &&
507           v->indices().size() == nested_store_->indices().size()) {
508         // also check indices
509         bool same = true;
510         for (const auto i : c10::irange(v->indices().size())) {
511           if (!exprEquals(v->indices()[i], nested_store_->indices()[i])) {
512             same = false;
513             break;
514           }
515         }
516         if (same) {
517           return IRMutator::mutate(v);
518         }
519       } else if (nested_store_->indices().empty()) {
520         return IRMutator::mutate(v);
521       }
522     }
523 
524     MemLoadList& load_list = load_stack_.back();
525     VarPtr load_new_var = alloc<Var>("v", v->dtype());
526     ExprPtr new_value = IRMutator::mutate(v);
527     load_list.emplace_back(load_new_var, new_value);
528 
529     return load_new_var;
530   }
531 
mutate(const CastPtr & v)532   ExprPtr mutate(const CastPtr& v) override {
533     LoadPtr src_load = to<Load>(v->src_value());
534     ExprPtr new_src = v->src_value()->accept_mutator(this);
535     VarPtr new_var = to<Var>(new_src);
536     if (!src_load || !new_var) {
537       return alloc<Cast>(v->dtype(), new_src);
538     }
539 
540     // We just did the prioritize load, let's fold in the Cast.
541     MemLoadList& load_list = load_stack_.back();
542     assert(!load_list.empty());
543     auto pair = load_list.back();
544     assert(pair.first == new_var);
545     load_list.pop_back();
546 
547     new_var = alloc<Var>("v", v->dtype());
548     ExprPtr new_value = alloc<Cast>(v->dtype(), pair.second);
549     load_list.emplace_back(new_var, new_value);
550     return new_var;
551   }
552 
mutate(const StorePtr & v)553   StmtPtr mutate(const StorePtr& v) override {
554     StorePtr last = nested_store_;
555     nested_store_ = v;
556     StmtPtr s = IRMutator::mutate(v);
557     nested_store_ = last;
558     return s;
559   }
560 
mutate(const LetPtr & v)561   StmtPtr mutate(const LetPtr& v) override {
562     nested_let_ = true;
563     StmtPtr s = IRMutator::mutate(v);
564     nested_let_ = false;
565     return s;
566   }
567 
mutate(const BlockPtr & v)568   StmtPtr mutate(const BlockPtr& v) override {
569     std::list<StmtPtr> stmts = v->stmts();
570     for (const StmtPtr& stmt : stmts) {
571       PushList();
572       StmtPtr stmt_new = stmt->accept_mutator(this);
573 
574       AddMemLoadsFromList(v, stmt);
575       PopList();
576 
577       if (stmt_new == stmt) {
578         continue;
579       }
580       v->replace_stmt(stmt, stmt_new);
581     }
582     return v;
583   }
584 
mutate(const IfThenElsePtr & v)585   ExprPtr mutate(const IfThenElsePtr& v) override {
586     nested_if_then_else_++;
587     ExprPtr new_v = IRMutator::mutate(v);
588     nested_if_then_else_--;
589     return new_v;
590   }
591 
592  private:
593   using MemLoadEntry = std::pair<VarPtr, ExprPtr>;
594   using MemLoadList = std::vector<MemLoadEntry>;
595   using MemoryLoadStack = std::vector<MemLoadList>;
596 
PushList()597   void PushList() {
598     load_stack_.emplace_back();
599   }
600 
PopList()601   void PopList() {
602     load_stack_.pop_back();
603   }
604 
AddMemLoadsFromList(const BlockPtr & block,const StmtPtr & last)605   void AddMemLoadsFromList(const BlockPtr& block, const StmtPtr& last) {
606     MemLoadList& load_list = load_stack_.back();
607     if (load_list.empty()) {
608       return;
609     }
610 
611     for (auto& pair : load_list) {
612       StmtPtr news = alloc<Let>(pair.first, pair.second);
613       block->insert_stmt_before(news, last);
614     }
615   }
616 
617   MemoryLoadStack load_stack_;
618   // TODO: For now, we are not moving the loads with the IfThenElse.
619   // Eventually, we should switch to a more generic structure like:
620   // int v2 = IfThenElse(cond, true_v, false_v) + 2 ->
621   //
622   // int v;
623   // if (cond) {
624   //   v = true_v;
625   // } else {
626   //   v = false_v;
627   // }
628   // int v2 = v + 2;
629   int nested_if_then_else_{0};
630   StorePtr nested_store_{nullptr};
631   bool nested_let_{false};
632   std::unordered_set<VarPtr> thread_local_bufs_;
633 };
634 
GetUniqueFuncName(const std::string & func_prefix)635 std::string CudaCodeGen::GetUniqueFuncName(const std::string& func_prefix) {
636   int64_t counter = 0;
637   std::string name = func_prefix;
638   while (taken_func_names.count(name)) {
639     name = func_prefix + "_" + std::to_string(counter++);
640   }
641 
642   taken_func_names.insert(name);
643   return name;
644 }
645 
isFullExtent()646 bool GPUMetaVarRewriter::isFullExtent() {
647   {
648     auto& extents = cuda_analysis_->gpu_block_extents();
649     for (int i = 0; i < 3; ++i) {
650       if (!exprEquals(current_block_reach_[i], extents[i])) {
651         return false;
652       }
653     }
654   }
655 
656   {
657     auto& extents = cuda_analysis_->gpu_thread_extents();
658     for (int i = 0; i < 3; ++i) {
659       if (!exprEquals(current_thread_reach_[i], extents[i])) {
660         return false;
661       }
662     }
663   }
664 
665   return true;
666 }
667 
mutate(const ForPtr & v)668 StmtPtr GPUMetaVarRewriter::mutate(const ForPtr& v) {
669   StmtPtr body = v->body();
670   ExprPtr old_reach = nullptr;
671   const LoopOptions& loop_options = v->loop_options();
672   if (loop_options.is_gpu_block_index()) {
673     int gpu_block_index = loop_options.gpu_block_index();
674     if (gpu_block_index >= 3) {
675       throw std::runtime_error("support only 3D gpu_block_index");
676     }
677     old_reach = current_block_reach_[gpu_block_index];
678 
679     // Extents must be positive, assume >= 1.
680     if (old_reach->isConstant() && immediateEquals(old_reach, 1)) {
681       current_block_reach_[gpu_block_index] = v->stop();
682     } else {
683       current_block_reach_[gpu_block_index] =
684           IRSimplifier::simplify(alloc<Max>(old_reach, v->stop(), true));
685     }
686 
687     VarPtr metaVar = gpu_block_vars_[gpu_block_index];
688     body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
689   } else if (loop_options.is_gpu_thread_index()) {
690     int gpu_thread_index = loop_options.gpu_thread_index();
691     if (gpu_thread_index >= 3) {
692       throw std::runtime_error("support only 3D gpu_thread_index");
693     }
694     old_reach = current_thread_reach_[gpu_thread_index];
695 
696     // Extents must be positive, assume >= 1.
697     if (old_reach->isConstant() && immediateEquals(old_reach, 1)) {
698       current_thread_reach_[gpu_thread_index] = v->stop();
699     } else {
700       current_thread_reach_[gpu_thread_index] =
701           IRSimplifier::simplify(alloc<Max>(old_reach, v->stop(), true));
702     }
703 
704     VarPtr metaVar = gpu_thread_vars_[gpu_thread_index];
705     body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
706   }
707 
708   // Recurse into body block.
709   body = Stmt::clone(body->accept_mutator(this));
710 
711   // pop the internal reach off the stack.
712   if (loop_options.is_gpu_block_index()) {
713     current_block_reach_[loop_options.gpu_block_index()] = old_reach;
714     return body;
715   } else if (loop_options.is_gpu_thread_index()) {
716     current_thread_reach_[loop_options.gpu_thread_index()] = old_reach;
717     return body;
718   }
719 
720   return v->cloneWithNewBody(body);
721 }
722 
mutate(const BlockPtr & v)723 StmtPtr GPUMetaVarRewriter::mutate(const BlockPtr& v) {
724   std::vector<Segment> innerSegments;
725   Segment current;
726 
727   auto pushAndReset = [&](bool mask) {
728     if (!current.empty()) {
729       innerSegments.push_back(current);
730     }
731     current.reset(mask);
732   };
733 
734   // Here's we're slicing the Block's contents into segments that should have
735   // the same launch reach. Segments are comprised of all statements that aren't
736   // loops - which are their own segments. Some operations, such as threading
737   // and memory ops should never be masked and so also get their own segment.
738   for (const StmtPtr& stmt : *v) {
739     StmtPtr stmt_new = stmt->accept_mutator(this);
740     if (stmt == stmt_new) {
741       stmt_new = Stmt::clone(stmt_new);
742     }
743 
744     // Likewise, Allocate and Free should never be masked.
745     if (to<Allocate>(stmt) || to<Free>(stmt)) {
746       pushAndReset(false);
747     }
748 
749     // If the current stmt *was* a loop, it's a segment boundary.
750     if (ForPtr f = to<For>(stmt)) {
751       pushAndReset(false);
752     }
753 
754     current.stmts().push_back(stmt_new);
755     // if the current segment should not be masked, it's a segment boundary on
756     // the far side as well.
757     if (!current.mask()) {
758       pushAndReset(true);
759     }
760   }
761 
762   if (!current.empty()) {
763     innerSegments.push_back(current);
764   }
765 
766   // We are max extent in all dimensions, so need no masks at this level.
767   if (isFullExtent()) {
768     // flatten inner segments.
769     std::vector<StmtPtr> stmts;
770     for (auto& v : innerSegments) {
771       for (const auto& s : v.stmts()) {
772         stmts.push_back(s);
773       }
774     }
775 
776     return alloc<Block>(stmts);
777   }
778 
779   std::vector<StmtPtr> stmts;
780   for (auto& segment : innerSegments) {
781     bool need_sync = false;
782     // We never mask loops, they'll mask their contents.
783     if (!segment.mask()) {
784       TORCH_INTERNAL_ASSERT(segment.stmts().size() == 1, buildErrorMessage());
785       stmts.push_back(segment.stmts()[0]);
786       continue;
787     }
788 
789     // If we get here, we must mask since we're not full reach and our direct
790     // child isn't a For.
791     StmtPtr inner = alloc<Block>(segment.stmts());
792     // threads inside blocks.
793     auto& thread_extents = cuda_analysis_->gpu_thread_extents();
794     for (size_t i = 0; i < gpu_thread_vars_.size(); ++i) {
795       if (!exprEquals(current_thread_reach_[i], thread_extents[i])) {
796         need_sync = true;
797         // Mask it against the current dimensions.
798         inner = alloc<Cond>(
799             alloc<CompareSelect>(
800                 gpu_thread_vars_[i],
801                 current_thread_reach_[i],
802                 CompareSelectOperation::kLT),
803             inner,
804             nullptr);
805       }
806     }
807     auto& block_extents = cuda_analysis_->gpu_block_extents();
808     for (size_t i = 0; i < gpu_block_vars_.size(); ++i) {
809       if (!exprEquals(current_block_reach_[i], block_extents[i])) {
810         // Mask it against the current dimensions.
811         inner = alloc<Cond>(
812             alloc<CompareSelect>(
813                 gpu_block_vars_[i],
814                 current_block_reach_[i],
815                 CompareSelectOperation::kLT),
816             inner,
817             nullptr);
818       }
819     }
820 
821     if (need_sync) {
822       stmts.push_back(alloc<SyncThreads>());
823     }
824     stmts.push_back(inner);
825     if (need_sync) {
826       stmts.push_back(alloc<SyncThreads>());
827     }
828   }
829 
830   return alloc<Block>(stmts);
831 }
832 
operator <<(std::ostream & out,const std::vector<ExprPtr> & exprs)833 static std::ostream& operator<<(
834     std::ostream& out,
835     const std::vector<ExprPtr>& exprs) {
836   size_t i = 0;
837   for (const auto& expr : exprs) {
838     if (i++ > 0) {
839       out << ", ";
840     }
841     out << *expr;
842   }
843   return out;
844 }
845 
846 static const char* device_resource_string = R"(
847 #define NAN __int_as_float(0x7fffffff)
848 #define POS_INFINITY __int_as_float(0x7f800000)
849 #define NEG_INFINITY __int_as_float(0xff800000)
850 
851 )";
852 
853 static const char* shared_resource_string = R"(
854 template<typename T>
855 __device__ T maximum(T a, T b) {
856   return isnan(a) ? a : (a > b ? a : b);
857 }
858 
859 template<typename T>
860 __device__ T minimum(T a, T b) {
861   return isnan(a) ? a : (a < b ? a : b);
862 }
863 
864 )";
865 
Initialize()866 void CudaCodeGen::Initialize() {
867   // TODO: handle multiple kernels.
868   // TODO: handle dynamic dimension.
869   // TODO: call nvrtc.
870   // TODO: merge HasRand with CudaAnalysis.
871   GenericIntrinsicsExpander intrinsics_expander;
872   apply_mutator(&intrinsics_expander);
873 
874   HasRand has_rand_func(stmt());
875   has_random_ = has_rand_func.has_rand();
876   cuda_analysis_ = std::make_unique<CudaAnalysis>();
877   printer_ =
878       std::make_unique<CudaPrinter>(&oss_, cuda_analysis_.get(), has_random_);
879   metavar_rewriter_ =
880       std::make_unique<GPUMetaVarRewriter>(cuda_analysis_.get());
881 
882   // Check whether the statement uses the Half type, if so add the
883   // half_support_literal.
884   StmtPtr stmt_v = stmt();
885   HalfChecker halfChecker(buffer_args());
886   stmt_v->accept(&halfChecker);
887 
888   os() << device_resource_string << shared_resource_string;
889 
890   if (has_random_) {
891     os() << philox_random_string << '\n';
892   }
893 
894   if (halfChecker.hasHalf()) {
895     os() << fuser::cuda::half_support_literal << '\n';
896   }
897   if (halfChecker.hasBFloat16()) {
898     os() << fuser::cuda::bfloat16_support_literal << '\n';
899   }
900 
901   std::string func_name = GetUniqueFuncName(kernel_func_name());
902   os() << "extern \"C\" __global__" << '\n';
903 #if defined(USE_ROCM)
904   // CUDA has a default limit of threads per block (=flat work group size)
905   // of 1024, but ROCm uses 256 by default. At the time of writing
906   // (#45506), I am unaware of a stricter limit that TensorExpr imposes
907   // (maybe for perf),so I use 1024 as maximum flat work group size.
908   // We put a minimum value of 1, this is also used by hip (ROCm 3.8) in
909   // the __launch_bound__ implementation. The arguments for the attribute
910   // are (min, max), for details see the documentation at
911   // https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
912   os() << "__attribute__((amdgpu_flat_work_group_size(1, 1024)))" << std::endl;
913 #endif
914   os() << "void " << func_name << "(";
915   const std::vector<BufferArg> buffer_args = this->buffer_args();
916   for (size_t i = 0; i < buffer_args.size(); i++) {
917     if (i > 0) {
918       os() << ", ";
919     }
920     const BufferArg& buffer_arg = buffer_args[i];
921     VarPtr var = buffer_arg.var();
922     Dtype dtype = buffer_arg.dtype();
923 
924     os() << printer_->dtypeToCppString(dtype)
925          << (buffer_arg.isVar() ? " " : "* ")
926          << name_manager()->get_unique_name(var);
927   }
928   VarPtr rand_seed;
929   VarPtr rand_offset;
930   if (has_random_) {
931     // TODO: switch to kUint64 when it is available.
932     rand_seed = alloc<Var>("rand_seed", kInt);
933     rand_offset = alloc<Var>("rand_offset", kInt);
934     std::string uint64_str = "unsigned long long";
935     os() << ", " << uint64_str << " " << *rand_seed << ", " << uint64_str << " "
936          << *rand_offset;
937   }
938   os() << ") {";
939   os() << '\n';
940 
941   if (has_random_) {
942     VarPtr idx = alloc<Var>("idx", kInt);
943     os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;" << '\n';
944     VarPtr rand_func = printer_->rand_func();
945     os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", "
946          << *rand_offset << ");" << '\n';
947     os() << '\n';
948   }
949 
950   stmt_v->accept(cuda_analysis_.get());
951 
952   stmt_v = stmt_v->accept_mutator(metavar_rewriter_.get());
953 
954   AtomicAddFuser atomic_add_fuser(
955       cuda_analysis_->thread_local_bufs(), *metavar_rewriter_);
956   stmt_v = stmt_v->accept_mutator(&atomic_add_fuser);
957 
958   stmt_v = registerize(stmt_v);
959 
960   PrioritizeLoad prioritize_load;
961   stmt_v = stmt_v->accept_mutator(&prioritize_load);
962 
963   // The registerizer might insert half-type scalars, we don't want this.
964   HalfRewriter hsFix;
965   stmt_v = stmt_v->accept_mutator(&hsFix);
966 
967   stmt_v = IRSimplifier::simplify(stmt_v);
968   set_stmt(stmt_v);
969 
970   stmt_v->accept(printer_.get());
971   os() << '\n';
972   os() << "}";
973 
974   // Check that all block extents had been set.
975   const std::vector<ExprPtr>& gpu_block_extents =
976       metavar_rewriter_->gpu_block_extents();
977   for (size_t i = 0; i < gpu_block_extents.size(); i++) {
978     if (!gpu_block_extents[i]) {
979       throw std::runtime_error("Missing gpu_block_index: " + std::to_string(i));
980     }
981   }
982 
983   // Precompute block and thread extents for call_with_numel().  If
984   // precomputation can't be done (block/thread extents aren't
985   // constant), then disallow call_with_numel.
986   auto block_extents = metavar_rewriter_->gpu_block_extents();
987   auto thread_extents = metavar_rewriter_->gpu_thread_extents();
988   bool canCallWithNumel =
989       !has_random_ && !block_extents.empty() && !thread_extents.empty();
990   for (size_t i = 1; i < block_extents.size() && canCallWithNumel; i++) {
991     canCallWithNumel = canCallWithNumel && block_extents[i]->isConstant() &&
992         immediateAs<int>(block_extents[i]) == 1;
993   }
994   for (size_t i = 1; i < thread_extents.size() && canCallWithNumel; i++) {
995     canCallWithNumel = canCallWithNumel && thread_extents[i]->isConstant() &&
996         immediateAs<int>(thread_extents[i]) == 1;
997   }
998   if (canCallWithNumel && thread_extents[0]->isConstant()) {
999     // We assume block_extents[0] is output.numel()/thread_block_size_.
1000     thread_block_size_ = immediateAs<int>(thread_extents[0]);
1001   } else {
1002     // Disable call_with_numel.
1003     thread_block_size_ = -1;
1004   }
1005 
1006   // Build an LLVM based eval expression for the extents
1007   block_extents_eval_.reserve(block_extents.size());
1008   std::vector<BufferArg> extents_buffer_args;
1009 
1010   // We need to extract the args that are used in the thread and block extents
1011   // from bufferArgs and only use those for the `ExprEval` below. Without this,
1012   // bufferArgs might contain arbitrary types that are not handled by LLVM and
1013   // hence would result in an error.
1014   std::unordered_set<VarPtr> vars_in_extents;
1015   for (const auto& be : block_extents) {
1016     auto v = VarFinder::find(be);
1017     vars_in_extents.insert(v.begin(), v.end());
1018   }
1019   for (const auto& te : thread_extents) {
1020     auto v = VarFinder::find(te);
1021     vars_in_extents.insert(v.begin(), v.end());
1022   }
1023   for (const size_t i : c10::irange(buffer_args.size())) {
1024     if (vars_in_extents.count(buffer_args[i].var())) {
1025       extents_buffer_args.push_back(buffer_args[i]);
1026       arg_pos_in_extents_.push_back(true);
1027     } else {
1028       arg_pos_in_extents_.push_back(false);
1029     }
1030   }
1031   for (const auto& be : block_extents) {
1032 #ifdef TORCH_ENABLE_LLVM
1033     block_extents_eval_.emplace_back(
1034         ExprEval<LLVMCodeGen>(ExprHandle(be), extents_buffer_args));
1035 #else
1036     block_extents_eval_.emplace_back(ExprHandle(be), extents_buffer_args);
1037 #endif
1038   }
1039   thread_extents_eval_.reserve(thread_extents.size());
1040   for (const auto& te : thread_extents) {
1041 #ifdef TORCH_ENABLE_LLVM
1042     thread_extents_eval_.emplace_back(
1043         ExprEval<LLVMCodeGen>(ExprHandle(te), extents_buffer_args));
1044 #else
1045     thread_extents_eval_.emplace_back(ExprHandle(te), extents_buffer_args);
1046 #endif
1047   }
1048 
1049   GRAPH_DEBUG(
1050       "Fused TE CUDA kernel:\n",
1051       oss_.str(),
1052       "\n",
1053       "gpu_block_extents: (",
1054       metavar_rewriter_->gpu_block_extents(),
1055       ")\n",
1056       "gpu_thread_extents: (",
1057       metavar_rewriter_->gpu_thread_extents(),
1058       ")");
1059 
1060   CompileToNVRTC(oss_.str(), func_name);
1061 }
1062 
call_with_numel(void ** args,int64_t numel)1063 void CudaCodeGen::call_with_numel(void** args, int64_t numel) {
1064   if (C10_UNLIKELY(numel == 0)) {
1065     return;
1066   }
1067   if (C10_UNLIKELY(thread_block_size_ <= 0)) {
1068     TORCH_INTERNAL_ASSERT(
1069         thread_block_size_ >= 0,
1070         "call_with_numel() requires a precomputed thread block size");
1071   }
1072 
1073   auto const& buffer_args = this->buffer_args();
1074   size_t gpu_block_extents =
1075       (numel + thread_block_size_ - 1) / thread_block_size_;
1076   size_t gpu_thread_extents = thread_block_size_;
1077 
1078   // In CUDA we need to pass pointers to pointers for buffers, thus we need to
1079   // go over args and add an extra indirection for such non-scalar
1080   // arguments.
1081   // Why? See some details here:
1082   // https://stackoverflow.com/questions/34388712/cannot-understand-how-jcuda-culaunchkernel-work
1083   std::vector<void*> ptr_to_args(buffer_args.size());
1084   for (size_t i = 0; i < buffer_args.size(); i++) {
1085     ptr_to_args[i] =
1086         buffer_args[i].isVar() ? args[i] : const_cast<void**>(&args[i]);
1087   }
1088 
1089   const auto device = this->device().index();
1090   const auto prior_device = at::cuda::current_device();
1091   if (prior_device != device) {
1092     at::cuda::set_device(device);
1093   }
1094 
1095   auto stream = at::cuda::getCurrentCUDAStream();
1096   at::cuda::jit::initializeCudaContext();
1097   AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
1098       function_,
1099       gpu_block_extents,
1100       1,
1101       1,
1102       gpu_thread_extents,
1103       1,
1104       1,
1105       0,
1106       stream,
1107       ptr_to_args.data(),
1108       nullptr));
1109 
1110   if (prior_device != device) {
1111     at::cuda::set_device(prior_device);
1112   }
1113 }
1114 
call_raw(const std::vector<void * > & raw_args)1115 void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
1116   auto const& buffer_args = this->buffer_args();
1117 
1118   // TODO: move as much of this into the constructors.
1119   const std::vector<ExprPtr>& gpu_block_extents =
1120       metavar_rewriter_->gpu_block_extents();
1121   const std::vector<ExprPtr>& gpu_thread_extents =
1122       metavar_rewriter_->gpu_thread_extents();
1123   if (gpu_block_extents.size() > 3 || gpu_thread_extents.size() > 3) {
1124     throw malformed_input(
1125         "cuda_codegen: block or thread extent greater than 3D");
1126   }
1127 
1128   std::vector<int64_t> gpu_block_extents_v(3, 1);
1129   std::vector<int64_t> gpu_thread_extents_v(3, 1);
1130 
1131   // evaluate all the block/thread extents into values
1132   // TODO: eventually, codegen these calculations and make them part of the
1133   // module.
1134   std::vector<void*> extent_args;
1135   size_t raw_args_size = raw_args.size();
1136   extent_args.reserve(raw_args_size);
1137   for (size_t i = 0; i < raw_args_size; ++i) {
1138     if (arg_pos_in_extents_[i]) {
1139       extent_args.push_back(raw_args[i]);
1140     }
1141   }
1142   for (size_t i = 0; i < gpu_block_extents.size(); i++) {
1143     if (gpu_block_extents[i]->isConstant()) {
1144       gpu_block_extents_v[i] = immediateAs<int64_t>(gpu_block_extents[i]);
1145       continue;
1146     }
1147     {
1148       // invocation of block_extents_eval_ isn't thread safe and this function
1149       // may be invoked by multiple threads
1150       std::lock_guard<std::mutex> guard(eval_lock_);
1151       gpu_block_extents_v[i] =
1152           block_extents_eval_[i].value<int64_t>(extent_args);
1153     }
1154   }
1155   for (size_t i = 0; i < gpu_thread_extents.size(); i++) {
1156     if (gpu_thread_extents[i]->isConstant()) {
1157       gpu_thread_extents_v[i] = immediateAs<int64_t>(gpu_thread_extents[i]);
1158       continue;
1159     }
1160     {
1161       std::lock_guard<std::mutex> guard(eval_lock_);
1162       gpu_thread_extents_v[i] =
1163           thread_extents_eval_[i].value<int64_t>(extent_args);
1164     }
1165   }
1166 
1167   // Skip launching the kernel if there are no elements to process.
1168   for (auto extent : gpu_block_extents_v) {
1169     if (extent == 0) {
1170       return;
1171     }
1172   }
1173 
1174   auto ptr_count = buffer_args.size();
1175   // If the kernel has a rand call in it, add two extra arguments for random
1176   // seed and offset.
1177   if (has_random_) {
1178     ptr_count += 2;
1179   }
1180   std::vector<void*> ptr_to_args(ptr_count);
1181 
1182   // In CUDA we need to pass pointers to pointers for buffers, thus we need to
1183   // go over raw_args and add an extra indirection for such non-scalar
1184   // arguments.
1185   // Why? See some details here:
1186   // https://stackoverflow.com/questions/34388712/cannot-understand-how-jcuda-culaunchkernel-work
1187   for (size_t i = 0; i < buffer_args.size(); i++) {
1188     ptr_to_args[i] =
1189         buffer_args[i].isVar() ? raw_args[i] : const_cast<void**>(&raw_args[i]);
1190   }
1191 
1192   if (has_random_) {
1193     uint64_t rand_seed = uint64_t(-1);
1194     uint64_t rand_offset = uint64_t(-1);
1195     auto gen = at::cuda::detail::getDefaultCUDAGenerator();
1196     // TODO: total hack. Switch to numel when it is available.
1197     int64_t total_elements_per_thread = (1LL << 28);
1198     {
1199       std::lock_guard<std::mutex> lock(gen.mutex());
1200       auto philox_engine_inputs =
1201           at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
1202               total_elements_per_thread);
1203       rand_seed = philox_engine_inputs.first;
1204       rand_offset = philox_engine_inputs.second;
1205     }
1206     ptr_to_args[buffer_args.size()] = &rand_seed;
1207     ptr_to_args[buffer_args.size() + 1] = &rand_offset;
1208   }
1209 
1210   auto prior_device = at::cuda::current_device();
1211   if (prior_device != this->device().index()) {
1212     at::cuda::set_device(this->device().index());
1213   }
1214   // Launch the kernels
1215   auto stream = at::cuda::getCurrentCUDAStream();
1216   at::cuda::jit::initializeCudaContext();
1217   AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
1218       function_,
1219       gpu_block_extents_v[0],
1220       gpu_block_extents_v[1],
1221       gpu_block_extents_v[2],
1222       gpu_thread_extents_v[0],
1223       gpu_thread_extents_v[1],
1224       gpu_thread_extents_v[2],
1225       0,
1226       stream,
1227       ptr_to_args.data(),
1228       nullptr));
1229 
1230   if (prior_device != this->device().index()) {
1231     at::cuda::set_device(prior_device);
1232   }
1233 }
1234 
call(const std::vector<CallArg> & args)1235 void CudaCodeGen::call(const std::vector<CallArg>& args) {
1236   if (args.size() != buffer_args().size()) {
1237     throw malformed_input("cuda_codegen: wrong number of args in call");
1238   }
1239 
1240   auto const& buffer_args = this->buffer_args();
1241   std::vector<void*> raw_args(buffer_args.size());
1242   for (size_t i = 0; i < buffer_args.size(); i++) {
1243     auto const& bufferArg = buffer_args[i];
1244     auto const& callArg = args[i];
1245     raw_args[i] = argToPtr(bufferArg, callArg);
1246   }
1247   call_raw(raw_args);
1248 }
1249 
empty_strided(c10::IntArrayRef size,c10::IntArrayRef stride,std::optional<c10::ScalarType> dtype_opt,std::optional<c10::Layout> layout_opt,std::optional<c10::Device> device_opt,std::optional<bool> pin_memory_opt)1250 at::Tensor CudaCodeGen::empty_strided(
1251     c10::IntArrayRef size,
1252     c10::IntArrayRef stride,
1253     std::optional<c10::ScalarType> dtype_opt,
1254     std::optional<c10::Layout> layout_opt,
1255     std::optional<c10::Device> device_opt,
1256     std::optional<bool> pin_memory_opt) {
1257   c10::DeviceGuard device_guard(device_opt.value());
1258   return at::native::empty_strided_cuda(
1259       size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
1260 }
1261 
CompileToNVRTC(const std::string & code,const std::string & func_name)1262 void CudaCodeGen::CompileToNVRTC(
1263     const std::string& code,
1264     const std::string& func_name) {
1265   at::cuda::jit::initializeCudaContext();
1266   // Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
1267   // properly in some scenarios
1268   auto prior_device = at::cuda::current_device();
1269   if (prior_device != this->device().index()) {
1270     at::cuda::set_device(this->device().index());
1271   }
1272   // Acquires device and NVRTC properties (for compile arch and occupancy
1273   // calculations)
1274   cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
1275   int major = 0, minor = 0;
1276   bool compile_to_sass = false;
1277   fuser::cuda::codegenOutputQuery(prop, major, minor, compile_to_sass);
1278 
1279   // Creates the NVRTC program
1280   nvrtcProgram program{nullptr};
1281   AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
1282       &program, code.c_str(), nullptr, 0, nullptr, nullptr));
1283 
1284 #if defined(USE_ROCM)
1285   std::vector<const char*> args = {"--std=c++17"};
1286   args.push_back("-hip-pch");
1287 #else
1288   const std::string compute = std::string("--gpu-architecture=") +
1289 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
1290       // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
1291       // which gives better backwards compatibility to work on older driver,
1292       // (since older driver doesn't necessarily recognize PTX emitted by new
1293       // toolkit);
1294       // Meanwhile, for forward compatibility (future device with
1295       // `compile_to_sass==false`), since SASS are not necessarily compatible,
1296       // we fallback to PTX instead.
1297       (compile_to_sass ? "sm_" : "compute_") +
1298 #else
1299       "compute_" +
1300 #endif
1301       std::to_string(major) + std::to_string(minor);
1302   const std::vector<const char*> args = {
1303       "--std=c++17", compute.c_str(), "-default-device"};
1304 #endif
1305 
1306   auto result = nvrtc().nvrtcCompileProgram(
1307       program, static_cast<int>(args.size()), args.data());
1308   if (result != NVRTC_SUCCESS) {
1309     size_t logsize = 0;
1310     AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize));
1311     std::vector<char> log(logsize);
1312     AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data()));
1313     std::stringstream cu;
1314     cu << log.data() << '\n';
1315     cu << "nvrtc compilation failed: " << '\n';
1316     cu << code << '\n';
1317     throw std::runtime_error(cu.str());
1318   }
1319   ResourceGuard holdProgram(
1320       [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
1321   AT_CUDA_NVRTC_CHECK(result);
1322   size_t ptx_size = 0;
1323   std::vector<char> ptx;
1324 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
1325   // compile_to_sass determines whether we are generating SASS or PTX, hence
1326   // the different API.
1327   auto getSize = compile_to_sass
1328       ? at::globalContext().getNVRTC().nvrtcGetCUBINSize
1329       : at::globalContext().getNVRTC().nvrtcGetPTXSize;
1330   auto getFunc = compile_to_sass ? at::globalContext().getNVRTC().nvrtcGetCUBIN
1331                                  : at::globalContext().getNVRTC().nvrtcGetPTX;
1332 #else
1333   auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
1334   auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
1335 #endif
1336   AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
1337   ptx.resize(ptx_size);
1338   AT_CUDA_NVRTC_CHECK(getFunc(program, ptx.data()));
1339 
1340   CUmodule module{nullptr};
1341   AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module, ptx.data()));
1342   AT_CUDA_DRIVER_CHECK(
1343       nvrtc().cuModuleGetFunction(&function_, module, func_name.c_str()));
1344 
1345   if (prior_device != this->device().index()) {
1346     at::cuda::set_device(prior_device);
1347   }
1348 }
1349 
1350 CudaCodeGen::~CudaCodeGen() = default;
1351 
1352 RegisterCodeGen<CudaCodeGen> cuda_codegen_reg("cuda_codegen");
1353 
1354 } // namespace torch::jit::tensorexpr
1355