1 #pragma once 2 3 #include <unordered_set> 4 5 #include <ATen/ATen.h> 6 #include <ATen/cuda/CUDAContext.h> 7 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h> 8 #include <c10/cuda/CUDACachingAllocator.h> 9 #include <c10/cuda/CUDAGuard.h> 10 #include <torch/csrc/jit/resource_guard.h> 11 #include <torch/csrc/jit/tensorexpr/codegen.h> 12 #include <torch/csrc/jit/tensorexpr/eval.h> 13 #include <torch/csrc/jit/tensorexpr/ir.h> 14 #include <torch/csrc/jit/tensorexpr/ir_printer.h> 15 #include <torch/csrc/jit/tensorexpr/ir_visitor.h> 16 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h> 17 #include <torch/csrc/jit/tensorexpr/unique_name_manager.h> 18 19 namespace torch::jit::tensorexpr { 20 21 // A class that analyzes the given program relevant for Cuda backends. 22 class CudaAnalysis : public IRVisitor { 23 public: CudaAnalysis()24 CudaAnalysis() { 25 gpu_block_extents_ = {alloc<IntImm>(1), alloc<IntImm>(1), alloc<IntImm>(1)}; 26 gpu_thread_extents_ = { 27 alloc<IntImm>(1), alloc<IntImm>(1), alloc<IntImm>(1)}; 28 } is_buf_store_target(const BufPtr & buf)29 bool is_buf_store_target(const BufPtr& buf) const { 30 return store_targets_.count(buf) > 0; 31 } 32 thread_local_bufs()33 const std::unordered_set<VarPtr>& thread_local_bufs() const { 34 return thread_local_bufs_; 35 } 36 cross_block_bufs()37 const std::unordered_set<VarPtr>& cross_block_bufs() const { 38 return cross_block_bufs_; 39 } 40 gpu_block_extents()41 const std::vector<ExprPtr>& gpu_block_extents() const { 42 return gpu_block_extents_; 43 } 44 gpu_thread_extents()45 const std::vector<ExprPtr>& gpu_thread_extents() const { 46 return gpu_thread_extents_; 47 } 48 49 private: visit(const StorePtr & v)50 void visit(const StorePtr& v) override { 51 store_targets_.insert(v->buf()); 52 } 53 54 void visit(const AllocatePtr& v) override; 55 void visit(const FreePtr& v) override; 56 void visit(const PlacementAllocatePtr& v) override; 57 void visit(const ForPtr& v) override; 58 59 std::unordered_set<BufPtr> store_targets_; 60 std::unordered_set<VarPtr> thread_local_bufs_; 61 std::unordered_set<VarPtr> cross_block_bufs_; 62 63 std::vector<ExprPtr> gpu_block_extents_; 64 std::vector<ExprPtr> gpu_thread_extents_; 65 }; 66 67 // An IRMutator that replaces binding loop options with Cuda metavars, and masks 68 // statements blocks which should execute with less reach than the launch 69 // parameter extent. 70 // 71 // We do this by segmenting each block into chunks which should have the same 72 // execution parameters, then if those params differ from the max mask each dim. 73 class GPUMetaVarRewriter : public IRMutator { 74 public: GPUMetaVarRewriter(const CudaAnalysis * cuda_analysis)75 explicit GPUMetaVarRewriter(const CudaAnalysis* cuda_analysis) 76 : cuda_analysis_(cuda_analysis) { 77 gpu_block_vars_ = { 78 alloc<Var>("blockIdx.x", kInt), 79 alloc<Var>("blockIdx.y", kInt), 80 alloc<Var>("blockIdx.z", kInt)}; 81 gpu_thread_vars_ = { 82 alloc<Var>("threadIdx.x", kInt), 83 alloc<Var>("threadIdx.y", kInt), 84 alloc<Var>("threadIdx.z", kInt)}; 85 86 current_block_reach_ = { 87 alloc<IntImm>(1), alloc<IntImm>(1), alloc<IntImm>(1)}; 88 current_thread_reach_ = { 89 alloc<IntImm>(1), alloc<IntImm>(1), alloc<IntImm>(1)}; 90 } 91 92 StmtPtr mutate(const ForPtr& v) override; 93 StmtPtr mutate(const BlockPtr& v) override; 94 gpu_block_vars()95 const std::vector<VarPtr>& gpu_block_vars() const { 96 return gpu_block_vars_; 97 } 98 gpu_thread_vars()99 const std::vector<VarPtr>& gpu_thread_vars() const { 100 return gpu_thread_vars_; 101 } 102 gpu_block_extents()103 const std::vector<ExprPtr>& gpu_block_extents() const { 104 return cuda_analysis_->gpu_block_extents(); 105 } 106 gpu_thread_extents()107 const std::vector<ExprPtr>& gpu_thread_extents() const { 108 return cuda_analysis_->gpu_thread_extents(); 109 } 110 111 private: 112 // When processing a block, stores the contents of each sub-segment. 113 class Segment { 114 public: reset(bool mask)115 void reset(bool mask) { 116 stmts_.clear(); 117 mask_ = mask; 118 } 119 empty()120 bool empty() const { 121 return stmts_.empty(); 122 } 123 stmts()124 std::vector<StmtPtr>& stmts() { 125 return stmts_; 126 } mask()127 bool mask() { 128 return mask_; 129 } 130 131 private: 132 std::vector<StmtPtr> stmts_; 133 bool mask_{true}; 134 }; 135 136 // Returns true if the current execution scope is equivalent to the launch 137 // parameters. 138 bool isFullExtent(); 139 140 std::vector<VarPtr> gpu_block_vars_; 141 std::vector<VarPtr> gpu_thread_vars_; 142 143 std::vector<ExprPtr> current_block_reach_; 144 std::vector<ExprPtr> current_thread_reach_; 145 146 const CudaAnalysis* cuda_analysis_; 147 }; 148 149 // A class that overrides the underlying IRPrinter to produce Cuda C. 150 class CudaPrinter : public IRPrinter { 151 public: CudaPrinter(std::ostream * os,const CudaAnalysis * cuda_analysis,bool has_random)152 explicit CudaPrinter( 153 std::ostream* os, 154 const CudaAnalysis* cuda_analysis, 155 bool has_random) 156 : IRPrinter(*os), cuda_analysis_(cuda_analysis) { 157 if (has_random) { 158 rand_func_ = alloc<Var>("rand", kHandle); 159 } 160 } 161 162 void visit(const CastPtr& v) override; 163 void visit(const IntrinsicsPtr& v) override; 164 void visit(const ForPtr& v) override; 165 166 void visit(const LoadPtr& v) override; 167 void visit(const StorePtr& v) override; 168 void visit(const AtomicAddPtr& v) override; 169 void visit(const MaxPtr& v) override; 170 void visit(const MinPtr& v) override; 171 void visit(const IfThenElsePtr& v) override; 172 void visit(const BlockPtr& v) override; 173 void visit(const AllocatePtr& v) override; 174 void visit(const FreePtr& v) override; 175 void visit(const LetPtr& v) override; 176 177 void visit(const ExternalCallPtr& v) override; 178 rand_func()179 VarPtr rand_func() const { 180 return rand_func_; 181 } 182 183 std::string dtypeToCppString(const Dtype& dtype) override; 184 185 using IRPrinter::name_manager; 186 using IRPrinter::visit; 187 188 private: 189 VarPtr rand_func_; 190 const CudaAnalysis* cuda_analysis_; 191 192 void print_flat_alloc(const AllocatePtr& alloc); 193 }; 194 195 // Construct Cuda C from the buffer and tensor input, and invoke the 196 // kernel when real arguments are provided. 197 class TORCH_CUDA_CU_API CudaCodeGen : public CodeGen { 198 public: 199 template <typename... Ts> CudaCodeGen(StmtPtr stmt,Ts...ts)200 CudaCodeGen(StmtPtr stmt, Ts... ts) 201 : CodeGen( 202 stmt, 203 std::vector<BufferArg>({BufferArg(ts)...}), 204 at::Device(at::kCUDA, at::cuda::current_device())) { 205 Initialize(); 206 } 207 208 CudaCodeGen( 209 StmtPtr stmt, 210 const std::vector<BufferArg>& buffer_args, 211 at::Device device = at::Device(at::kCUDA, at::cuda::current_device()), 212 const std::string& kernel_func_name = "func") CodeGen(std::move (stmt),buffer_args,device,kernel_func_name)213 : CodeGen(std::move(stmt), buffer_args, device, kernel_func_name) { 214 Initialize(); 215 } 216 217 ~CudaCodeGen() override; 218 219 void call(const std::vector<CallArg>& args) override; 220 void call_raw(const std::vector<void*>& args) override; 221 void call_with_numel(void** args, int64_t numel) override; 222 223 template <typename... Ts> operator()224 void operator()(const Ts&... ts) { 225 call(std::vector<CallArg>({CallArg(ts)...})); 226 } 227 228 at::Tensor empty_strided( 229 c10::IntArrayRef size, 230 c10::IntArrayRef stride, 231 std::optional<c10::ScalarType> dtype_opt, 232 std::optional<c10::Layout> layout_opt, 233 std::optional<c10::Device> device_opt, 234 std::optional<bool> pin_memory_opt) override; 235 gpu_block_extents()236 const std::vector<ExprPtr>& gpu_block_extents() const { 237 return cuda_analysis_->gpu_block_extents(); 238 } 239 gpu_thread_extents()240 const std::vector<ExprPtr>& gpu_thread_extents() const { 241 return cuda_analysis_->gpu_thread_extents(); 242 } 243 244 std::string getCodeText(const std::string& attr = "") override { 245 return oss_.str(); 246 } 247 248 private: 249 void Initialize(); 250 251 void CompileToNVRTC(const std::string& code, const std::string& func_name); 252 name_manager()253 UniqueNameManager* name_manager() { 254 if (!printer_) { 255 throw std::runtime_error("Null IRPrinter is not expected"); 256 } 257 return printer_->name_manager(); 258 } 259 os()260 std::ostream& os() { 261 return printer_->os(); 262 } 263 264 std::ostringstream oss_; 265 std::unique_ptr<CudaPrinter> printer_; 266 std::unique_ptr<CudaAnalysis> cuda_analysis_; 267 std::unique_ptr<GPUMetaVarRewriter> metavar_rewriter_; 268 std::unordered_set<std::string> taken_func_names; 269 std::mutex eval_lock_; 270 CUfunction function_{nullptr}; 271 bool has_random_ = false; 272 int thread_block_size_ = -1; 273 274 std::vector<bool> arg_pos_in_extents_; 275 #ifdef TORCH_ENABLE_LLVM 276 std::vector<ExprEval<LLVMCodeGen>> block_extents_eval_; 277 std::vector<ExprEval<LLVMCodeGen>> thread_extents_eval_; 278 #else 279 std::vector<ExprEval<SimpleIREvaluator>> block_extents_eval_; 280 std::vector<ExprEval<SimpleIREvaluator>> thread_extents_eval_; 281 #endif 282 283 std::string GetUniqueFuncName(const std::string& func_prefix); 284 }; 285 286 } // namespace torch::jit::tensorexpr 287