1 #include <torch/csrc/jit/tensorexpr/block_codegen.h>
2
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/tensorexpr/analysis.h>
5 #include <torch/csrc/jit/tensorexpr/eval.h>
6 #include <torch/csrc/jit/tensorexpr/exceptions.h>
7 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
8
9 namespace torch::jit::tensorexpr {
10
blockDtypeCppString(const Dtype & dtype)11 static std::string blockDtypeCppString(const Dtype& dtype) {
12 switch (dtype.scalar_type()) {
13 case ScalarType::Bool:
14 return "1";
15 // NOLINTNEXTLINE(bugprone-branch-clone)
16 case ScalarType::Half:
17 return "2";
18 case ScalarType::BFloat16:
19 return "2";
20 // NOLINTNEXTLINE(bugprone-branch-clone)
21 case ScalarType::Char:
22 return "1";
23 case ScalarType::Byte:
24 return "1";
25 case ScalarType::Short:
26 return "4";
27 case ScalarType::Long:
28 return "8";
29 case ScalarType::Float:
30 return "2"; // Return Half for now
31 default:
32 return dtype.ToCppString();
33 }
34 }
35
areBufsInMap(const std::unordered_set<BufPtr> & bufs) const36 bool BlockAnalysis::areBufsInMap(const std::unordered_set<BufPtr>& bufs) const {
37 for (auto const& arg : bufs) {
38 auto got = map_input_to_tensor_bufs_.find(arg->name_hint());
39 if (got == map_input_to_tensor_bufs_.end()) {
40 return false;
41 }
42 }
43 return true;
44 }
45
getMultiDimBuf(const BufPtr & buf) const46 BufPtr BlockAnalysis::getMultiDimBuf(const BufPtr& buf) const {
47 auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint());
48 if (input_ != map_input_to_tensor_bufs_.end()) {
49 return input_->second;
50 } else {
51 throw std::runtime_error("BlockCodeGen: Entry not in input/Buffer map");
52 }
53 }
54
getInputName(const BufPtr & buf) const55 std::string BlockAnalysis::getInputName(const BufPtr& buf) const {
56 auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint());
57 if (input_ != map_input_to_tensor_bufs_.end()) {
58 return input_->second->name_hint();
59 } else {
60 throw std::runtime_error("BlockCodeGen: Entry not in input/Buffer map");
61 }
62 }
63
visit(const StorePtr & v)64 void BlockAnalysis::visit(const StorePtr& v) {
65 store_targets_.insert(v->buf());
66 v->value()->accept(this);
67 }
68
visit(const LoadPtr & v)69 void BlockAnalysis::visit(const LoadPtr& v) {
70 loads_.insert(v->buf());
71 }
72
visit(const ForPtr & v)73 void BlockAnalysis::visit(const ForPtr& v) {
74 const LoopOptions& loop_options = v->loop_options();
75 if (loop_options.is_gpu_block_index()) {
76 map_input_to_tensor_bufs_ = loop_options.get_buffer_mapping();
77 v->body()->accept(this);
78 } else if (loop_options.is_gpu_thread_index()) {
79 auto block_size = v->stop();
80 block_size_ = *intValue(block_size);
81 v->body()->accept(this);
82 } else {
83 IRVisitor::visit(v);
84 }
85 }
86
87 // For both Add, Mul we only print out the opening
88 // parenthesis. This behavior is to handle blocks add Op
89 // where c=a+b becomes add(a, b, c). The closing parenthesis is
90 // added in the store statement.
91 // TODO: When handling fused ops d = a + b + c, the correct
92 // way would be to mutate the expression to Block version and print.
93
visit(const AddPtr & v)94 void BlockPrinter::visit(const AddPtr& v) {
95 emitIndent();
96 os() << "add(";
97 v->lhs()->accept(this);
98 v->rhs()->accept(this);
99 }
100
visit(const MulPtr & v)101 void BlockPrinter::visit(const MulPtr& v) {
102 emitIndent();
103 os() << "mul(";
104 v->lhs()->accept(this);
105 v->rhs()->accept(this);
106 }
107
visit(const ForPtr & v)108 void BlockPrinter::visit(const ForPtr& v) {
109 const LoopOptions& loop_options = v->loop_options();
110
111 auto buf_reads = block_analysis_->loads();
112 auto buf_writes = block_analysis_->stores();
113 std::unordered_set<BufPtr> bufs(buf_reads.begin(), buf_reads.end());
114 bufs.insert(buf_writes.begin(), buf_writes.end());
115
116 if (loop_options.is_gpu_block_index()) {
117 emitIndent();
118 PrintTensorInfo(bufs);
119 PrintDistribution(bufs);
120 PrintBufferInfo(buf_reads);
121 PrintArguments(bufs);
122
123 emitIndent();
124 os() << "compute {" << '\n';
125
126 PrintReshapeInfo(bufs);
127
128 emitIndent();
129 PrintLoop(bufs, true);
130 v->body()->accept(this);
131
132 os() << '\n';
133 emitIndent();
134 PrintReshapeInfo(buf_writes, true); // print reverse reshape
135 os() << "}";
136 os() << '\n';
137 } else if (loop_options.is_gpu_thread_index()) {
138 PrintDMAs(buf_reads);
139 PrintLoop(buf_reads, false);
140 v->body()->accept(this);
141 os() << '\n';
142 PrintAdjustBuffers(buf_reads);
143
144 } else {
145 IRPrinter::visit(v);
146 }
147 }
148
PrintTensorInfo(const std::unordered_set<BufPtr> & bufs)149 void BlockPrinter::PrintTensorInfo(const std::unordered_set<BufPtr>& bufs) {
150 os() << "tensors {";
151 for (auto& buf : bufs) {
152 os() << '\n';
153 emitIndent();
154 emitIndent();
155 auto num_dims = block_analysis_->getMultiDimBuf(buf)->dims().size();
156 os() << block_analysis_->getInputName(buf) << " = ";
157 os() << "{";
158 for (unsigned long d = 0; d < num_dims; d++) {
159 os() << "{" << dim_names[d] << "};";
160 }
161 os() << " elem : " << blockDtypeCppString(buf->dtype());
162 os() << "}";
163 }
164
165 for (auto& buf : bufs) {
166 os() << '\n';
167 emitIndent();
168 emitIndent();
169 auto num_dims = block_analysis_->getMultiDimBuf(buf)->dims().size();
170 os() << block_analysis_->getFlatInputName(buf) << " = ";
171 os() << "{";
172 os() << "{" << flat_dim_names[num_dims - 1] << "};";
173 os() << " elem : " << blockDtypeCppString(buf->dtype());
174 os() << "}"
175 << " // flattened tensor";
176 }
177 os() << '\n';
178 emitIndent();
179 os() << "}" << '\n' << '\n';
180 }
181
PrintArguments(const std::unordered_set<BufPtr> & bufs)182 void BlockPrinter::PrintArguments(const std::unordered_set<BufPtr>& bufs) {
183 for (auto& buf : bufs) {
184 auto multidimbuf = block_analysis_->getMultiDimBuf(buf);
185 auto num_dims = multidimbuf->dims().size();
186
187 // The dims for the multi-dim tensors
188 for (unsigned long d = 0; d < num_dims; d++) {
189 auto dim_val = *intValue(multidimbuf->dim(d));
190 this->dim_values_map.emplace(this->dim_names[d], dim_val);
191 }
192
193 // The dimensions for the flattened tensors
194 auto val = *intValue(buf->dim(0));
195 if (block_analysis_->is_buf_store_target(buf)) {
196 this->dim_values_map.emplace(this->flat_dim_names[num_dims - 1], val);
197 }
198 }
199
200 emitIndent();
201 os() << "arguments {" << '\n';
202
203 for (auto const& arg : this->dim_values_map) {
204 emitIndent();
205 os() << "var " << arg.first << " = " << arg.second << '\n';
206 }
207
208 emitIndent();
209 emitIndent();
210 auto blck_sz = block_analysis_->block_size();
211 os() << "var bs_N = " << blck_sz << '\n';
212 emitIndent();
213 emitIndent();
214 os() << "var bs_DPE = " << blck_sz << '\n';
215 emitIndent();
216 os() << "}" << '\n' << '\n';
217 }
218
PrintBufferInfo(const std::unordered_set<BufPtr> & bufs)219 void BlockPrinter::PrintBufferInfo(const std::unordered_set<BufPtr>& bufs) {
220 emitIndent();
221 os() << "buffers {";
222 for (auto& read : bufs) {
223 os() << '\n';
224 emitIndent();
225 emitIndent();
226 os() << block_analysis_->getFlatInputName(read) << " = ";
227 os() << "{{"
228 << "bs_DPE"
229 << "}}";
230 }
231 os() << '\n';
232 emitIndent();
233 os() << "}" << '\n' << '\n';
234 }
235
PrintDistribution(const std::unordered_set<BufPtr> & bufs)236 void BlockPrinter::PrintDistribution(const std::unordered_set<BufPtr>& bufs) {
237 emitIndent();
238 os() << "distribution {" << '\n';
239 for (auto& buf : bufs) {
240 emitIndent();
241 emitIndent();
242 os() << block_analysis_->getFlatInputName(buf) << " = ";
243 os() << "{(0, 1, )}" << '\n';
244 }
245 os() << " }" << '\n' << '\n';
246 }
247
PrintLoop(const std::unordered_set<BufPtr> & bufs,bool block_idx)248 void BlockPrinter::PrintLoop(
249 const std::unordered_set<BufPtr>& bufs,
250 bool block_idx) {
251 emitIndent();
252 os() << "loop (";
253 auto trip = 0;
254 for (auto& buf : bufs) {
255 if (trip > 0) {
256 os() << ",";
257 }
258 os() << "{dim : ";
259 os() << block_analysis_->getFlatInputName(buf) << ".dim.0, ";
260 os() << (block_idx ? "block: bs_N}" : "block: bs_DPE}");
261 ++trip;
262 }
263 os() << ")";
264 }
265
PrintReshapeInfo(const std::unordered_set<BufPtr> & bufs,bool reverse)266 void BlockPrinter::PrintReshapeInfo(
267 const std::unordered_set<BufPtr>& bufs,
268 bool reverse) {
269 for (auto& buf : bufs) {
270 emitIndent();
271 os() << "reshape("
272 << (reverse ? block_analysis_->getFlatInputName(buf)
273 : block_analysis_->getInputName(buf))
274 << ", "
275 << (reverse ? block_analysis_->getInputName(buf)
276 : block_analysis_->getFlatInputName(buf))
277 << ")" << '\n';
278 }
279 }
280
PrintDMAs(const std::unordered_set<BufPtr> & bufs)281 void BlockPrinter::PrintDMAs(const std::unordered_set<BufPtr>& bufs) {
282 for (auto& read : bufs) {
283 emitIndent();
284 os() << "dma_in(";
285 os() << block_analysis_->getFlatInputName(read);
286 os() << ")" << '\n';
287 }
288 }
PrintAdjustBuffers(const std::unordered_set<BufPtr> & bufs)289 void BlockPrinter::PrintAdjustBuffers(const std::unordered_set<BufPtr>& bufs) {
290 for (auto& read : bufs) {
291 emitIndent();
292 os() << "adjust_buffer(";
293 os() << block_analysis_->getFlatInputName(read);
294 os() << ")" << '\n';
295 }
296 }
297
visit(const LoadPtr & v)298 void BlockPrinter::visit(const LoadPtr& v) {
299 os() << block_analysis_->getFlatInputName(v->buf()) << ".buffer, ";
300 }
visit(const StorePtr & v)301 void BlockPrinter::visit(const StorePtr& v) {
302 emitIndent();
303 os() << *v->value() << block_analysis_->getFlatInputName(v->buf())
304 << ".tensor)" << '\n';
305 }
306
visit(const BlockPtr & v)307 void BlockPrinter::visit(const BlockPtr& v) {
308 os() << "{" << '\n';
309 indent_++;
310 for (const StmtPtr& s : v->stmts()) {
311 s->accept(this);
312 }
313 indent_--;
314 emitIndent();
315 os() << "}";
316 }
317
GetUniqueFuncName(const std::string & func_prefix)318 std::string BlockCodeGen::GetUniqueFuncName(const std::string& func_prefix) {
319 // We are using a global counter here to make sure difference instances
320 // within BlockCodeGen have different names.
321 static int64_t counter = 0;
322 ++counter;
323 int64_t value = counter;
324 return func_prefix + "_" + std::to_string(value);
325 }
326
Initialize()327 void BlockCodeGen::Initialize() {
328 block_analysis_ = std::make_unique<BlockAnalysis>();
329 printer_ = std::make_unique<BlockPrinter>(&oss_, block_analysis_.get());
330
331 StmtPtr stmt_v = stmt();
332 stmt_v->accept(block_analysis_.get());
333
334 auto buf_reads = block_analysis_->loads();
335 auto buf_writes = block_analysis_->stores();
336 // Ensure all Bufs in reads/writes are in the map
337 std::unordered_set<BufPtr> bufs(buf_reads.begin(), buf_reads.end());
338 bufs.insert(buf_writes.begin(), buf_writes.end());
339 if (!block_analysis_->areBufsInMap(bufs)) {
340 throw std::runtime_error("BlockCodeGen: Entry not in input/Buffer map");
341 };
342
343 std::string func_name = GetUniqueFuncName("func");
344 os() << "kernel " << func_name << "(";
345 for (auto const& arg : buf_writes) {
346 os() << block_analysis_->getInputName(arg);
347 }
348 for (auto const& arg : buf_reads) {
349 os() << ";" << block_analysis_->getInputName(arg);
350 }
351 os() << ")";
352
353 stmt_v->accept(printer_.get());
354
355 GRAPH_DEBUG("Generated Block code: ", oss_.str(), "\n");
356 }
357
call(const std::vector<CallArg> & args)358 void BlockCodeGen::call(const std::vector<CallArg>& args) {
359 throw std::runtime_error("BlockCodeGen: Cannot call Block code ");
360 }
call_raw(const std::vector<void * > & args)361 void BlockCodeGen::call_raw(const std::vector<void*>& args) {
362 throw std::runtime_error("BlockCodeGen: Cannot call Block code ");
363 }
364
365 BlockCodeGen::~BlockCodeGen() = default;
366 RegisterCodeGen<BlockCodeGen> block_codegen_reg("block_codegen");
367
368 } // namespace torch::jit::tensorexpr
369