1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/tensorexpr/analysis.h>
3 #include <torch/csrc/jit/tensorexpr/codegen.h>
4
5 #include <sstream>
6 #include <utility>
7
8 namespace torch::jit::tensorexpr {
9
CodeGen(StmtPtr stmt,std::vector<BufferArg> buffer_args,at::Device device,std::string kernel_func_name)10 CodeGen::CodeGen(
11 StmtPtr stmt,
12 std::vector<BufferArg> buffer_args,
13 at::Device device,
14 std::string kernel_func_name)
15 : stmt_(std::move(stmt)),
16 buffer_args_(std::move(buffer_args)),
17 device_(device),
18 kernel_func_name_(std::move(kernel_func_name)) {
19 ExtCallMemoryReuse extCallMemoryReuse(buffer_args_);
20 apply_mutator(&extCallMemoryReuse);
21 allocIntermediateBufs();
22 }
23
GetInstance()24 RegisterCodeGenList& RegisterCodeGenList::GetInstance() {
25 static RegisterCodeGenList codegen_list;
26 return codegen_list;
27 }
28
29 RegisterCodeGenList::StmtFactoryMethod RegisterCodeGenList::
FindStmtFactoryMethod(const std::string & name)30 FindStmtFactoryMethod(const std::string& name) {
31 auto iter = stmt_factory_methods_.find(name);
32 if (iter == stmt_factory_methods_.end()) {
33 std::ostringstream oss;
34 oss << "Invalid stmt codegen name: " << name << ". ";
35 oss << "Existing codegen names: [";
36 int index = 0;
37 for (auto& entry : stmt_factory_methods_) {
38 if (index != 0) {
39 oss << ", ";
40 }
41 oss << entry.first;
42 index++;
43 }
44 oss << "]";
45 throw std::runtime_error(oss.str());
46 }
47 return iter->second;
48 }
49
AddStmtFactoryMethod(const std::string & name,const StmtFactoryMethod & stmt_factory_method)50 void RegisterCodeGenList::AddStmtFactoryMethod(
51 const std::string& name,
52 const StmtFactoryMethod& stmt_factory_method) {
53 stmt_factory_methods_[name] = stmt_factory_method;
54 }
55
CreateCodeGen(const std::string & name,StmtPtr stmt,const std::vector<CodeGen::BufferArg> & params,at::Device device,const std::string & kernel_func_name)56 std::unique_ptr<CodeGen> CreateCodeGen(
57 const std::string& name,
58 StmtPtr stmt,
59 const std::vector<CodeGen::BufferArg>& params,
60 at::Device device,
61 const std::string& kernel_func_name) {
62 RegisterCodeGenList::StmtFactoryMethod method =
63 RegisterCodeGenList::GetInstance().FindStmtFactoryMethod(name);
64 return method(std::move(stmt), params, device, kernel_func_name);
65 }
66
mutate(const IntrinsicsPtr & v)67 ExprPtr GenericIntrinsicsExpander::mutate(const IntrinsicsPtr& v) {
68 if (v->op_type() == kSigmoid) {
69 auto x = v->param(0)->accept_mutator(this);
70 auto one = expr_to_vec(
71 ExprHandle(getImmediateByType(v->dtype(), 1.0)), v->dtype().lanes());
72 auto zero = expr_to_vec(
73 ExprHandle(getImmediateByType(v->dtype(), 0.0)), v->dtype().lanes());
74 ExprHandle y = one / (one + exp(zero - ExprHandle(x)));
75 return y.node();
76 }
77 return IRMutator::mutate(v);
78 }
79
argToPtr(const BufferArg & bufferArg,const CallArg & callArg)80 void* CodeGen::argToPtr(const BufferArg& bufferArg, const CallArg& callArg) {
81 if (!bufferArg.isVar()) {
82 return callArg.data();
83 }
84
85 switch (bufferArg.dtype().scalar_type()) {
86 #define TYPE_CASE(_1, Name) \
87 case ScalarType::Name: \
88 return callArg.Name##Ptr();
89
90 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
91 #undef TYPE_CASE
92
93 default:
94 throw unsupported_dtype();
95 }
96 return nullptr;
97 }
98
call_with_numel(void ** args,int64_t numel)99 void CodeGen::call_with_numel(void** args, int64_t numel) {
100 TORCH_INTERNAL_ASSERT(
101 false, "This codegen backend does not implement call_with_numel");
102 }
103
bufSize(const BufPtr & buf)104 static std::optional<size_t> bufSize(const BufPtr& buf) {
105 size_t size = elementSize(buf->dtype().scalar_type()) * buf->dtype().lanes();
106 for (auto& d : buf->dims()) {
107 if (!d->isConstant()) {
108 return std::nullopt;
109 }
110 size = size * (*intValue(d));
111 }
112 return size;
113 }
114
115 // This algorithm takes the list of intermediate buffers and their liveness
116 // ranges, and returns the allocations of these buffers. A buffer 'A' can be
117 // allocated in the memory (appears as a pair of 'A's in the allocation results)
118 // or reuse another buffer such as 'B' (appears as ('A', 'B')). Specifically, we
119 // linearly scan the intermediate buffers by the time they appear, and try to
120 // assign it an existing non-occupied memory allocation. If there are no such
121 // allocations available, we'll create memory for it. Once we are beyond the
122 // liveness range of this buffer, we'll mark its corresponding memory allocation
123 // as "up for grabs" for future reuse.
AllocBufsWithMemReuse(const std::unordered_set<BufPtr> & bufs,const std::unordered_map<BufPtr,std::tuple<int32_t,int32_t>> & buf_ranges,const std::unordered_set<BufPtr> & bufs_external_allocs)124 static std::vector<std::pair<BufPtr, BufPtr>> AllocBufsWithMemReuse(
125 const std::unordered_set<BufPtr>& bufs,
126 const std::unordered_map<BufPtr, std::tuple<int32_t, int32_t>>& buf_ranges,
127 const std::unordered_set<BufPtr>& bufs_external_allocs) {
128 // Sort buffers by the time they appear.
129 std::vector<BufPtr> bufs_sorted(bufs.begin(), bufs.end());
130 auto sorting_function_by_start_time =
131 [&buf_ranges](const BufPtr& b1, const BufPtr& b2) -> bool {
132 return std::get<0>(buf_ranges.at(b1)) < std::get<0>(buf_ranges.at(b2));
133 };
134 std::sort(
135 bufs_sorted.begin(), bufs_sorted.end(), sorting_function_by_start_time);
136
137 // Map intermediate buffers to the most recently used memory if any.
138 std::list<BufPtr> mem_up_for_grabs;
139 std::unordered_map<BufPtr, BufPtr> buf_mem_map;
140 std::vector<std::pair<BufPtr, BufPtr>> buf_allocs;
141
142 auto sorting_function_by_end_time =
143 [&buf_ranges](const BufPtr& b1, const BufPtr& b2) -> bool {
144 return std::get<1>(buf_ranges.at(b1)) < std::get<1>(buf_ranges.at(b2));
145 };
146 for (const auto& buf : bufs_sorted) {
147 // If the buf has dynamic shapes, we'll skip it (i.e., allocate memory for
148 // it, and there are no future reuses on its memory).
149 // TODO: reuse memory for bufs with dynamic shapes
150 if (!bufSize(buf)) {
151 buf_allocs.emplace_back(buf, buf);
152 continue;
153 }
154
155 auto start = std::get<0>(buf_ranges.at(buf));
156
157 // Release memory for buffers whose liveness range ends before the creation
158 // time of this buf.
159 // TODO: optimize in-place operations and copy operations
160 std::vector<BufPtr> buf_to_release;
161 for (auto& mapped : buf_mem_map) {
162 auto buf_mapped = mapped.first;
163 auto end_buf_mapped = std::get<1>(buf_ranges.at(buf_mapped));
164 if (end_buf_mapped < start) {
165 buf_to_release.push_back(buf_mapped);
166 }
167 }
168
169 // Sort the buffers in the order of used time so the head of the release
170 // list contains the most recently used buf.
171 std::sort(
172 buf_to_release.begin(),
173 buf_to_release.end(),
174 sorting_function_by_end_time);
175 for (auto& buf_rl : buf_to_release) {
176 mem_up_for_grabs.push_front(buf_mem_map.at(buf_rl));
177 buf_mem_map.erase(buf_rl);
178 }
179
180 bool allocated = false;
181 if (bufs_external_allocs.find(buf) == bufs_external_allocs.end()) {
182 // Check whether there are free memories that this buf can reuse.
183 for (auto it = mem_up_for_grabs.begin(); it != mem_up_for_grabs.end();
184 it++) {
185 auto m = *it;
186 if (bufSize(m) >= bufSize(buf)) {
187 buf_mem_map[buf] = m;
188 buf_allocs.emplace_back(buf, m);
189 allocated = true;
190 mem_up_for_grabs.erase(it);
191 break;
192 }
193 }
194 }
195
196 // If there are no memories to reuse, we'll have to allocate new memory for
197 // it.
198 if (!allocated) {
199 buf_mem_map[buf] = buf;
200 buf_allocs.emplace_back(buf, buf);
201 }
202 }
203
204 return buf_allocs;
205 }
206
insertAllocFree(std::vector<std::pair<BufPtr,BufPtr>> & buf_allocs,const std::unordered_set<BufPtr> & bufs_external_allocs,const StmtPtr & stmt)207 static StmtPtr insertAllocFree(
208 std::vector<std::pair<BufPtr, BufPtr>>& buf_allocs,
209 const std::unordered_set<BufPtr>& bufs_external_allocs,
210 const StmtPtr& stmt) {
211 BlockPtr b = to<Block>(stmt);
212 if (!b) {
213 b = alloc<Block>(std::vector<StmtPtr>({stmt}));
214 }
215
216 std::vector<BufPtr> bufs_ext_to_free;
217 // Insert allocations and frees for temporary buffers at global scope.
218 for (auto rit = buf_allocs.rbegin(); rit != buf_allocs.rend(); ++rit) {
219 if (rit->first == rit->second) {
220 BufPtr buf = rit->first;
221 if (bufs_external_allocs.find(buf) == bufs_external_allocs.end()) {
222 b->prepend_stmt(alloc<Allocate>(buf));
223 b->append_stmt(alloc<Free>(buf));
224 } else {
225 bufs_ext_to_free.push_back(buf);
226 }
227 } else {
228 b->prepend_stmt(alloc<PlacementAllocate>(rit->first, rit->second));
229 }
230 }
231
232 b->append_stmt(alloc<FreeExt>(bufs_ext_to_free));
233 return b;
234 }
235
236 std::unordered_map<std::string, std::string> ExtCallMemoryReuse::
makeExtCallFuncNameMap()237 makeExtCallFuncNameMap() {
238 return {
239 {"nnc_aten_quantize_per_tensor", "nnc_aten_quantize_per_tensor_out"},
240 {"nnc_aten_dequantize", "nnc_aten_dequantize_out"},
241 {"nnc_aten_quantized_mul", "nnc_aten_quantized_mul_out"},
242 {"nnc_aten_quantized_conv2d", "nnc_aten_quantized_conv2d_out"},
243 {"nnc_aten_quantized_conv2d_relu", "nnc_aten_quantized_conv2d_relu_out"},
244 {"nnc_aten_quantized_mul", "nnc_aten_quantized_mul_out"},
245 {"nnc_aten_quantized_sigmoid", "nnc_aten_quantized_sigmoid_out"},
246 {"nnc_aten_upsample_nearest2d", "nnc_aten_upsample_nearest2d_out"},
247 {"nnc_aten_quantized_linear", "nnc_aten_quantized_linear_out"},
248 {"nnc_aten_quantized_conv1d", "nnc_aten_quantized_conv1d_out"},
249 {"nnc_aten_quantized_mul_scalar", "nnc_aten_quantized_mul_scalar_out"},
250 {"nnc_aten_max_red", "nnc_aten_max_red_out"},
251 {"nnc_aten_conv1d", "nnc_aten_conv1d_out"},
252 };
253 }
254
255 const std::unordered_map<std::string, std::string>
256 ExtCallMemoryReuse::extCallFuncNameMap_ = makeExtCallFuncNameMap();
257
ExtCallMemoryReuse(const std::vector<CodeGen::BufferArg> & bufferArgs)258 ExtCallMemoryReuse::ExtCallMemoryReuse(
259 const std::vector<CodeGen::BufferArg>& bufferArgs) {
260 for (const auto& ba : bufferArgs) {
261 if (ba.buf()) {
262 bufferArgs_.insert(ba.buf());
263 }
264 }
265 }
266
mutate(const ExternalCallPtr & v)267 StmtPtr ExtCallMemoryReuse::mutate(const ExternalCallPtr& v) {
268 if (extCallFuncNameMap_.count(v->func_name()) &&
269 bufferArgs_.count(v->buf()) == 0) {
270 std::vector<BufPtr> buf_out_args = {v->buf()};
271 return alloc<ExternalCallWithAlloc>(
272 extCallFuncNameMap_.at(v->func_name()),
273 buf_out_args,
274 v->buf_args(),
275 v->args());
276 }
277 return v;
278 }
279
280 // We allocate intermediate buffers by inserting Allocate/Free or
281 // PlacementAllocate stmts. Allocate/Free stmts will allocate memory at runtime,
282 // and PlacementAllocate stmt reuses the memory of one buffer for another
283 // buffer. In current implementation, we use linear scan for memory reuses.
284 // TODO: try more memory reuse algorithms and compare their memory efficiency.
allocIntermediateBufs()285 void CodeGen::allocIntermediateBufs() {
286 // Identify intermediate buffers that are not allocated yet.
287 auto bufs = NodeFinder<Buf>::find(stmt_);
288 std::unordered_set<BufPtr> bufs_allocated;
289 for (const auto& b : buffer_args_) {
290 bufs_allocated.insert(b.buf());
291 }
292 auto allocs = NodeFinder<Allocate>::find(stmt_);
293 for (const auto& a : allocs) {
294 bufs_allocated.insert(a->buf());
295 }
296
297 std::unordered_set<BufPtr> interm_bufs;
298 std::unordered_map<BufPtr, std::tuple<int32_t, int32_t>> interm_buf_ranges;
299 for (const auto& buf : bufs) {
300 if (!bufs_allocated.count(buf) && !interm_bufs.count(buf)) {
301 interm_bufs.insert(buf);
302
303 // Identify the access stmts to each unallocated intermediate buffer.
304 auto range = BufLiveRange::liveRange(stmt_, buf);
305 interm_buf_ranges.emplace(buf, range);
306 }
307 }
308
309 const auto bufs_external_allocs = ExternalAllocBufFinder::find(stmt_);
310
311 // For each intermediate buffer, we reuse the memory of an old buffer whose
312 // liveness range does not overlap with the current buffer, or allocate memory
313 // if reusing buffer is impossible.
314 auto buf_allocs = AllocBufsWithMemReuse(
315 interm_bufs, interm_buf_ranges, bufs_external_allocs);
316
317 // Insert memory allocation/mapping nodes.
318 if (!buf_allocs.empty()) {
319 auto stmt_new = insertAllocFree(buf_allocs, bufs_external_allocs, stmt_);
320 set_stmt(stmt_new);
321 }
322
323 GRAPH_DEBUG("\nMemory Allocation:\n\n", *stmt(), "\n");
324 }
325
326 } // namespace torch::jit::tensorexpr
327