xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/block_codegen.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <string>
4 #include <unordered_map>
5 #include <unordered_set>
6 #include <utility>
7 
8 #include <ATen/ATen.h>
9 #include <torch/csrc/jit/resource_guard.h>
10 #include <torch/csrc/jit/tensorexpr/analysis.h>
11 #include <torch/csrc/jit/tensorexpr/codegen.h>
12 #include <torch/csrc/jit/tensorexpr/ir.h>
13 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
14 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
15 #include <torch/csrc/jit/tensorexpr/unique_name_manager.h>
16 
17 namespace torch::jit::tensorexpr {
18 
19 // A class that analyzes the given program relevant for Block backend.
20 class BlockAnalysis : public IRVisitor {
21  public:
is_buf_store_target(const BufPtr & buf)22   bool is_buf_store_target(const BufPtr& buf) const {
23     return store_targets_.count(buf) > 0;
24   }
25 
loads()26   const std::unordered_set<BufPtr>& loads() const {
27     return loads_;
28   }
29 
stores()30   const std::unordered_set<BufPtr>& stores() const {
31     return store_targets_;
32   }
33 
block_size()34   int64_t block_size() const {
35     return block_size_;
36   }
37 
38   bool areBufsInMap(const std::unordered_set<BufPtr>& bufs) const;
39 
40   BufPtr getMultiDimBuf(const BufPtr& buf) const;
41 
42   std::string getInputName(const BufPtr& buf) const;
43 
getFlatInputName(const BufPtr & buf)44   std::string getFlatInputName(const BufPtr& buf) const {
45     return getInputName(buf) + "_flat";
46   }
47 
getBufferMap()48   std::unordered_map<std::string, BufPtr> getBufferMap() const {
49     return map_input_to_tensor_bufs_;
50   }
51 
52  private:
53   void visit(const StorePtr& v) override;
54   void visit(const LoadPtr& v) override;
55   void visit(const ForPtr& v) override;
56 
57   std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
58   std::unordered_set<BufPtr> store_targets_;
59   std::unordered_set<BufPtr> loads_;
60   int64_t block_size_ = 32;
61 };
62 
63 // A class that overrides the underlying IRPrinter to produce Block.
64 class BlockPrinter : public IRPrinter {
65  public:
BlockPrinter(std::ostream * os,BlockAnalysis * block_analysis)66   BlockPrinter(std::ostream* os, BlockAnalysis* block_analysis)
67       : IRPrinter(*os), block_analysis_(block_analysis) {}
68 
69   using IRPrinter::name_manager;
70   using IRPrinter::visit;
71 
72  private:
73   BlockAnalysis* block_analysis_;
74   std::unordered_map<std::string, int> dim_values_map;
75   std::vector<std::string> dim_names = {"N", "H", "W", "C"};
76   std::vector<std::string> flat_dim_names = {"N", "NH", "NHW", "NHWC"};
77   void PrintTensorInfo(const std::unordered_set<BufPtr>& bufs);
78   void PrintArguments(const std::unordered_set<BufPtr>& bufs);
79   void PrintBufferInfo(const std::unordered_set<BufPtr>& bufs);
80   void PrintDistribution(const std::unordered_set<BufPtr>& bufs);
81   void PrintLoop(const std::unordered_set<BufPtr>& bufs, bool block_idx = true);
82   void PrintReshapeInfo(
83       const std::unordered_set<BufPtr>& bufs,
84       bool reverse = false);
85   void PrintDMAs(const std::unordered_set<BufPtr>& bufs);
86   void PrintAdjustBuffers(const std::unordered_set<BufPtr>& bufs);
87 
88   void visit(const ForPtr& v) override;
89   void visit(const LoadPtr& v) override;
90   void visit(const StorePtr& v) override;
91   void visit(const BlockPtr& v) override;
92   void visit(const AddPtr& v) override;
93   void visit(const MulPtr& v) override;
94 };
95 
96 class TORCH_API BlockCodeGen : public CodeGen {
97  public:
98   template <typename... Ts>
99   /* implicit */
BlockCodeGen(StmtPtr stmt,Ts...ts)100   BlockCodeGen(StmtPtr stmt, Ts... ts)
101       : CodeGen(
102             stmt,
103             std::vector<BufferArg>({BufferArg(ts)...}),
104             at::Device(at::kCPU)) {
105     Initialize();
106   }
107 
108   BlockCodeGen(
109       StmtPtr stmt,
110       const std::vector<BufferArg>& buffer_args,
111       at::Device device = at::Device(at::kCPU),
112       const std::string& kernel_func_name = "func")
CodeGen(std::move (stmt),buffer_args,device,kernel_func_name)113       : CodeGen(std::move(stmt), buffer_args, device, kernel_func_name) {
114     Initialize();
115   }
116 
117   ~BlockCodeGen() override;
118 
119   void call(const std::vector<CallArg>& args) override;
120   void call_raw(const std::vector<void*>& args) override;
121 
122   void Initialize();
123 
124   std::string getCodeText(const std::string& attr = "") override {
125     return oss_.str();
126   }
127 
128  private:
name_manager()129   UniqueNameManager* name_manager() {
130     if (!printer_) {
131       throw std::runtime_error("Null IRPrinter is not expected");
132     }
133     return printer_->name_manager();
134   }
135 
os()136   std::ostream& os() {
137     return printer_->os();
138   }
139 
140   std::ostringstream oss_;
141   std::unique_ptr<BlockPrinter> printer_;
142   std::unique_ptr<BlockAnalysis> block_analysis_;
143 
144   std::string GetUniqueFuncName(const std::string& func_prefix);
145 };
146 } // namespace torch::jit::tensorexpr
147