xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/codegen.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/tensorexpr/analysis.h>
3 #include <torch/csrc/jit/tensorexpr/codegen.h>
4 
5 #include <sstream>
6 #include <utility>
7 
8 namespace torch::jit::tensorexpr {
9 
CodeGen(StmtPtr stmt,std::vector<BufferArg> buffer_args,at::Device device,std::string kernel_func_name)10 CodeGen::CodeGen(
11     StmtPtr stmt,
12     std::vector<BufferArg> buffer_args,
13     at::Device device,
14     std::string kernel_func_name)
15     : stmt_(std::move(stmt)),
16       buffer_args_(std::move(buffer_args)),
17       device_(device),
18       kernel_func_name_(std::move(kernel_func_name)) {
19   ExtCallMemoryReuse extCallMemoryReuse(buffer_args_);
20   apply_mutator(&extCallMemoryReuse);
21   allocIntermediateBufs();
22 }
23 
GetInstance()24 RegisterCodeGenList& RegisterCodeGenList::GetInstance() {
25   static RegisterCodeGenList codegen_list;
26   return codegen_list;
27 }
28 
29 RegisterCodeGenList::StmtFactoryMethod RegisterCodeGenList::
FindStmtFactoryMethod(const std::string & name)30     FindStmtFactoryMethod(const std::string& name) {
31   auto iter = stmt_factory_methods_.find(name);
32   if (iter == stmt_factory_methods_.end()) {
33     std::ostringstream oss;
34     oss << "Invalid stmt codegen name: " << name << ". ";
35     oss << "Existing codegen names: [";
36     int index = 0;
37     for (auto& entry : stmt_factory_methods_) {
38       if (index != 0) {
39         oss << ", ";
40       }
41       oss << entry.first;
42       index++;
43     }
44     oss << "]";
45     throw std::runtime_error(oss.str());
46   }
47   return iter->second;
48 }
49 
AddStmtFactoryMethod(const std::string & name,const StmtFactoryMethod & stmt_factory_method)50 void RegisterCodeGenList::AddStmtFactoryMethod(
51     const std::string& name,
52     const StmtFactoryMethod& stmt_factory_method) {
53   stmt_factory_methods_[name] = stmt_factory_method;
54 }
55 
CreateCodeGen(const std::string & name,StmtPtr stmt,const std::vector<CodeGen::BufferArg> & params,at::Device device,const std::string & kernel_func_name)56 std::unique_ptr<CodeGen> CreateCodeGen(
57     const std::string& name,
58     StmtPtr stmt,
59     const std::vector<CodeGen::BufferArg>& params,
60     at::Device device,
61     const std::string& kernel_func_name) {
62   RegisterCodeGenList::StmtFactoryMethod method =
63       RegisterCodeGenList::GetInstance().FindStmtFactoryMethod(name);
64   return method(std::move(stmt), params, device, kernel_func_name);
65 }
66 
mutate(const IntrinsicsPtr & v)67 ExprPtr GenericIntrinsicsExpander::mutate(const IntrinsicsPtr& v) {
68   if (v->op_type() == kSigmoid) {
69     auto x = v->param(0)->accept_mutator(this);
70     auto one = expr_to_vec(
71         ExprHandle(getImmediateByType(v->dtype(), 1.0)), v->dtype().lanes());
72     auto zero = expr_to_vec(
73         ExprHandle(getImmediateByType(v->dtype(), 0.0)), v->dtype().lanes());
74     ExprHandle y = one / (one + exp(zero - ExprHandle(x)));
75     return y.node();
76   }
77   return IRMutator::mutate(v);
78 }
79 
argToPtr(const BufferArg & bufferArg,const CallArg & callArg)80 void* CodeGen::argToPtr(const BufferArg& bufferArg, const CallArg& callArg) {
81   if (!bufferArg.isVar()) {
82     return callArg.data();
83   }
84 
85   switch (bufferArg.dtype().scalar_type()) {
86 #define TYPE_CASE(_1, Name) \
87   case ScalarType::Name:    \
88     return callArg.Name##Ptr();
89 
90     AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
91 #undef TYPE_CASE
92 
93     default:
94       throw unsupported_dtype();
95   }
96   return nullptr;
97 }
98 
call_with_numel(void ** args,int64_t numel)99 void CodeGen::call_with_numel(void** args, int64_t numel) {
100   TORCH_INTERNAL_ASSERT(
101       false, "This codegen backend does not implement call_with_numel");
102 }
103 
bufSize(const BufPtr & buf)104 static std::optional<size_t> bufSize(const BufPtr& buf) {
105   size_t size = elementSize(buf->dtype().scalar_type()) * buf->dtype().lanes();
106   for (auto& d : buf->dims()) {
107     if (!d->isConstant()) {
108       return std::nullopt;
109     }
110     size = size * (*intValue(d));
111   }
112   return size;
113 }
114 
115 // This algorithm takes the list of intermediate buffers and their liveness
116 // ranges, and returns the allocations of these buffers. A buffer 'A' can be
117 // allocated in the memory (appears as a pair of 'A's in the allocation results)
118 // or reuse another buffer such as 'B' (appears as ('A', 'B')). Specifically, we
119 // linearly scan the intermediate buffers by the time they appear, and try to
120 // assign it an existing non-occupied memory allocation. If there are no such
121 // allocations available, we'll create memory for it. Once we are beyond the
122 // liveness range of this buffer, we'll mark its corresponding memory allocation
123 // as "up for grabs" for future reuse.
AllocBufsWithMemReuse(const std::unordered_set<BufPtr> & bufs,const std::unordered_map<BufPtr,std::tuple<int32_t,int32_t>> & buf_ranges,const std::unordered_set<BufPtr> & bufs_external_allocs)124 static std::vector<std::pair<BufPtr, BufPtr>> AllocBufsWithMemReuse(
125     const std::unordered_set<BufPtr>& bufs,
126     const std::unordered_map<BufPtr, std::tuple<int32_t, int32_t>>& buf_ranges,
127     const std::unordered_set<BufPtr>& bufs_external_allocs) {
128   // Sort buffers by the time they appear.
129   std::vector<BufPtr> bufs_sorted(bufs.begin(), bufs.end());
130   auto sorting_function_by_start_time =
131       [&buf_ranges](const BufPtr& b1, const BufPtr& b2) -> bool {
132     return std::get<0>(buf_ranges.at(b1)) < std::get<0>(buf_ranges.at(b2));
133   };
134   std::sort(
135       bufs_sorted.begin(), bufs_sorted.end(), sorting_function_by_start_time);
136 
137   // Map intermediate buffers to the most recently used memory if any.
138   std::list<BufPtr> mem_up_for_grabs;
139   std::unordered_map<BufPtr, BufPtr> buf_mem_map;
140   std::vector<std::pair<BufPtr, BufPtr>> buf_allocs;
141 
142   auto sorting_function_by_end_time =
143       [&buf_ranges](const BufPtr& b1, const BufPtr& b2) -> bool {
144     return std::get<1>(buf_ranges.at(b1)) < std::get<1>(buf_ranges.at(b2));
145   };
146   for (const auto& buf : bufs_sorted) {
147     // If the buf has dynamic shapes, we'll skip it (i.e., allocate memory for
148     // it, and there are no future reuses on its memory).
149     // TODO: reuse memory for bufs with dynamic shapes
150     if (!bufSize(buf)) {
151       buf_allocs.emplace_back(buf, buf);
152       continue;
153     }
154 
155     auto start = std::get<0>(buf_ranges.at(buf));
156 
157     // Release memory for buffers whose liveness range ends before the creation
158     // time of this buf.
159     // TODO: optimize in-place operations and copy operations
160     std::vector<BufPtr> buf_to_release;
161     for (auto& mapped : buf_mem_map) {
162       auto buf_mapped = mapped.first;
163       auto end_buf_mapped = std::get<1>(buf_ranges.at(buf_mapped));
164       if (end_buf_mapped < start) {
165         buf_to_release.push_back(buf_mapped);
166       }
167     }
168 
169     // Sort the buffers in the order of used time so the head of the release
170     // list contains the most recently used buf.
171     std::sort(
172         buf_to_release.begin(),
173         buf_to_release.end(),
174         sorting_function_by_end_time);
175     for (auto& buf_rl : buf_to_release) {
176       mem_up_for_grabs.push_front(buf_mem_map.at(buf_rl));
177       buf_mem_map.erase(buf_rl);
178     }
179 
180     bool allocated = false;
181     if (bufs_external_allocs.find(buf) == bufs_external_allocs.end()) {
182       // Check whether there are free memories that this buf can reuse.
183       for (auto it = mem_up_for_grabs.begin(); it != mem_up_for_grabs.end();
184            it++) {
185         auto m = *it;
186         if (bufSize(m) >= bufSize(buf)) {
187           buf_mem_map[buf] = m;
188           buf_allocs.emplace_back(buf, m);
189           allocated = true;
190           mem_up_for_grabs.erase(it);
191           break;
192         }
193       }
194     }
195 
196     // If there are no memories to reuse, we'll have to allocate new memory for
197     // it.
198     if (!allocated) {
199       buf_mem_map[buf] = buf;
200       buf_allocs.emplace_back(buf, buf);
201     }
202   }
203 
204   return buf_allocs;
205 }
206 
insertAllocFree(std::vector<std::pair<BufPtr,BufPtr>> & buf_allocs,const std::unordered_set<BufPtr> & bufs_external_allocs,const StmtPtr & stmt)207 static StmtPtr insertAllocFree(
208     std::vector<std::pair<BufPtr, BufPtr>>& buf_allocs,
209     const std::unordered_set<BufPtr>& bufs_external_allocs,
210     const StmtPtr& stmt) {
211   BlockPtr b = to<Block>(stmt);
212   if (!b) {
213     b = alloc<Block>(std::vector<StmtPtr>({stmt}));
214   }
215 
216   std::vector<BufPtr> bufs_ext_to_free;
217   // Insert allocations and frees for temporary buffers at global scope.
218   for (auto rit = buf_allocs.rbegin(); rit != buf_allocs.rend(); ++rit) {
219     if (rit->first == rit->second) {
220       BufPtr buf = rit->first;
221       if (bufs_external_allocs.find(buf) == bufs_external_allocs.end()) {
222         b->prepend_stmt(alloc<Allocate>(buf));
223         b->append_stmt(alloc<Free>(buf));
224       } else {
225         bufs_ext_to_free.push_back(buf);
226       }
227     } else {
228       b->prepend_stmt(alloc<PlacementAllocate>(rit->first, rit->second));
229     }
230   }
231 
232   b->append_stmt(alloc<FreeExt>(bufs_ext_to_free));
233   return b;
234 }
235 
236 std::unordered_map<std::string, std::string> ExtCallMemoryReuse::
makeExtCallFuncNameMap()237     makeExtCallFuncNameMap() {
238   return {
239       {"nnc_aten_quantize_per_tensor", "nnc_aten_quantize_per_tensor_out"},
240       {"nnc_aten_dequantize", "nnc_aten_dequantize_out"},
241       {"nnc_aten_quantized_mul", "nnc_aten_quantized_mul_out"},
242       {"nnc_aten_quantized_conv2d", "nnc_aten_quantized_conv2d_out"},
243       {"nnc_aten_quantized_conv2d_relu", "nnc_aten_quantized_conv2d_relu_out"},
244       {"nnc_aten_quantized_mul", "nnc_aten_quantized_mul_out"},
245       {"nnc_aten_quantized_sigmoid", "nnc_aten_quantized_sigmoid_out"},
246       {"nnc_aten_upsample_nearest2d", "nnc_aten_upsample_nearest2d_out"},
247       {"nnc_aten_quantized_linear", "nnc_aten_quantized_linear_out"},
248       {"nnc_aten_quantized_conv1d", "nnc_aten_quantized_conv1d_out"},
249       {"nnc_aten_quantized_mul_scalar", "nnc_aten_quantized_mul_scalar_out"},
250       {"nnc_aten_max_red", "nnc_aten_max_red_out"},
251       {"nnc_aten_conv1d", "nnc_aten_conv1d_out"},
252   };
253 }
254 
255 const std::unordered_map<std::string, std::string>
256     ExtCallMemoryReuse::extCallFuncNameMap_ = makeExtCallFuncNameMap();
257 
ExtCallMemoryReuse(const std::vector<CodeGen::BufferArg> & bufferArgs)258 ExtCallMemoryReuse::ExtCallMemoryReuse(
259     const std::vector<CodeGen::BufferArg>& bufferArgs) {
260   for (const auto& ba : bufferArgs) {
261     if (ba.buf()) {
262       bufferArgs_.insert(ba.buf());
263     }
264   }
265 }
266 
mutate(const ExternalCallPtr & v)267 StmtPtr ExtCallMemoryReuse::mutate(const ExternalCallPtr& v) {
268   if (extCallFuncNameMap_.count(v->func_name()) &&
269       bufferArgs_.count(v->buf()) == 0) {
270     std::vector<BufPtr> buf_out_args = {v->buf()};
271     return alloc<ExternalCallWithAlloc>(
272         extCallFuncNameMap_.at(v->func_name()),
273         buf_out_args,
274         v->buf_args(),
275         v->args());
276   }
277   return v;
278 }
279 
280 // We allocate intermediate buffers by inserting Allocate/Free or
281 // PlacementAllocate stmts. Allocate/Free stmts will allocate memory at runtime,
282 // and PlacementAllocate stmt reuses the memory of one buffer for another
283 // buffer. In current implementation, we use linear scan for memory reuses.
284 // TODO: try more memory reuse algorithms and compare their memory efficiency.
allocIntermediateBufs()285 void CodeGen::allocIntermediateBufs() {
286   // Identify intermediate buffers that are not allocated yet.
287   auto bufs = NodeFinder<Buf>::find(stmt_);
288   std::unordered_set<BufPtr> bufs_allocated;
289   for (const auto& b : buffer_args_) {
290     bufs_allocated.insert(b.buf());
291   }
292   auto allocs = NodeFinder<Allocate>::find(stmt_);
293   for (const auto& a : allocs) {
294     bufs_allocated.insert(a->buf());
295   }
296 
297   std::unordered_set<BufPtr> interm_bufs;
298   std::unordered_map<BufPtr, std::tuple<int32_t, int32_t>> interm_buf_ranges;
299   for (const auto& buf : bufs) {
300     if (!bufs_allocated.count(buf) && !interm_bufs.count(buf)) {
301       interm_bufs.insert(buf);
302 
303       // Identify the access stmts to each unallocated intermediate buffer.
304       auto range = BufLiveRange::liveRange(stmt_, buf);
305       interm_buf_ranges.emplace(buf, range);
306     }
307   }
308 
309   const auto bufs_external_allocs = ExternalAllocBufFinder::find(stmt_);
310 
311   // For each intermediate buffer, we reuse the memory of an old buffer whose
312   // liveness range does not overlap with the current buffer, or allocate memory
313   // if reusing buffer is impossible.
314   auto buf_allocs = AllocBufsWithMemReuse(
315       interm_bufs, interm_buf_ranges, bufs_external_allocs);
316 
317   // Insert memory allocation/mapping nodes.
318   if (!buf_allocs.empty()) {
319     auto stmt_new = insertAllocFree(buf_allocs, bufs_external_allocs, stmt_);
320     set_stmt(stmt_new);
321   }
322 
323   GRAPH_DEBUG("\nMemory Allocation:\n\n", *stmt(), "\n");
324 }
325 
326 } // namespace torch::jit::tensorexpr
327