xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/block_codegen.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/block_codegen.h>
2 
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/tensorexpr/analysis.h>
5 #include <torch/csrc/jit/tensorexpr/eval.h>
6 #include <torch/csrc/jit/tensorexpr/exceptions.h>
7 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
8 
9 namespace torch::jit::tensorexpr {
10 
blockDtypeCppString(const Dtype & dtype)11 static std::string blockDtypeCppString(const Dtype& dtype) {
12   switch (dtype.scalar_type()) {
13     case ScalarType::Bool:
14       return "1";
15     // NOLINTNEXTLINE(bugprone-branch-clone)
16     case ScalarType::Half:
17       return "2";
18     case ScalarType::BFloat16:
19       return "2";
20     // NOLINTNEXTLINE(bugprone-branch-clone)
21     case ScalarType::Char:
22       return "1";
23     case ScalarType::Byte:
24       return "1";
25     case ScalarType::Short:
26       return "4";
27     case ScalarType::Long:
28       return "8";
29     case ScalarType::Float:
30       return "2"; // Return Half for now
31     default:
32       return dtype.ToCppString();
33   }
34 }
35 
areBufsInMap(const std::unordered_set<BufPtr> & bufs) const36 bool BlockAnalysis::areBufsInMap(const std::unordered_set<BufPtr>& bufs) const {
37   for (auto const& arg : bufs) {
38     auto got = map_input_to_tensor_bufs_.find(arg->name_hint());
39     if (got == map_input_to_tensor_bufs_.end()) {
40       return false;
41     }
42   }
43   return true;
44 }
45 
getMultiDimBuf(const BufPtr & buf) const46 BufPtr BlockAnalysis::getMultiDimBuf(const BufPtr& buf) const {
47   auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint());
48   if (input_ != map_input_to_tensor_bufs_.end()) {
49     return input_->second;
50   } else {
51     throw std::runtime_error("BlockCodeGen: Entry not in input/Buffer map");
52   }
53 }
54 
getInputName(const BufPtr & buf) const55 std::string BlockAnalysis::getInputName(const BufPtr& buf) const {
56   auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint());
57   if (input_ != map_input_to_tensor_bufs_.end()) {
58     return input_->second->name_hint();
59   } else {
60     throw std::runtime_error("BlockCodeGen: Entry not in input/Buffer map");
61   }
62 }
63 
visit(const StorePtr & v)64 void BlockAnalysis::visit(const StorePtr& v) {
65   store_targets_.insert(v->buf());
66   v->value()->accept(this);
67 }
68 
visit(const LoadPtr & v)69 void BlockAnalysis::visit(const LoadPtr& v) {
70   loads_.insert(v->buf());
71 }
72 
visit(const ForPtr & v)73 void BlockAnalysis::visit(const ForPtr& v) {
74   const LoopOptions& loop_options = v->loop_options();
75   if (loop_options.is_gpu_block_index()) {
76     map_input_to_tensor_bufs_ = loop_options.get_buffer_mapping();
77     v->body()->accept(this);
78   } else if (loop_options.is_gpu_thread_index()) {
79     auto block_size = v->stop();
80     block_size_ = *intValue(block_size);
81     v->body()->accept(this);
82   } else {
83     IRVisitor::visit(v);
84   }
85 }
86 
87 // For both Add, Mul we only print out the opening
88 // parenthesis. This behavior is to handle blocks add Op
89 // where c=a+b becomes add(a, b, c). The closing parenthesis is
90 // added in the store statement.
91 // TODO: When handling fused ops d = a + b + c, the correct
92 // way would be to mutate the expression to Block version and print.
93 
visit(const AddPtr & v)94 void BlockPrinter::visit(const AddPtr& v) {
95   emitIndent();
96   os() << "add(";
97   v->lhs()->accept(this);
98   v->rhs()->accept(this);
99 }
100 
visit(const MulPtr & v)101 void BlockPrinter::visit(const MulPtr& v) {
102   emitIndent();
103   os() << "mul(";
104   v->lhs()->accept(this);
105   v->rhs()->accept(this);
106 }
107 
visit(const ForPtr & v)108 void BlockPrinter::visit(const ForPtr& v) {
109   const LoopOptions& loop_options = v->loop_options();
110 
111   auto buf_reads = block_analysis_->loads();
112   auto buf_writes = block_analysis_->stores();
113   std::unordered_set<BufPtr> bufs(buf_reads.begin(), buf_reads.end());
114   bufs.insert(buf_writes.begin(), buf_writes.end());
115 
116   if (loop_options.is_gpu_block_index()) {
117     emitIndent();
118     PrintTensorInfo(bufs);
119     PrintDistribution(bufs);
120     PrintBufferInfo(buf_reads);
121     PrintArguments(bufs);
122 
123     emitIndent();
124     os() << "compute {" << '\n';
125 
126     PrintReshapeInfo(bufs);
127 
128     emitIndent();
129     PrintLoop(bufs, true);
130     v->body()->accept(this);
131 
132     os() << '\n';
133     emitIndent();
134     PrintReshapeInfo(buf_writes, true); // print reverse reshape
135     os() << "}";
136     os() << '\n';
137   } else if (loop_options.is_gpu_thread_index()) {
138     PrintDMAs(buf_reads);
139     PrintLoop(buf_reads, false);
140     v->body()->accept(this);
141     os() << '\n';
142     PrintAdjustBuffers(buf_reads);
143 
144   } else {
145     IRPrinter::visit(v);
146   }
147 }
148 
PrintTensorInfo(const std::unordered_set<BufPtr> & bufs)149 void BlockPrinter::PrintTensorInfo(const std::unordered_set<BufPtr>& bufs) {
150   os() << "tensors {";
151   for (auto& buf : bufs) {
152     os() << '\n';
153     emitIndent();
154     emitIndent();
155     auto num_dims = block_analysis_->getMultiDimBuf(buf)->dims().size();
156     os() << block_analysis_->getInputName(buf) << " = ";
157     os() << "{";
158     for (unsigned long d = 0; d < num_dims; d++) {
159       os() << "{" << dim_names[d] << "};";
160     }
161     os() << " elem : " << blockDtypeCppString(buf->dtype());
162     os() << "}";
163   }
164 
165   for (auto& buf : bufs) {
166     os() << '\n';
167     emitIndent();
168     emitIndent();
169     auto num_dims = block_analysis_->getMultiDimBuf(buf)->dims().size();
170     os() << block_analysis_->getFlatInputName(buf) << " = ";
171     os() << "{";
172     os() << "{" << flat_dim_names[num_dims - 1] << "};";
173     os() << " elem : " << blockDtypeCppString(buf->dtype());
174     os() << "}"
175          << " // flattened tensor";
176   }
177   os() << '\n';
178   emitIndent();
179   os() << "}" << '\n' << '\n';
180 }
181 
PrintArguments(const std::unordered_set<BufPtr> & bufs)182 void BlockPrinter::PrintArguments(const std::unordered_set<BufPtr>& bufs) {
183   for (auto& buf : bufs) {
184     auto multidimbuf = block_analysis_->getMultiDimBuf(buf);
185     auto num_dims = multidimbuf->dims().size();
186 
187     // The dims for the multi-dim tensors
188     for (unsigned long d = 0; d < num_dims; d++) {
189       auto dim_val = *intValue(multidimbuf->dim(d));
190       this->dim_values_map.emplace(this->dim_names[d], dim_val);
191     }
192 
193     // The dimensions for the flattened tensors
194     auto val = *intValue(buf->dim(0));
195     if (block_analysis_->is_buf_store_target(buf)) {
196       this->dim_values_map.emplace(this->flat_dim_names[num_dims - 1], val);
197     }
198   }
199 
200   emitIndent();
201   os() << "arguments {" << '\n';
202 
203   for (auto const& arg : this->dim_values_map) {
204     emitIndent();
205     os() << "var " << arg.first << " = " << arg.second << '\n';
206   }
207 
208   emitIndent();
209   emitIndent();
210   auto blck_sz = block_analysis_->block_size();
211   os() << "var bs_N = " << blck_sz << '\n';
212   emitIndent();
213   emitIndent();
214   os() << "var bs_DPE = " << blck_sz << '\n';
215   emitIndent();
216   os() << "}" << '\n' << '\n';
217 }
218 
PrintBufferInfo(const std::unordered_set<BufPtr> & bufs)219 void BlockPrinter::PrintBufferInfo(const std::unordered_set<BufPtr>& bufs) {
220   emitIndent();
221   os() << "buffers {";
222   for (auto& read : bufs) {
223     os() << '\n';
224     emitIndent();
225     emitIndent();
226     os() << block_analysis_->getFlatInputName(read) << " = ";
227     os() << "{{"
228          << "bs_DPE"
229          << "}}";
230   }
231   os() << '\n';
232   emitIndent();
233   os() << "}" << '\n' << '\n';
234 }
235 
PrintDistribution(const std::unordered_set<BufPtr> & bufs)236 void BlockPrinter::PrintDistribution(const std::unordered_set<BufPtr>& bufs) {
237   emitIndent();
238   os() << "distribution {" << '\n';
239   for (auto& buf : bufs) {
240     emitIndent();
241     emitIndent();
242     os() << block_analysis_->getFlatInputName(buf) << " = ";
243     os() << "{(0, 1, )}" << '\n';
244   }
245   os() << "  }" << '\n' << '\n';
246 }
247 
PrintLoop(const std::unordered_set<BufPtr> & bufs,bool block_idx)248 void BlockPrinter::PrintLoop(
249     const std::unordered_set<BufPtr>& bufs,
250     bool block_idx) {
251   emitIndent();
252   os() << "loop (";
253   auto trip = 0;
254   for (auto& buf : bufs) {
255     if (trip > 0) {
256       os() << ",";
257     }
258     os() << "{dim : ";
259     os() << block_analysis_->getFlatInputName(buf) << ".dim.0, ";
260     os() << (block_idx ? "block: bs_N}" : "block: bs_DPE}");
261     ++trip;
262   }
263   os() << ")";
264 }
265 
PrintReshapeInfo(const std::unordered_set<BufPtr> & bufs,bool reverse)266 void BlockPrinter::PrintReshapeInfo(
267     const std::unordered_set<BufPtr>& bufs,
268     bool reverse) {
269   for (auto& buf : bufs) {
270     emitIndent();
271     os() << "reshape("
272          << (reverse ? block_analysis_->getFlatInputName(buf)
273                      : block_analysis_->getInputName(buf))
274          << ", "
275          << (reverse ? block_analysis_->getInputName(buf)
276                      : block_analysis_->getFlatInputName(buf))
277          << ")" << '\n';
278   }
279 }
280 
PrintDMAs(const std::unordered_set<BufPtr> & bufs)281 void BlockPrinter::PrintDMAs(const std::unordered_set<BufPtr>& bufs) {
282   for (auto& read : bufs) {
283     emitIndent();
284     os() << "dma_in(";
285     os() << block_analysis_->getFlatInputName(read);
286     os() << ")" << '\n';
287   }
288 }
PrintAdjustBuffers(const std::unordered_set<BufPtr> & bufs)289 void BlockPrinter::PrintAdjustBuffers(const std::unordered_set<BufPtr>& bufs) {
290   for (auto& read : bufs) {
291     emitIndent();
292     os() << "adjust_buffer(";
293     os() << block_analysis_->getFlatInputName(read);
294     os() << ")" << '\n';
295   }
296 }
297 
visit(const LoadPtr & v)298 void BlockPrinter::visit(const LoadPtr& v) {
299   os() << block_analysis_->getFlatInputName(v->buf()) << ".buffer, ";
300 }
visit(const StorePtr & v)301 void BlockPrinter::visit(const StorePtr& v) {
302   emitIndent();
303   os() << *v->value() << block_analysis_->getFlatInputName(v->buf())
304        << ".tensor)" << '\n';
305 }
306 
visit(const BlockPtr & v)307 void BlockPrinter::visit(const BlockPtr& v) {
308   os() << "{" << '\n';
309   indent_++;
310   for (const StmtPtr& s : v->stmts()) {
311     s->accept(this);
312   }
313   indent_--;
314   emitIndent();
315   os() << "}";
316 }
317 
GetUniqueFuncName(const std::string & func_prefix)318 std::string BlockCodeGen::GetUniqueFuncName(const std::string& func_prefix) {
319   // We are using a global counter here to make sure difference instances
320   // within BlockCodeGen have different names.
321   static int64_t counter = 0;
322   ++counter;
323   int64_t value = counter;
324   return func_prefix + "_" + std::to_string(value);
325 }
326 
Initialize()327 void BlockCodeGen::Initialize() {
328   block_analysis_ = std::make_unique<BlockAnalysis>();
329   printer_ = std::make_unique<BlockPrinter>(&oss_, block_analysis_.get());
330 
331   StmtPtr stmt_v = stmt();
332   stmt_v->accept(block_analysis_.get());
333 
334   auto buf_reads = block_analysis_->loads();
335   auto buf_writes = block_analysis_->stores();
336   // Ensure all Bufs in reads/writes are in the map
337   std::unordered_set<BufPtr> bufs(buf_reads.begin(), buf_reads.end());
338   bufs.insert(buf_writes.begin(), buf_writes.end());
339   if (!block_analysis_->areBufsInMap(bufs)) {
340     throw std::runtime_error("BlockCodeGen: Entry not in input/Buffer map");
341   };
342 
343   std::string func_name = GetUniqueFuncName("func");
344   os() << "kernel " << func_name << "(";
345   for (auto const& arg : buf_writes) {
346     os() << block_analysis_->getInputName(arg);
347   }
348   for (auto const& arg : buf_reads) {
349     os() << ";" << block_analysis_->getInputName(arg);
350   }
351   os() << ")";
352 
353   stmt_v->accept(printer_.get());
354 
355   GRAPH_DEBUG("Generated Block code: ", oss_.str(), "\n");
356 }
357 
call(const std::vector<CallArg> & args)358 void BlockCodeGen::call(const std::vector<CallArg>& args) {
359   throw std::runtime_error("BlockCodeGen: Cannot call Block code ");
360 }
call_raw(const std::vector<void * > & args)361 void BlockCodeGen::call_raw(const std::vector<void*>& args) {
362   throw std::runtime_error("BlockCodeGen: Cannot call Block code ");
363 }
364 
365 BlockCodeGen::~BlockCodeGen() = default;
366 RegisterCodeGen<BlockCodeGen> block_codegen_reg("block_codegen");
367 
368 } // namespace torch::jit::tensorexpr
369