1 #include <torch/csrc/jit/tensorexpr/cuda_codegen.h>
2 #include <torch/csrc/jit/tensorexpr/half_support.h>
3
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/CUDAGeneratorImpl.h>
6 #include <ATen/native/cuda/jit_utils.h>
7 #include <c10/cuda/CUDAFunctions.h>
8 #include <c10/util/irange.h>
9 #include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h>
10 #include <torch/csrc/jit/codegen/fuser/cuda/resource_strings.h>
11 #include <torch/csrc/jit/jit_log.h>
12 #include <torch/csrc/jit/tensorexpr/analysis.h>
13 #include <torch/csrc/jit/tensorexpr/cuda_random.h>
14 #include <torch/csrc/jit/tensorexpr/eval.h>
15 #include <torch/csrc/jit/tensorexpr/exceptions.h>
16 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
17 #include <torch/csrc/jit/tensorexpr/registerizer.h>
18
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/NativeFunctions.h>
21 #else
22 #include <ATen/ops/empty_strided_native.h>
23 #endif
24
25 #include <unordered_map>
26 #include <utility>
27
28 namespace torch::jit::tensorexpr {
29
30 // A RAII wrapper to manage a variable and name pair in the look-up table.
31 // TODO: move this to a more shared place.
32 class ScopedVarName {
33 public:
ScopedVarName(VarNameMap * mapping,const VarPtr & var,const std::string & name)34 ScopedVarName(VarNameMap* mapping, const VarPtr& var, const std::string& name)
35 : mapping_(mapping), var_(var) {
36 auto iter = mapping->find(var);
37 if (iter != mapping->end()) {
38 throw std::runtime_error("Duplicate var entry: " + var->name_hint());
39 }
40 mapping->insert(std::make_pair(var, name));
41 }
42
ScopedVarName(UniqueNameManager * manager,const VarPtr & var,const std::string & name)43 ScopedVarName(
44 UniqueNameManager* manager,
45 const VarPtr& var,
46 const std::string& name)
47 : ScopedVarName(&manager->unique_name_mapping_, var, name) {}
48
49 ScopedVarName(const ScopedVarName&) = delete;
50 ScopedVarName& operator=(const ScopedVarName&) = delete;
51
~ScopedVarName()52 ~ScopedVarName() noexcept(false) {
53 mapping_->erase(var_);
54 }
55
56 private:
57 VarNameMap* mapping_ = nullptr;
58 VarPtr var_ = nullptr;
59 };
60
is_zero(const ExprPtr & expr)61 static bool is_zero(const ExprPtr& expr) {
62 auto v = intValue(expr);
63 return v && *v == 0;
64 }
65
nvrtc()66 static const at::cuda::NVRTC& nvrtc() {
67 return at::globalContext().getNVRTC();
68 }
69
dtypeToCppString(const Dtype & dtype)70 std::string CudaPrinter::dtypeToCppString(const Dtype& dtype) {
71 switch (dtype.scalar_type()) {
72 case ScalarType::Bool:
73 return "bool";
74 case ScalarType::Half:
75 return "half";
76 case ScalarType::BFloat16:
77 return fuser::cuda::bfloat16_type_string;
78 case ScalarType::Char:
79 return "char";
80 case ScalarType::Byte:
81 return "unsigned char";
82 case ScalarType::Short:
83 return "short";
84 case ScalarType::Long:
85 return "long long";
86 default:
87 return dtype.ToCppString();
88 }
89 }
90
visit(const FreePtr & v)91 void CudaAnalysis::visit(const FreePtr& v) {
92 if (thread_local_bufs_.count(v->buffer_var()) == 0 &&
93 cross_block_bufs_.count(v->buffer_var()) == 0) {
94 throw std::runtime_error("Global free not supported yet");
95 }
96 }
97
visit(const AllocatePtr & v)98 void CudaAnalysis::visit(const AllocatePtr& v) {
99 StmtPtr p = v->get_parent();
100 while (p) {
101 ForPtr for_v = to<For>(p);
102 if (for_v) {
103 if (for_v->loop_options().is_gpu_block_index()) {
104 // TODO: This isn't right if there's a thread index at a higher level
105 // than this.
106 cross_block_bufs_.insert(v->buffer_var());
107 return;
108 } else if (for_v->loop_options().is_gpu_thread_index()) {
109 thread_local_bufs_.insert(v->buffer_var());
110 return;
111 }
112 }
113 p = p->get_parent();
114 }
115 throw std::runtime_error("Global alloc not supported yet");
116 }
117
visit(const PlacementAllocatePtr & v)118 void CudaAnalysis::visit(const PlacementAllocatePtr& v) {
119 throw std::runtime_error("Memory reuse not supported yet");
120 }
121
visit(const ForPtr & v)122 void CudaAnalysis::visit(const ForPtr& v) {
123 // Recurse first.
124 v->body()->accept(this);
125
126 const LoopOptions& loop_options = v->loop_options();
127 if (loop_options.is_gpu_block_index()) {
128 int gpu_block_index = loop_options.gpu_block_index();
129 if (gpu_block_index >= 3) {
130 throw std::runtime_error("support only 3D gpu_block_index");
131 }
132 ExprPtr prev = nullptr;
133 if (gpu_block_extents_.size() <= static_cast<size_t>(gpu_block_index)) {
134 gpu_block_extents_.resize(gpu_block_index + 1);
135 } else {
136 prev = gpu_block_extents_[gpu_block_index];
137 }
138 if (!is_zero(v->start())) {
139 throw std::runtime_error(
140 "start must be zero for gpu_block_index: " +
141 std::to_string(v->start()));
142 }
143
144 // NOLINTNEXTLINE(bugprone-branch-clone)
145 if (prev == nullptr) {
146 gpu_block_extents_[gpu_block_index] = v->stop();
147 } else if (prev->isConstant() && immediateEquals(prev, 1)) {
148 // extents must be positive so if the current extent is 1 then even if the
149 // stop is symbolic it's the max.
150 gpu_block_extents_[gpu_block_index] = v->stop();
151 } else {
152 gpu_block_extents_[gpu_block_index] =
153 IRSimplifier::simplify(alloc<Max>(prev, v->stop(), true));
154 }
155 } else if (loop_options.is_gpu_thread_index()) {
156 int gpu_thread_index = loop_options.gpu_thread_index();
157 if (gpu_thread_index >= 3) {
158 throw std::runtime_error("support only 3D gpu_thread_index");
159 }
160 ExprPtr prev = nullptr;
161 if (gpu_thread_extents_.size() <= static_cast<size_t>(gpu_thread_index)) {
162 gpu_thread_extents_.resize(gpu_thread_index + 1);
163 } else {
164 prev = gpu_thread_extents_[gpu_thread_index];
165 }
166 if (!is_zero(v->start())) {
167 throw std::runtime_error(
168 "start must be zero for gpu_thread_index: " +
169 std::to_string(v->start()));
170 }
171
172 // NOLINTNEXTLINE(bugprone-branch-clone)
173 if (prev == nullptr) {
174 gpu_thread_extents_[gpu_thread_index] = v->stop();
175 } else if (prev->isConstant() && immediateEquals(prev, 1)) {
176 // extents must be positive so if the current extent is 1 then even if the
177 // stop is symbolic it's the max.
178 gpu_thread_extents_[gpu_thread_index] = v->stop();
179 } else {
180 gpu_thread_extents_[gpu_thread_index] =
181 IRSimplifier::simplify(alloc<Max>(prev, v->stop(), true));
182 }
183 }
184 }
185
print_flat_alloc(const AllocatePtr & alloc)186 void CudaPrinter::print_flat_alloc(const AllocatePtr& alloc) {
187 std::vector<ExprPtr> dims = alloc->dims();
188 // TODO: this should be merged with the storage flattener.
189 int64_t flat_size = 1;
190 for (const auto& dim : dims) {
191 auto dim_i = intValue(dim);
192 if (dim_i) {
193 flat_size *= *dim_i;
194 } else {
195 throw std::runtime_error("Only integer dimensions are supported for now");
196 }
197 }
198 os() << dtypeToCppString(alloc->dtype()) << " " << (*alloc->buffer_var())
199 << "[" << flat_size << "];" << '\n';
200 }
201
visit(const AllocatePtr & v)202 void CudaPrinter::visit(const AllocatePtr& v) {
203 // TODO: handle dynamic shapes here.
204 if (cuda_analysis_->cross_block_bufs().count(v->buffer_var()) != 0) {
205 emitIndent();
206 os() << "__shared__ ";
207 print_flat_alloc(v);
208 return;
209 }
210
211 if (cuda_analysis_->thread_local_bufs().count(v->buffer_var()) != 0) {
212 emitIndent();
213 print_flat_alloc(v);
214 return;
215 }
216
217 throw std::runtime_error("Encountered Alloc not local to block or thread");
218 }
219
visit(const FreePtr & v)220 void CudaPrinter::visit(const FreePtr& v) {
221 // do nothing
222 }
223
visit(const ForPtr & v)224 void CudaPrinter::visit(const ForPtr& v) {
225 IRPrinter::visit(v);
226 }
227
visit(const CastPtr & v)228 void CudaPrinter::visit(const CastPtr& v) {
229 std::string castFn = v->dtype().scalar_type() == ScalarType::Half
230 ? "__float2half"
231 : v->dtype().scalar_type() == ScalarType::BFloat16 ? "__float2bfloat16"
232 : v->src_value()->dtype().scalar_type() == ScalarType::Half
233 ? "__half2float"
234 : v->src_value()->dtype().scalar_type() == ScalarType::BFloat16
235 ? "__bfloat162float"
236 : ("(" + dtypeToCppString(v->dtype()) + ")");
237 os() << castFn << "(";
238 v->src_value()->accept(this);
239 os() << ")";
240 }
241
visit(const IntrinsicsPtr & v)242 void CudaPrinter::visit(const IntrinsicsPtr& v) {
243 if (v->op_type() == IntrinsicsOp::kRand) {
244 os() << "Uint32ToFloat(" << *rand_func_ << "())";
245 return;
246 }
247
248 std::string func_name = v->func_name();
249
250 // get type of resulting expression.
251 ScalarType returnType = v->param(0)->dtype().scalar_type();
252 for (size_t i = 1; i < v->nparams(); ++i) {
253 returnType = promoteTypes(returnType, v->param(i)->dtype().scalar_type());
254 }
255
256 if (returnType == ScalarType::Half || returnType == ScalarType::Float) {
257 func_name = func_name + "f";
258 }
259 if (v->op_type() == IntrinsicsOp::kAbs &&
260 !c10::isIntegralType(returnType, true)) {
261 // since kAbs's func_name is `abs`, prefix `f` for floating point
262 func_name = "f" + func_name;
263 }
264 if (v->op_type() == IntrinsicsOp::kIsNan) {
265 func_name = "isnan";
266 }
267
268 os() << func_name << "(";
269 for (const auto i : c10::irange(v->nparams())) {
270 if (i > 0) {
271 os() << ", ";
272 }
273 os() << *v->param(i);
274 }
275 os() << ")";
276 }
277
visit(const ExternalCallPtr & v)278 void CudaPrinter::visit(const ExternalCallPtr& v) {
279 throw unimplemented_lowering(v);
280 }
281
visit(const LoadPtr & v)282 void CudaPrinter::visit(const LoadPtr& v) {
283 // TODO: find a better metric in using ldg or not. Support different dtypes.
284 // Detects whether the load target is also a store target.
285 // TODO: this is currently too wide. It detects whether a store-target
286 // exists within the program. In fact, this check is only necessary within a
287 // kernel.
288 if (v->indices().empty()) {
289 os() << *v->base_handle();
290 return;
291 }
292 if (v->dtype().scalar_type() == ScalarType::Bool ||
293 v->dtype().scalar_type() == ScalarType::Half ||
294 v->dtype().scalar_type() == ScalarType::BFloat16) {
295 // There's no __ldg overload for bool or half.
296 os() << *v->base_handle() << "[" << *v->flat_index() << "]";
297 return;
298 }
299 if (cuda_analysis_->is_buf_store_target(v->buf())) {
300 // Cuda __ldg can only be applied on read-only buffers.
301 os() << *v->base_handle() << "[" << *v->flat_index() << "]";
302 return;
303 }
304 os() << "__ldg(" << *v->base_handle() << " + " << *v->flat_index() << ")";
305 }
306
307 // TODO: maybe this should be a more shared location?
308 // TODO: investigate how "ExprPtr" can be implicitly converted to "ExprHandle"
309 // as a bool.
CheckEqual(const ExprPtr & lhs,const ExprPtr & rhs)310 static bool CheckEqual(const ExprPtr& lhs, const ExprPtr& rhs) {
311 // The fast path. Checks if the pointers are the same.
312 if (lhs == rhs) {
313 return true;
314 }
315 ExprHandle diff = Sub::make(ExprHandle(lhs), ExprHandle(rhs));
316 ExprHandle diff_s = IRSimplifier::simplify(diff);
317 return immediateEquals(diff_s.node(), 0);
318 }
319
320 class AtomicAddFuser : public IRMutator {
321 public:
AtomicAddFuser(const std::unordered_set<VarPtr> & thread_local_bufs,const GPUMetaVarRewriter & metavars)322 AtomicAddFuser(
323 const std::unordered_set<VarPtr>& thread_local_bufs,
324 const GPUMetaVarRewriter& metavars)
325 : thread_local_bufs_(thread_local_bufs) {
326 const std::vector<ExprPtr>& block_extents = metavars.gpu_block_extents();
327 const std::vector<VarPtr>& block_vars = metavars.gpu_block_vars();
328 for (size_t i = 0; i < block_extents.size(); ++i) {
329 MetaVarExtent extent{block_extents[i], false};
330 if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
331 extent.trivial = true;
332 } else {
333 nontrivial_metavars_.insert(block_vars[i]);
334 }
335 metavars_[block_vars[i]] = extent;
336 }
337
338 const std::vector<ExprPtr>& thread_extents = metavars.gpu_thread_extents();
339 const std::vector<VarPtr>& thread_vars = metavars.gpu_thread_vars();
340 for (size_t i = 0; i < thread_extents.size(); ++i) {
341 MetaVarExtent extent{thread_extents[i], false};
342 if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
343 extent.trivial = true;
344 } else {
345 nontrivial_metavars_.insert(thread_vars[i]);
346 }
347 metavars_[thread_vars[i]] = extent;
348 }
349 }
350
mutate(const StorePtr & v)351 StmtPtr mutate(const StorePtr& v) override {
352 BufPtr buf = v->buf();
353
354 // Thread locals never need to be atomic.
355 if (thread_local_bufs_.count(buf->base_handle()) != 0) {
356 return v;
357 }
358
359 ScalarType dtype = v->value()->dtype().scalar_type();
360 if (dtype != ScalarType::Float && dtype != ScalarType::Double) {
361 return v;
362 }
363 AddPtr add_v = to<Add>(v->value());
364 if (!add_v) {
365 return v;
366 }
367 LoadPtr load_v = to<Load>(add_v->lhs());
368 if (!load_v) {
369 return v;
370 }
371 if (v->base_handle() != load_v->base_handle()) {
372 return v;
373 }
374 if (v->indices().empty() && load_v->indices().empty()) {
375 return v;
376 }
377 bool index_equal = CheckEqual(v->flat_index(), load_v->flat_index());
378 if (!index_equal) {
379 return v;
380 }
381
382 // TODO: this checks that the metavars occur directly as an index, but this
383 // is pessimistic, blockIdx.x + 1 is fine too if there is no overlapping.
384 std::unordered_set<VarPtr> vars_to_find = nontrivial_metavars_;
385 for (const ExprPtr& e : v->indices()) {
386 if (VarPtr v = to<Var>(e)) {
387 vars_to_find.erase(v);
388 }
389 }
390
391 if (vars_to_find.empty()) {
392 // All metavars accounted for.
393 return v;
394 }
395
396 return alloc<AtomicAdd>(buf, v->indices(), add_v->rhs());
397 }
398
399 private:
400 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
401 const std::unordered_set<VarPtr>& thread_local_bufs_;
402 struct MetaVarExtent {
403 ExprPtr expr{nullptr};
404 bool trivial{false};
405 };
406 std::unordered_map<VarPtr, MetaVarExtent> metavars_;
407 std::unordered_set<VarPtr> nontrivial_metavars_;
408 };
409
visit(const StorePtr & v)410 void CudaPrinter::visit(const StorePtr& v) {
411 emitIndent();
412 if (v->indices().empty()) {
413 os() << *v->base_handle() << " = ";
414 } else {
415 os() << *v->base_handle() << "[" << *v->flat_index() << "] = ";
416 }
417 os() << *v->value() << ";";
418 os() << '\n';
419 }
420
visit(const AtomicAddPtr & v)421 void CudaPrinter::visit(const AtomicAddPtr& v) {
422 emitIndent();
423 if (cuda_analysis_->thread_local_bufs().count(v->base_handle()) > 0) {
424 // atomicAdd only works on global and shared memory
425 os() << *v->base_handle() << "[" << *v->flat_index()
426 << "] += " << *v->value() << ";";
427 } else {
428 os() << "atomicAdd(&" << *v->base_handle() << "[" << *v->flat_index() << "]"
429 << ", " << *v->value() << ");";
430 }
431 os() << '\n';
432 }
433
visit(const MaxPtr & v)434 void CudaPrinter::visit(const MaxPtr& v) {
435 if (v->dtype().is_integral()) {
436 os() << "max(";
437 } else {
438 os() << "maximum(";
439 }
440 v->lhs()->accept(this);
441 os() << ",";
442 v->rhs()->accept(this);
443 os() << ")";
444 }
445
visit(const MinPtr & v)446 void CudaPrinter::visit(const MinPtr& v) {
447 if (v->dtype().is_integral()) {
448 os() << "min(";
449 } else {
450 os() << "minimum(";
451 }
452 v->lhs()->accept(this);
453 os() << ",";
454 v->rhs()->accept(this);
455 os() << ")";
456 }
457
visit(const IfThenElsePtr & v)458 void CudaPrinter::visit(const IfThenElsePtr& v) {
459 os() << "((";
460 v->condition()->accept(this);
461 os() << ") ? ";
462 v->true_value()->accept(this);
463 os() << " : ";
464 v->false_value()->accept(this);
465 os() << ")";
466 }
467
visit(const BlockPtr & v)468 void CudaPrinter::visit(const BlockPtr& v) {
469 os() << "{" << '\n';
470 indent_++;
471
472 for (const StmtPtr& s : v->stmts()) {
473 s->accept(this);
474 }
475
476 indent_--;
477 emitIndent();
478 os() << "}";
479 }
480
visit(const LetPtr & v)481 void CudaPrinter::visit(const LetPtr& v) {
482 emitIndent();
483 os() << dtypeToCppString(v->var()->dtype());
484 os() << " " << *v->var() << " = ";
485 v->value()->accept(this);
486 os() << ";" << '\n';
487 }
488
489 class PrioritizeLoad : public IRMutator {
490 public:
mutate(const LoadPtr & v)491 ExprPtr mutate(const LoadPtr& v) override {
492 // Look at the declaration of this variable for more details.
493 if (nested_if_then_else_ > 0) {
494 return IRMutator::mutate(v);
495 }
496 if (nested_let_) {
497 return IRMutator::mutate(v);
498 }
499 if (thread_local_bufs_.count(v->base_handle()) > 0) {
500 return IRMutator::mutate(v);
501 }
502 if (v->indices().empty()) {
503 return IRMutator::mutate(v);
504 }
505 if (nested_store_) {
506 if (v->base_handle() == nested_store_->buf()->base_handle() &&
507 v->indices().size() == nested_store_->indices().size()) {
508 // also check indices
509 bool same = true;
510 for (const auto i : c10::irange(v->indices().size())) {
511 if (!exprEquals(v->indices()[i], nested_store_->indices()[i])) {
512 same = false;
513 break;
514 }
515 }
516 if (same) {
517 return IRMutator::mutate(v);
518 }
519 } else if (nested_store_->indices().empty()) {
520 return IRMutator::mutate(v);
521 }
522 }
523
524 MemLoadList& load_list = load_stack_.back();
525 VarPtr load_new_var = alloc<Var>("v", v->dtype());
526 ExprPtr new_value = IRMutator::mutate(v);
527 load_list.emplace_back(load_new_var, new_value);
528
529 return load_new_var;
530 }
531
mutate(const CastPtr & v)532 ExprPtr mutate(const CastPtr& v) override {
533 LoadPtr src_load = to<Load>(v->src_value());
534 ExprPtr new_src = v->src_value()->accept_mutator(this);
535 VarPtr new_var = to<Var>(new_src);
536 if (!src_load || !new_var) {
537 return alloc<Cast>(v->dtype(), new_src);
538 }
539
540 // We just did the prioritize load, let's fold in the Cast.
541 MemLoadList& load_list = load_stack_.back();
542 assert(!load_list.empty());
543 auto pair = load_list.back();
544 assert(pair.first == new_var);
545 load_list.pop_back();
546
547 new_var = alloc<Var>("v", v->dtype());
548 ExprPtr new_value = alloc<Cast>(v->dtype(), pair.second);
549 load_list.emplace_back(new_var, new_value);
550 return new_var;
551 }
552
mutate(const StorePtr & v)553 StmtPtr mutate(const StorePtr& v) override {
554 StorePtr last = nested_store_;
555 nested_store_ = v;
556 StmtPtr s = IRMutator::mutate(v);
557 nested_store_ = last;
558 return s;
559 }
560
mutate(const LetPtr & v)561 StmtPtr mutate(const LetPtr& v) override {
562 nested_let_ = true;
563 StmtPtr s = IRMutator::mutate(v);
564 nested_let_ = false;
565 return s;
566 }
567
mutate(const BlockPtr & v)568 StmtPtr mutate(const BlockPtr& v) override {
569 std::list<StmtPtr> stmts = v->stmts();
570 for (const StmtPtr& stmt : stmts) {
571 PushList();
572 StmtPtr stmt_new = stmt->accept_mutator(this);
573
574 AddMemLoadsFromList(v, stmt);
575 PopList();
576
577 if (stmt_new == stmt) {
578 continue;
579 }
580 v->replace_stmt(stmt, stmt_new);
581 }
582 return v;
583 }
584
mutate(const IfThenElsePtr & v)585 ExprPtr mutate(const IfThenElsePtr& v) override {
586 nested_if_then_else_++;
587 ExprPtr new_v = IRMutator::mutate(v);
588 nested_if_then_else_--;
589 return new_v;
590 }
591
592 private:
593 using MemLoadEntry = std::pair<VarPtr, ExprPtr>;
594 using MemLoadList = std::vector<MemLoadEntry>;
595 using MemoryLoadStack = std::vector<MemLoadList>;
596
PushList()597 void PushList() {
598 load_stack_.emplace_back();
599 }
600
PopList()601 void PopList() {
602 load_stack_.pop_back();
603 }
604
AddMemLoadsFromList(const BlockPtr & block,const StmtPtr & last)605 void AddMemLoadsFromList(const BlockPtr& block, const StmtPtr& last) {
606 MemLoadList& load_list = load_stack_.back();
607 if (load_list.empty()) {
608 return;
609 }
610
611 for (auto& pair : load_list) {
612 StmtPtr news = alloc<Let>(pair.first, pair.second);
613 block->insert_stmt_before(news, last);
614 }
615 }
616
617 MemoryLoadStack load_stack_;
618 // TODO: For now, we are not moving the loads with the IfThenElse.
619 // Eventually, we should switch to a more generic structure like:
620 // int v2 = IfThenElse(cond, true_v, false_v) + 2 ->
621 //
622 // int v;
623 // if (cond) {
624 // v = true_v;
625 // } else {
626 // v = false_v;
627 // }
628 // int v2 = v + 2;
629 int nested_if_then_else_{0};
630 StorePtr nested_store_{nullptr};
631 bool nested_let_{false};
632 std::unordered_set<VarPtr> thread_local_bufs_;
633 };
634
GetUniqueFuncName(const std::string & func_prefix)635 std::string CudaCodeGen::GetUniqueFuncName(const std::string& func_prefix) {
636 int64_t counter = 0;
637 std::string name = func_prefix;
638 while (taken_func_names.count(name)) {
639 name = func_prefix + "_" + std::to_string(counter++);
640 }
641
642 taken_func_names.insert(name);
643 return name;
644 }
645
isFullExtent()646 bool GPUMetaVarRewriter::isFullExtent() {
647 {
648 auto& extents = cuda_analysis_->gpu_block_extents();
649 for (int i = 0; i < 3; ++i) {
650 if (!exprEquals(current_block_reach_[i], extents[i])) {
651 return false;
652 }
653 }
654 }
655
656 {
657 auto& extents = cuda_analysis_->gpu_thread_extents();
658 for (int i = 0; i < 3; ++i) {
659 if (!exprEquals(current_thread_reach_[i], extents[i])) {
660 return false;
661 }
662 }
663 }
664
665 return true;
666 }
667
mutate(const ForPtr & v)668 StmtPtr GPUMetaVarRewriter::mutate(const ForPtr& v) {
669 StmtPtr body = v->body();
670 ExprPtr old_reach = nullptr;
671 const LoopOptions& loop_options = v->loop_options();
672 if (loop_options.is_gpu_block_index()) {
673 int gpu_block_index = loop_options.gpu_block_index();
674 if (gpu_block_index >= 3) {
675 throw std::runtime_error("support only 3D gpu_block_index");
676 }
677 old_reach = current_block_reach_[gpu_block_index];
678
679 // Extents must be positive, assume >= 1.
680 if (old_reach->isConstant() && immediateEquals(old_reach, 1)) {
681 current_block_reach_[gpu_block_index] = v->stop();
682 } else {
683 current_block_reach_[gpu_block_index] =
684 IRSimplifier::simplify(alloc<Max>(old_reach, v->stop(), true));
685 }
686
687 VarPtr metaVar = gpu_block_vars_[gpu_block_index];
688 body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
689 } else if (loop_options.is_gpu_thread_index()) {
690 int gpu_thread_index = loop_options.gpu_thread_index();
691 if (gpu_thread_index >= 3) {
692 throw std::runtime_error("support only 3D gpu_thread_index");
693 }
694 old_reach = current_thread_reach_[gpu_thread_index];
695
696 // Extents must be positive, assume >= 1.
697 if (old_reach->isConstant() && immediateEquals(old_reach, 1)) {
698 current_thread_reach_[gpu_thread_index] = v->stop();
699 } else {
700 current_thread_reach_[gpu_thread_index] =
701 IRSimplifier::simplify(alloc<Max>(old_reach, v->stop(), true));
702 }
703
704 VarPtr metaVar = gpu_thread_vars_[gpu_thread_index];
705 body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
706 }
707
708 // Recurse into body block.
709 body = Stmt::clone(body->accept_mutator(this));
710
711 // pop the internal reach off the stack.
712 if (loop_options.is_gpu_block_index()) {
713 current_block_reach_[loop_options.gpu_block_index()] = old_reach;
714 return body;
715 } else if (loop_options.is_gpu_thread_index()) {
716 current_thread_reach_[loop_options.gpu_thread_index()] = old_reach;
717 return body;
718 }
719
720 return v->cloneWithNewBody(body);
721 }
722
mutate(const BlockPtr & v)723 StmtPtr GPUMetaVarRewriter::mutate(const BlockPtr& v) {
724 std::vector<Segment> innerSegments;
725 Segment current;
726
727 auto pushAndReset = [&](bool mask) {
728 if (!current.empty()) {
729 innerSegments.push_back(current);
730 }
731 current.reset(mask);
732 };
733
734 // Here's we're slicing the Block's contents into segments that should have
735 // the same launch reach. Segments are comprised of all statements that aren't
736 // loops - which are their own segments. Some operations, such as threading
737 // and memory ops should never be masked and so also get their own segment.
738 for (const StmtPtr& stmt : *v) {
739 StmtPtr stmt_new = stmt->accept_mutator(this);
740 if (stmt == stmt_new) {
741 stmt_new = Stmt::clone(stmt_new);
742 }
743
744 // Likewise, Allocate and Free should never be masked.
745 if (to<Allocate>(stmt) || to<Free>(stmt)) {
746 pushAndReset(false);
747 }
748
749 // If the current stmt *was* a loop, it's a segment boundary.
750 if (ForPtr f = to<For>(stmt)) {
751 pushAndReset(false);
752 }
753
754 current.stmts().push_back(stmt_new);
755 // if the current segment should not be masked, it's a segment boundary on
756 // the far side as well.
757 if (!current.mask()) {
758 pushAndReset(true);
759 }
760 }
761
762 if (!current.empty()) {
763 innerSegments.push_back(current);
764 }
765
766 // We are max extent in all dimensions, so need no masks at this level.
767 if (isFullExtent()) {
768 // flatten inner segments.
769 std::vector<StmtPtr> stmts;
770 for (auto& v : innerSegments) {
771 for (const auto& s : v.stmts()) {
772 stmts.push_back(s);
773 }
774 }
775
776 return alloc<Block>(stmts);
777 }
778
779 std::vector<StmtPtr> stmts;
780 for (auto& segment : innerSegments) {
781 bool need_sync = false;
782 // We never mask loops, they'll mask their contents.
783 if (!segment.mask()) {
784 TORCH_INTERNAL_ASSERT(segment.stmts().size() == 1, buildErrorMessage());
785 stmts.push_back(segment.stmts()[0]);
786 continue;
787 }
788
789 // If we get here, we must mask since we're not full reach and our direct
790 // child isn't a For.
791 StmtPtr inner = alloc<Block>(segment.stmts());
792 // threads inside blocks.
793 auto& thread_extents = cuda_analysis_->gpu_thread_extents();
794 for (size_t i = 0; i < gpu_thread_vars_.size(); ++i) {
795 if (!exprEquals(current_thread_reach_[i], thread_extents[i])) {
796 need_sync = true;
797 // Mask it against the current dimensions.
798 inner = alloc<Cond>(
799 alloc<CompareSelect>(
800 gpu_thread_vars_[i],
801 current_thread_reach_[i],
802 CompareSelectOperation::kLT),
803 inner,
804 nullptr);
805 }
806 }
807 auto& block_extents = cuda_analysis_->gpu_block_extents();
808 for (size_t i = 0; i < gpu_block_vars_.size(); ++i) {
809 if (!exprEquals(current_block_reach_[i], block_extents[i])) {
810 // Mask it against the current dimensions.
811 inner = alloc<Cond>(
812 alloc<CompareSelect>(
813 gpu_block_vars_[i],
814 current_block_reach_[i],
815 CompareSelectOperation::kLT),
816 inner,
817 nullptr);
818 }
819 }
820
821 if (need_sync) {
822 stmts.push_back(alloc<SyncThreads>());
823 }
824 stmts.push_back(inner);
825 if (need_sync) {
826 stmts.push_back(alloc<SyncThreads>());
827 }
828 }
829
830 return alloc<Block>(stmts);
831 }
832
operator <<(std::ostream & out,const std::vector<ExprPtr> & exprs)833 static std::ostream& operator<<(
834 std::ostream& out,
835 const std::vector<ExprPtr>& exprs) {
836 size_t i = 0;
837 for (const auto& expr : exprs) {
838 if (i++ > 0) {
839 out << ", ";
840 }
841 out << *expr;
842 }
843 return out;
844 }
845
846 static const char* device_resource_string = R"(
847 #define NAN __int_as_float(0x7fffffff)
848 #define POS_INFINITY __int_as_float(0x7f800000)
849 #define NEG_INFINITY __int_as_float(0xff800000)
850
851 )";
852
853 static const char* shared_resource_string = R"(
854 template<typename T>
855 __device__ T maximum(T a, T b) {
856 return isnan(a) ? a : (a > b ? a : b);
857 }
858
859 template<typename T>
860 __device__ T minimum(T a, T b) {
861 return isnan(a) ? a : (a < b ? a : b);
862 }
863
864 )";
865
Initialize()866 void CudaCodeGen::Initialize() {
867 // TODO: handle multiple kernels.
868 // TODO: handle dynamic dimension.
869 // TODO: call nvrtc.
870 // TODO: merge HasRand with CudaAnalysis.
871 GenericIntrinsicsExpander intrinsics_expander;
872 apply_mutator(&intrinsics_expander);
873
874 HasRand has_rand_func(stmt());
875 has_random_ = has_rand_func.has_rand();
876 cuda_analysis_ = std::make_unique<CudaAnalysis>();
877 printer_ =
878 std::make_unique<CudaPrinter>(&oss_, cuda_analysis_.get(), has_random_);
879 metavar_rewriter_ =
880 std::make_unique<GPUMetaVarRewriter>(cuda_analysis_.get());
881
882 // Check whether the statement uses the Half type, if so add the
883 // half_support_literal.
884 StmtPtr stmt_v = stmt();
885 HalfChecker halfChecker(buffer_args());
886 stmt_v->accept(&halfChecker);
887
888 os() << device_resource_string << shared_resource_string;
889
890 if (has_random_) {
891 os() << philox_random_string << '\n';
892 }
893
894 if (halfChecker.hasHalf()) {
895 os() << fuser::cuda::half_support_literal << '\n';
896 }
897 if (halfChecker.hasBFloat16()) {
898 os() << fuser::cuda::bfloat16_support_literal << '\n';
899 }
900
901 std::string func_name = GetUniqueFuncName(kernel_func_name());
902 os() << "extern \"C\" __global__" << '\n';
903 #if defined(USE_ROCM)
904 // CUDA has a default limit of threads per block (=flat work group size)
905 // of 1024, but ROCm uses 256 by default. At the time of writing
906 // (#45506), I am unaware of a stricter limit that TensorExpr imposes
907 // (maybe for perf),so I use 1024 as maximum flat work group size.
908 // We put a minimum value of 1, this is also used by hip (ROCm 3.8) in
909 // the __launch_bound__ implementation. The arguments for the attribute
910 // are (min, max), for details see the documentation at
911 // https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
912 os() << "__attribute__((amdgpu_flat_work_group_size(1, 1024)))" << std::endl;
913 #endif
914 os() << "void " << func_name << "(";
915 const std::vector<BufferArg> buffer_args = this->buffer_args();
916 for (size_t i = 0; i < buffer_args.size(); i++) {
917 if (i > 0) {
918 os() << ", ";
919 }
920 const BufferArg& buffer_arg = buffer_args[i];
921 VarPtr var = buffer_arg.var();
922 Dtype dtype = buffer_arg.dtype();
923
924 os() << printer_->dtypeToCppString(dtype)
925 << (buffer_arg.isVar() ? " " : "* ")
926 << name_manager()->get_unique_name(var);
927 }
928 VarPtr rand_seed;
929 VarPtr rand_offset;
930 if (has_random_) {
931 // TODO: switch to kUint64 when it is available.
932 rand_seed = alloc<Var>("rand_seed", kInt);
933 rand_offset = alloc<Var>("rand_offset", kInt);
934 std::string uint64_str = "unsigned long long";
935 os() << ", " << uint64_str << " " << *rand_seed << ", " << uint64_str << " "
936 << *rand_offset;
937 }
938 os() << ") {";
939 os() << '\n';
940
941 if (has_random_) {
942 VarPtr idx = alloc<Var>("idx", kInt);
943 os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;" << '\n';
944 VarPtr rand_func = printer_->rand_func();
945 os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", "
946 << *rand_offset << ");" << '\n';
947 os() << '\n';
948 }
949
950 stmt_v->accept(cuda_analysis_.get());
951
952 stmt_v = stmt_v->accept_mutator(metavar_rewriter_.get());
953
954 AtomicAddFuser atomic_add_fuser(
955 cuda_analysis_->thread_local_bufs(), *metavar_rewriter_);
956 stmt_v = stmt_v->accept_mutator(&atomic_add_fuser);
957
958 stmt_v = registerize(stmt_v);
959
960 PrioritizeLoad prioritize_load;
961 stmt_v = stmt_v->accept_mutator(&prioritize_load);
962
963 // The registerizer might insert half-type scalars, we don't want this.
964 HalfRewriter hsFix;
965 stmt_v = stmt_v->accept_mutator(&hsFix);
966
967 stmt_v = IRSimplifier::simplify(stmt_v);
968 set_stmt(stmt_v);
969
970 stmt_v->accept(printer_.get());
971 os() << '\n';
972 os() << "}";
973
974 // Check that all block extents had been set.
975 const std::vector<ExprPtr>& gpu_block_extents =
976 metavar_rewriter_->gpu_block_extents();
977 for (size_t i = 0; i < gpu_block_extents.size(); i++) {
978 if (!gpu_block_extents[i]) {
979 throw std::runtime_error("Missing gpu_block_index: " + std::to_string(i));
980 }
981 }
982
983 // Precompute block and thread extents for call_with_numel(). If
984 // precomputation can't be done (block/thread extents aren't
985 // constant), then disallow call_with_numel.
986 auto block_extents = metavar_rewriter_->gpu_block_extents();
987 auto thread_extents = metavar_rewriter_->gpu_thread_extents();
988 bool canCallWithNumel =
989 !has_random_ && !block_extents.empty() && !thread_extents.empty();
990 for (size_t i = 1; i < block_extents.size() && canCallWithNumel; i++) {
991 canCallWithNumel = canCallWithNumel && block_extents[i]->isConstant() &&
992 immediateAs<int>(block_extents[i]) == 1;
993 }
994 for (size_t i = 1; i < thread_extents.size() && canCallWithNumel; i++) {
995 canCallWithNumel = canCallWithNumel && thread_extents[i]->isConstant() &&
996 immediateAs<int>(thread_extents[i]) == 1;
997 }
998 if (canCallWithNumel && thread_extents[0]->isConstant()) {
999 // We assume block_extents[0] is output.numel()/thread_block_size_.
1000 thread_block_size_ = immediateAs<int>(thread_extents[0]);
1001 } else {
1002 // Disable call_with_numel.
1003 thread_block_size_ = -1;
1004 }
1005
1006 // Build an LLVM based eval expression for the extents
1007 block_extents_eval_.reserve(block_extents.size());
1008 std::vector<BufferArg> extents_buffer_args;
1009
1010 // We need to extract the args that are used in the thread and block extents
1011 // from bufferArgs and only use those for the `ExprEval` below. Without this,
1012 // bufferArgs might contain arbitrary types that are not handled by LLVM and
1013 // hence would result in an error.
1014 std::unordered_set<VarPtr> vars_in_extents;
1015 for (const auto& be : block_extents) {
1016 auto v = VarFinder::find(be);
1017 vars_in_extents.insert(v.begin(), v.end());
1018 }
1019 for (const auto& te : thread_extents) {
1020 auto v = VarFinder::find(te);
1021 vars_in_extents.insert(v.begin(), v.end());
1022 }
1023 for (const size_t i : c10::irange(buffer_args.size())) {
1024 if (vars_in_extents.count(buffer_args[i].var())) {
1025 extents_buffer_args.push_back(buffer_args[i]);
1026 arg_pos_in_extents_.push_back(true);
1027 } else {
1028 arg_pos_in_extents_.push_back(false);
1029 }
1030 }
1031 for (const auto& be : block_extents) {
1032 #ifdef TORCH_ENABLE_LLVM
1033 block_extents_eval_.emplace_back(
1034 ExprEval<LLVMCodeGen>(ExprHandle(be), extents_buffer_args));
1035 #else
1036 block_extents_eval_.emplace_back(ExprHandle(be), extents_buffer_args);
1037 #endif
1038 }
1039 thread_extents_eval_.reserve(thread_extents.size());
1040 for (const auto& te : thread_extents) {
1041 #ifdef TORCH_ENABLE_LLVM
1042 thread_extents_eval_.emplace_back(
1043 ExprEval<LLVMCodeGen>(ExprHandle(te), extents_buffer_args));
1044 #else
1045 thread_extents_eval_.emplace_back(ExprHandle(te), extents_buffer_args);
1046 #endif
1047 }
1048
1049 GRAPH_DEBUG(
1050 "Fused TE CUDA kernel:\n",
1051 oss_.str(),
1052 "\n",
1053 "gpu_block_extents: (",
1054 metavar_rewriter_->gpu_block_extents(),
1055 ")\n",
1056 "gpu_thread_extents: (",
1057 metavar_rewriter_->gpu_thread_extents(),
1058 ")");
1059
1060 CompileToNVRTC(oss_.str(), func_name);
1061 }
1062
call_with_numel(void ** args,int64_t numel)1063 void CudaCodeGen::call_with_numel(void** args, int64_t numel) {
1064 if (C10_UNLIKELY(numel == 0)) {
1065 return;
1066 }
1067 if (C10_UNLIKELY(thread_block_size_ <= 0)) {
1068 TORCH_INTERNAL_ASSERT(
1069 thread_block_size_ >= 0,
1070 "call_with_numel() requires a precomputed thread block size");
1071 }
1072
1073 auto const& buffer_args = this->buffer_args();
1074 size_t gpu_block_extents =
1075 (numel + thread_block_size_ - 1) / thread_block_size_;
1076 size_t gpu_thread_extents = thread_block_size_;
1077
1078 // In CUDA we need to pass pointers to pointers for buffers, thus we need to
1079 // go over args and add an extra indirection for such non-scalar
1080 // arguments.
1081 // Why? See some details here:
1082 // https://stackoverflow.com/questions/34388712/cannot-understand-how-jcuda-culaunchkernel-work
1083 std::vector<void*> ptr_to_args(buffer_args.size());
1084 for (size_t i = 0; i < buffer_args.size(); i++) {
1085 ptr_to_args[i] =
1086 buffer_args[i].isVar() ? args[i] : const_cast<void**>(&args[i]);
1087 }
1088
1089 const auto device = this->device().index();
1090 const auto prior_device = at::cuda::current_device();
1091 if (prior_device != device) {
1092 at::cuda::set_device(device);
1093 }
1094
1095 auto stream = at::cuda::getCurrentCUDAStream();
1096 at::cuda::jit::initializeCudaContext();
1097 AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
1098 function_,
1099 gpu_block_extents,
1100 1,
1101 1,
1102 gpu_thread_extents,
1103 1,
1104 1,
1105 0,
1106 stream,
1107 ptr_to_args.data(),
1108 nullptr));
1109
1110 if (prior_device != device) {
1111 at::cuda::set_device(prior_device);
1112 }
1113 }
1114
call_raw(const std::vector<void * > & raw_args)1115 void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
1116 auto const& buffer_args = this->buffer_args();
1117
1118 // TODO: move as much of this into the constructors.
1119 const std::vector<ExprPtr>& gpu_block_extents =
1120 metavar_rewriter_->gpu_block_extents();
1121 const std::vector<ExprPtr>& gpu_thread_extents =
1122 metavar_rewriter_->gpu_thread_extents();
1123 if (gpu_block_extents.size() > 3 || gpu_thread_extents.size() > 3) {
1124 throw malformed_input(
1125 "cuda_codegen: block or thread extent greater than 3D");
1126 }
1127
1128 std::vector<int64_t> gpu_block_extents_v(3, 1);
1129 std::vector<int64_t> gpu_thread_extents_v(3, 1);
1130
1131 // evaluate all the block/thread extents into values
1132 // TODO: eventually, codegen these calculations and make them part of the
1133 // module.
1134 std::vector<void*> extent_args;
1135 size_t raw_args_size = raw_args.size();
1136 extent_args.reserve(raw_args_size);
1137 for (size_t i = 0; i < raw_args_size; ++i) {
1138 if (arg_pos_in_extents_[i]) {
1139 extent_args.push_back(raw_args[i]);
1140 }
1141 }
1142 for (size_t i = 0; i < gpu_block_extents.size(); i++) {
1143 if (gpu_block_extents[i]->isConstant()) {
1144 gpu_block_extents_v[i] = immediateAs<int64_t>(gpu_block_extents[i]);
1145 continue;
1146 }
1147 {
1148 // invocation of block_extents_eval_ isn't thread safe and this function
1149 // may be invoked by multiple threads
1150 std::lock_guard<std::mutex> guard(eval_lock_);
1151 gpu_block_extents_v[i] =
1152 block_extents_eval_[i].value<int64_t>(extent_args);
1153 }
1154 }
1155 for (size_t i = 0; i < gpu_thread_extents.size(); i++) {
1156 if (gpu_thread_extents[i]->isConstant()) {
1157 gpu_thread_extents_v[i] = immediateAs<int64_t>(gpu_thread_extents[i]);
1158 continue;
1159 }
1160 {
1161 std::lock_guard<std::mutex> guard(eval_lock_);
1162 gpu_thread_extents_v[i] =
1163 thread_extents_eval_[i].value<int64_t>(extent_args);
1164 }
1165 }
1166
1167 // Skip launching the kernel if there are no elements to process.
1168 for (auto extent : gpu_block_extents_v) {
1169 if (extent == 0) {
1170 return;
1171 }
1172 }
1173
1174 auto ptr_count = buffer_args.size();
1175 // If the kernel has a rand call in it, add two extra arguments for random
1176 // seed and offset.
1177 if (has_random_) {
1178 ptr_count += 2;
1179 }
1180 std::vector<void*> ptr_to_args(ptr_count);
1181
1182 // In CUDA we need to pass pointers to pointers for buffers, thus we need to
1183 // go over raw_args and add an extra indirection for such non-scalar
1184 // arguments.
1185 // Why? See some details here:
1186 // https://stackoverflow.com/questions/34388712/cannot-understand-how-jcuda-culaunchkernel-work
1187 for (size_t i = 0; i < buffer_args.size(); i++) {
1188 ptr_to_args[i] =
1189 buffer_args[i].isVar() ? raw_args[i] : const_cast<void**>(&raw_args[i]);
1190 }
1191
1192 if (has_random_) {
1193 uint64_t rand_seed = uint64_t(-1);
1194 uint64_t rand_offset = uint64_t(-1);
1195 auto gen = at::cuda::detail::getDefaultCUDAGenerator();
1196 // TODO: total hack. Switch to numel when it is available.
1197 int64_t total_elements_per_thread = (1LL << 28);
1198 {
1199 std::lock_guard<std::mutex> lock(gen.mutex());
1200 auto philox_engine_inputs =
1201 at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
1202 total_elements_per_thread);
1203 rand_seed = philox_engine_inputs.first;
1204 rand_offset = philox_engine_inputs.second;
1205 }
1206 ptr_to_args[buffer_args.size()] = &rand_seed;
1207 ptr_to_args[buffer_args.size() + 1] = &rand_offset;
1208 }
1209
1210 auto prior_device = at::cuda::current_device();
1211 if (prior_device != this->device().index()) {
1212 at::cuda::set_device(this->device().index());
1213 }
1214 // Launch the kernels
1215 auto stream = at::cuda::getCurrentCUDAStream();
1216 at::cuda::jit::initializeCudaContext();
1217 AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
1218 function_,
1219 gpu_block_extents_v[0],
1220 gpu_block_extents_v[1],
1221 gpu_block_extents_v[2],
1222 gpu_thread_extents_v[0],
1223 gpu_thread_extents_v[1],
1224 gpu_thread_extents_v[2],
1225 0,
1226 stream,
1227 ptr_to_args.data(),
1228 nullptr));
1229
1230 if (prior_device != this->device().index()) {
1231 at::cuda::set_device(prior_device);
1232 }
1233 }
1234
call(const std::vector<CallArg> & args)1235 void CudaCodeGen::call(const std::vector<CallArg>& args) {
1236 if (args.size() != buffer_args().size()) {
1237 throw malformed_input("cuda_codegen: wrong number of args in call");
1238 }
1239
1240 auto const& buffer_args = this->buffer_args();
1241 std::vector<void*> raw_args(buffer_args.size());
1242 for (size_t i = 0; i < buffer_args.size(); i++) {
1243 auto const& bufferArg = buffer_args[i];
1244 auto const& callArg = args[i];
1245 raw_args[i] = argToPtr(bufferArg, callArg);
1246 }
1247 call_raw(raw_args);
1248 }
1249
empty_strided(c10::IntArrayRef size,c10::IntArrayRef stride,std::optional<c10::ScalarType> dtype_opt,std::optional<c10::Layout> layout_opt,std::optional<c10::Device> device_opt,std::optional<bool> pin_memory_opt)1250 at::Tensor CudaCodeGen::empty_strided(
1251 c10::IntArrayRef size,
1252 c10::IntArrayRef stride,
1253 std::optional<c10::ScalarType> dtype_opt,
1254 std::optional<c10::Layout> layout_opt,
1255 std::optional<c10::Device> device_opt,
1256 std::optional<bool> pin_memory_opt) {
1257 c10::DeviceGuard device_guard(device_opt.value());
1258 return at::native::empty_strided_cuda(
1259 size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
1260 }
1261
CompileToNVRTC(const std::string & code,const std::string & func_name)1262 void CudaCodeGen::CompileToNVRTC(
1263 const std::string& code,
1264 const std::string& func_name) {
1265 at::cuda::jit::initializeCudaContext();
1266 // Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
1267 // properly in some scenarios
1268 auto prior_device = at::cuda::current_device();
1269 if (prior_device != this->device().index()) {
1270 at::cuda::set_device(this->device().index());
1271 }
1272 // Acquires device and NVRTC properties (for compile arch and occupancy
1273 // calculations)
1274 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
1275 int major = 0, minor = 0;
1276 bool compile_to_sass = false;
1277 fuser::cuda::codegenOutputQuery(prop, major, minor, compile_to_sass);
1278
1279 // Creates the NVRTC program
1280 nvrtcProgram program{nullptr};
1281 AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
1282 &program, code.c_str(), nullptr, 0, nullptr, nullptr));
1283
1284 #if defined(USE_ROCM)
1285 std::vector<const char*> args = {"--std=c++17"};
1286 args.push_back("-hip-pch");
1287 #else
1288 const std::string compute = std::string("--gpu-architecture=") +
1289 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
1290 // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
1291 // which gives better backwards compatibility to work on older driver,
1292 // (since older driver doesn't necessarily recognize PTX emitted by new
1293 // toolkit);
1294 // Meanwhile, for forward compatibility (future device with
1295 // `compile_to_sass==false`), since SASS are not necessarily compatible,
1296 // we fallback to PTX instead.
1297 (compile_to_sass ? "sm_" : "compute_") +
1298 #else
1299 "compute_" +
1300 #endif
1301 std::to_string(major) + std::to_string(minor);
1302 const std::vector<const char*> args = {
1303 "--std=c++17", compute.c_str(), "-default-device"};
1304 #endif
1305
1306 auto result = nvrtc().nvrtcCompileProgram(
1307 program, static_cast<int>(args.size()), args.data());
1308 if (result != NVRTC_SUCCESS) {
1309 size_t logsize = 0;
1310 AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize));
1311 std::vector<char> log(logsize);
1312 AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data()));
1313 std::stringstream cu;
1314 cu << log.data() << '\n';
1315 cu << "nvrtc compilation failed: " << '\n';
1316 cu << code << '\n';
1317 throw std::runtime_error(cu.str());
1318 }
1319 ResourceGuard holdProgram(
1320 [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
1321 AT_CUDA_NVRTC_CHECK(result);
1322 size_t ptx_size = 0;
1323 std::vector<char> ptx;
1324 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
1325 // compile_to_sass determines whether we are generating SASS or PTX, hence
1326 // the different API.
1327 auto getSize = compile_to_sass
1328 ? at::globalContext().getNVRTC().nvrtcGetCUBINSize
1329 : at::globalContext().getNVRTC().nvrtcGetPTXSize;
1330 auto getFunc = compile_to_sass ? at::globalContext().getNVRTC().nvrtcGetCUBIN
1331 : at::globalContext().getNVRTC().nvrtcGetPTX;
1332 #else
1333 auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
1334 auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
1335 #endif
1336 AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
1337 ptx.resize(ptx_size);
1338 AT_CUDA_NVRTC_CHECK(getFunc(program, ptx.data()));
1339
1340 CUmodule module{nullptr};
1341 AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module, ptx.data()));
1342 AT_CUDA_DRIVER_CHECK(
1343 nvrtc().cuModuleGetFunction(&function_, module, func_name.c_str()));
1344
1345 if (prior_device != this->device().index()) {
1346 at::cuda::set_device(prior_device);
1347 }
1348 }
1349
1350 CudaCodeGen::~CudaCodeGen() = default;
1351
1352 RegisterCodeGen<CudaCodeGen> cuda_codegen_reg("cuda_codegen");
1353
1354 } // namespace torch::jit::tensorexpr
1355