xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/cuda_codegen.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <unordered_set>
4 
5 #include <ATen/ATen.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
8 #include <c10/cuda/CUDACachingAllocator.h>
9 #include <c10/cuda/CUDAGuard.h>
10 #include <torch/csrc/jit/resource_guard.h>
11 #include <torch/csrc/jit/tensorexpr/codegen.h>
12 #include <torch/csrc/jit/tensorexpr/eval.h>
13 #include <torch/csrc/jit/tensorexpr/ir.h>
14 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
15 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
16 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
17 #include <torch/csrc/jit/tensorexpr/unique_name_manager.h>
18 
19 namespace torch::jit::tensorexpr {
20 
21 // A class that analyzes the given program relevant for Cuda backends.
22 class CudaAnalysis : public IRVisitor {
23  public:
CudaAnalysis()24   CudaAnalysis() {
25     gpu_block_extents_ = {alloc<IntImm>(1), alloc<IntImm>(1), alloc<IntImm>(1)};
26     gpu_thread_extents_ = {
27         alloc<IntImm>(1), alloc<IntImm>(1), alloc<IntImm>(1)};
28   }
is_buf_store_target(const BufPtr & buf)29   bool is_buf_store_target(const BufPtr& buf) const {
30     return store_targets_.count(buf) > 0;
31   }
32 
thread_local_bufs()33   const std::unordered_set<VarPtr>& thread_local_bufs() const {
34     return thread_local_bufs_;
35   }
36 
cross_block_bufs()37   const std::unordered_set<VarPtr>& cross_block_bufs() const {
38     return cross_block_bufs_;
39   }
40 
gpu_block_extents()41   const std::vector<ExprPtr>& gpu_block_extents() const {
42     return gpu_block_extents_;
43   }
44 
gpu_thread_extents()45   const std::vector<ExprPtr>& gpu_thread_extents() const {
46     return gpu_thread_extents_;
47   }
48 
49  private:
visit(const StorePtr & v)50   void visit(const StorePtr& v) override {
51     store_targets_.insert(v->buf());
52   }
53 
54   void visit(const AllocatePtr& v) override;
55   void visit(const FreePtr& v) override;
56   void visit(const PlacementAllocatePtr& v) override;
57   void visit(const ForPtr& v) override;
58 
59   std::unordered_set<BufPtr> store_targets_;
60   std::unordered_set<VarPtr> thread_local_bufs_;
61   std::unordered_set<VarPtr> cross_block_bufs_;
62 
63   std::vector<ExprPtr> gpu_block_extents_;
64   std::vector<ExprPtr> gpu_thread_extents_;
65 };
66 
67 // An IRMutator that replaces binding loop options with Cuda metavars, and masks
68 // statements blocks which should execute with less reach than the launch
69 // parameter extent.
70 //
71 // We do this by segmenting each block into chunks which should have the same
72 // execution parameters, then if those params differ from the max mask each dim.
73 class GPUMetaVarRewriter : public IRMutator {
74  public:
GPUMetaVarRewriter(const CudaAnalysis * cuda_analysis)75   explicit GPUMetaVarRewriter(const CudaAnalysis* cuda_analysis)
76       : cuda_analysis_(cuda_analysis) {
77     gpu_block_vars_ = {
78         alloc<Var>("blockIdx.x", kInt),
79         alloc<Var>("blockIdx.y", kInt),
80         alloc<Var>("blockIdx.z", kInt)};
81     gpu_thread_vars_ = {
82         alloc<Var>("threadIdx.x", kInt),
83         alloc<Var>("threadIdx.y", kInt),
84         alloc<Var>("threadIdx.z", kInt)};
85 
86     current_block_reach_ = {
87         alloc<IntImm>(1), alloc<IntImm>(1), alloc<IntImm>(1)};
88     current_thread_reach_ = {
89         alloc<IntImm>(1), alloc<IntImm>(1), alloc<IntImm>(1)};
90   }
91 
92   StmtPtr mutate(const ForPtr& v) override;
93   StmtPtr mutate(const BlockPtr& v) override;
94 
gpu_block_vars()95   const std::vector<VarPtr>& gpu_block_vars() const {
96     return gpu_block_vars_;
97   }
98 
gpu_thread_vars()99   const std::vector<VarPtr>& gpu_thread_vars() const {
100     return gpu_thread_vars_;
101   }
102 
gpu_block_extents()103   const std::vector<ExprPtr>& gpu_block_extents() const {
104     return cuda_analysis_->gpu_block_extents();
105   }
106 
gpu_thread_extents()107   const std::vector<ExprPtr>& gpu_thread_extents() const {
108     return cuda_analysis_->gpu_thread_extents();
109   }
110 
111  private:
112   // When processing a block, stores the contents of each sub-segment.
113   class Segment {
114    public:
reset(bool mask)115     void reset(bool mask) {
116       stmts_.clear();
117       mask_ = mask;
118     }
119 
empty()120     bool empty() const {
121       return stmts_.empty();
122     }
123 
stmts()124     std::vector<StmtPtr>& stmts() {
125       return stmts_;
126     }
mask()127     bool mask() {
128       return mask_;
129     }
130 
131    private:
132     std::vector<StmtPtr> stmts_;
133     bool mask_{true};
134   };
135 
136   // Returns true if the current execution scope is equivalent to the launch
137   // parameters.
138   bool isFullExtent();
139 
140   std::vector<VarPtr> gpu_block_vars_;
141   std::vector<VarPtr> gpu_thread_vars_;
142 
143   std::vector<ExprPtr> current_block_reach_;
144   std::vector<ExprPtr> current_thread_reach_;
145 
146   const CudaAnalysis* cuda_analysis_;
147 };
148 
149 // A class that overrides the underlying IRPrinter to produce Cuda C.
150 class CudaPrinter : public IRPrinter {
151  public:
CudaPrinter(std::ostream * os,const CudaAnalysis * cuda_analysis,bool has_random)152   explicit CudaPrinter(
153       std::ostream* os,
154       const CudaAnalysis* cuda_analysis,
155       bool has_random)
156       : IRPrinter(*os), cuda_analysis_(cuda_analysis) {
157     if (has_random) {
158       rand_func_ = alloc<Var>("rand", kHandle);
159     }
160   }
161 
162   void visit(const CastPtr& v) override;
163   void visit(const IntrinsicsPtr& v) override;
164   void visit(const ForPtr& v) override;
165 
166   void visit(const LoadPtr& v) override;
167   void visit(const StorePtr& v) override;
168   void visit(const AtomicAddPtr& v) override;
169   void visit(const MaxPtr& v) override;
170   void visit(const MinPtr& v) override;
171   void visit(const IfThenElsePtr& v) override;
172   void visit(const BlockPtr& v) override;
173   void visit(const AllocatePtr& v) override;
174   void visit(const FreePtr& v) override;
175   void visit(const LetPtr& v) override;
176 
177   void visit(const ExternalCallPtr& v) override;
178 
rand_func()179   VarPtr rand_func() const {
180     return rand_func_;
181   }
182 
183   std::string dtypeToCppString(const Dtype& dtype) override;
184 
185   using IRPrinter::name_manager;
186   using IRPrinter::visit;
187 
188  private:
189   VarPtr rand_func_;
190   const CudaAnalysis* cuda_analysis_;
191 
192   void print_flat_alloc(const AllocatePtr& alloc);
193 };
194 
195 // Construct Cuda C from the buffer and tensor input, and invoke the
196 // kernel when real arguments are provided.
197 class TORCH_CUDA_CU_API CudaCodeGen : public CodeGen {
198  public:
199   template <typename... Ts>
CudaCodeGen(StmtPtr stmt,Ts...ts)200   CudaCodeGen(StmtPtr stmt, Ts... ts)
201       : CodeGen(
202             stmt,
203             std::vector<BufferArg>({BufferArg(ts)...}),
204             at::Device(at::kCUDA, at::cuda::current_device())) {
205     Initialize();
206   }
207 
208   CudaCodeGen(
209       StmtPtr stmt,
210       const std::vector<BufferArg>& buffer_args,
211       at::Device device = at::Device(at::kCUDA, at::cuda::current_device()),
212       const std::string& kernel_func_name = "func")
CodeGen(std::move (stmt),buffer_args,device,kernel_func_name)213       : CodeGen(std::move(stmt), buffer_args, device, kernel_func_name) {
214     Initialize();
215   }
216 
217   ~CudaCodeGen() override;
218 
219   void call(const std::vector<CallArg>& args) override;
220   void call_raw(const std::vector<void*>& args) override;
221   void call_with_numel(void** args, int64_t numel) override;
222 
223   template <typename... Ts>
operator()224   void operator()(const Ts&... ts) {
225     call(std::vector<CallArg>({CallArg(ts)...}));
226   }
227 
228   at::Tensor empty_strided(
229       c10::IntArrayRef size,
230       c10::IntArrayRef stride,
231       std::optional<c10::ScalarType> dtype_opt,
232       std::optional<c10::Layout> layout_opt,
233       std::optional<c10::Device> device_opt,
234       std::optional<bool> pin_memory_opt) override;
235 
gpu_block_extents()236   const std::vector<ExprPtr>& gpu_block_extents() const {
237     return cuda_analysis_->gpu_block_extents();
238   }
239 
gpu_thread_extents()240   const std::vector<ExprPtr>& gpu_thread_extents() const {
241     return cuda_analysis_->gpu_thread_extents();
242   }
243 
244   std::string getCodeText(const std::string& attr = "") override {
245     return oss_.str();
246   }
247 
248  private:
249   void Initialize();
250 
251   void CompileToNVRTC(const std::string& code, const std::string& func_name);
252 
name_manager()253   UniqueNameManager* name_manager() {
254     if (!printer_) {
255       throw std::runtime_error("Null IRPrinter is not expected");
256     }
257     return printer_->name_manager();
258   }
259 
os()260   std::ostream& os() {
261     return printer_->os();
262   }
263 
264   std::ostringstream oss_;
265   std::unique_ptr<CudaPrinter> printer_;
266   std::unique_ptr<CudaAnalysis> cuda_analysis_;
267   std::unique_ptr<GPUMetaVarRewriter> metavar_rewriter_;
268   std::unordered_set<std::string> taken_func_names;
269   std::mutex eval_lock_;
270   CUfunction function_{nullptr};
271   bool has_random_ = false;
272   int thread_block_size_ = -1;
273 
274   std::vector<bool> arg_pos_in_extents_;
275 #ifdef TORCH_ENABLE_LLVM
276   std::vector<ExprEval<LLVMCodeGen>> block_extents_eval_;
277   std::vector<ExprEval<LLVMCodeGen>> thread_extents_eval_;
278 #else
279   std::vector<ExprEval<SimpleIREvaluator>> block_extents_eval_;
280   std::vector<ExprEval<SimpleIREvaluator>> thread_extents_eval_;
281 #endif
282 
283   std::string GetUniqueFuncName(const std::string& func_prefix);
284 };
285 
286 } // namespace torch::jit::tensorexpr
287