xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/te_wrapper.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/static/te_wrapper.h>
2 
3 #include <ATen/CPUFunctions.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/jit/runtime/static/impl.h>
7 #include <torch/csrc/jit/tensorexpr/expr.h>
8 #include <torch/csrc/jit/tensorexpr/operators/misc.h>
9 #include <torch/csrc/jit/tensorexpr/operators/operators.h>
10 
11 #include <utility>
12 
13 namespace torch::jit {
14 
15 using namespace torch::jit::tensorexpr;
16 
17 // Use the width of an AVX-512 vector by default; this happens to work OK for
18 // AVX2 as well. Some ops benefit from using multiple AVX ports, in which case
19 // they are vectorized by twice this constant.  An exception is logit, since it
20 // contains FP divide, which is single-ported.
21 static constexpr int kVectorWidth = 16;
22 
23 #ifdef TORCH_ENABLE_LLVM
24 
update(std::unique_ptr<LLVMCodeGen> && cg_)25 void TEWrapper::update(std::unique_ptr<LLVMCodeGen>&& cg_) {
26   cg = std::move(cg_);
27 }
28 
call(const std::vector<void * > & args)29 void TEWrapper::call(const std::vector<void*>& args) {
30   cg->call_raw(args);
31 }
32 
optimizePointwise(LoopNest * ln,Tensor target,int width)33 static void optimizePointwise(LoopNest* ln, Tensor target, int width) {
34   std::vector<ForPtr> loops = ln->getLoopStmtsFor(target);
35   ForPtr inner, tail;
36   TORCH_CHECK(loops.size() > 0, "No loops created for pointwise op");
37   ln->splitWithTail(loops[0], width, &inner, &tail);
38   ln->vectorize(inner);
39 }
40 
wrapTECompute(std::shared_ptr<TEWrapper> wrap,Tensor out,std::vector<CodeGen::BufferArg> args,int width=kVectorWidth)41 static std::shared_ptr<TEWrapper> wrapTECompute(
42     std::shared_ptr<TEWrapper> wrap,
43     Tensor out,
44     std::vector<CodeGen::BufferArg> args,
45     int width = kVectorWidth) {
46   LoopNest ln({out});
47   optimizePointwise(&ln, out, width);
48   ln.prepareForCodegen();
49   StmtPtr s = ln.root_stmt();
50   s = IRSimplifier::simplify(s);
51   args.insert(args.begin(), out);
52   auto cg = std::make_unique<LLVMCodeGen>(s, args);
53   cg->cleanup_memory();
54   wrap->update(std::move(cg));
55   return wrap;
56 }
57 
wrapTECompute(std::shared_ptr<TEWrapper> wrap,LoopNest * ln,std::vector<CodeGen::BufferArg> args)58 static std::shared_ptr<TEWrapper> wrapTECompute(
59     std::shared_ptr<TEWrapper> wrap,
60     LoopNest* ln,
61     std::vector<CodeGen::BufferArg> args) {
62   auto cg = std::make_unique<LLVMCodeGen>(ln->root_stmt(), args);
63   wrap->update(std::move(cg));
64   return wrap;
65 }
66 
67 #else
68 
call(const std::vector<void * > & args)69 void TEWrapper::call(const std::vector<void*>& args) {
70   DCHECK(0 && "Invalid call");
71 }
72 
wrapTECompute(std::shared_ptr<TEWrapper> wrap,const Tensor & out,const std::vector<CodeGen::BufferArg> & args,int width=kVectorWidth)73 static std::shared_ptr<TEWrapper> wrapTECompute(
74     std::shared_ptr<TEWrapper> wrap,
75     const Tensor& out,
76     const std::vector<CodeGen::BufferArg>& args,
77     int width = kVectorWidth) {
78   return wrap;
79 }
80 
wrapTECompute(std::shared_ptr<TEWrapper> wrap,LoopNest * ln,const std::vector<CodeGen::BufferArg> & args)81 static std::shared_ptr<TEWrapper> wrapTECompute(
82     std::shared_ptr<TEWrapper> wrap,
83     LoopNest* ln,
84     const std::vector<CodeGen::BufferArg>& args) {
85   return wrap;
86 }
87 
88 #endif
89 
90 namespace {
91 
getNNCCacheMutex()92 std::mutex& getNNCCacheMutex() {
93   static std::mutex nncCacheMutex;
94   return nncCacheMutex;
95 }
96 
getNNCCache()97 c10::FastMap<NodeKind, std::shared_ptr<TEWrapper>>& getNNCCache() {
98   static c10::FastMap<NodeKind, std::shared_ptr<TEWrapper>> nncCache;
99   return nncCache;
100 }
101 
lookupNNCCache(NodeKind kind)102 std::shared_ptr<TEWrapper> lookupNNCCache(NodeKind kind) {
103   std::lock_guard<std::mutex> lock(getNNCCacheMutex());
104   auto it = getNNCCache().find(kind);
105   if (it != getNNCCache().end()) {
106     return it->second;
107   }
108   return nullptr;
109 }
110 
updateNNCCache(NodeKind kind,std::shared_ptr<TEWrapper> code)111 void updateNNCCache(NodeKind kind, std::shared_ptr<TEWrapper> code) {
112   std::lock_guard<std::mutex> lock(getNNCCacheMutex());
113   getNNCCache()[kind] = std::move(code);
114 }
115 
116 } // namespace
117 
createDiv()118 std::shared_ptr<TEWrapper> createDiv() {
119   auto wrap = lookupNNCCache(aten::div);
120   if (wrap) {
121     return wrap;
122   }
123   wrap = std::make_shared<TEWrapper>();
124 
125   auto dim = VarHandle("dim", kInt);
126   auto mode = VarHandle("mode", kInt);
127   BufHandle A("A", {dim}, kFloat);
128   BufHandle B("B", {dim}, kFloat);
129 
130   using axis = const VarHandle&;
131   Tensor C = Compute("C", {dim}, [&](axis x) {
132     auto true_div_result = A.load(x) / B.load(x);
133 
134     auto mode_default = IntImm::make(0);
135     auto mode_trunc = IntImm::make(1);
136     auto mode_floor = IntImm::make(2);
137 
138     // this is a glorified ternary choice operator train
139     return CompareSelect::make(
140         mode,
141         mode_default,
142         true_div_result,
143         CompareSelect::make(
144             mode,
145             mode_trunc,
146             trunc(true_div_result),
147             floor(true_div_result),
148             kEQ),
149         kEQ);
150   });
151 
152   wrap = wrapTECompute(wrap, C, {A, B, mode, dim});
153 
154   updateNNCCache(aten::div, wrap);
155   return wrap;
156 }
157 
createLogit()158 std::shared_ptr<TEWrapper> createLogit() {
159   auto wrap = lookupNNCCache(aten::logit);
160   if (wrap) {
161     return wrap;
162   }
163   wrap = std::make_shared<TEWrapper>();
164   auto N = VarHandle("N", kInt);
165   auto C = VarHandle("C", kFloat);
166   BufHandle A("A", {N}, kFloat);
167   Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
168     auto A_elem = [&]() {
169       auto elem = A.load(i);
170       auto one = FloatImm::make(1.0f);
171       const auto& min = C;
172       auto max = one - C;
173       elem = CompareSelect::make(elem, min, min, elem, kLT);
174       return CompareSelect::make(elem, max, max, elem, kGT);
175     }();
176     return log_vml(A_elem / (FloatImm::make(1.0f) - A_elem));
177   });
178   wrap = wrapTECompute(wrap, B, {A, N, C});
179   updateNNCCache(aten::logit, wrap);
180   return wrap;
181 }
182 
createRelu()183 std::shared_ptr<TEWrapper> createRelu() {
184   auto wrap = lookupNNCCache(aten::relu);
185   if (wrap) {
186     return wrap;
187   }
188   wrap = std::make_shared<TEWrapper>();
189   auto N = VarHandle("N", kInt);
190   BufHandle A("A", {N}, kFloat);
191   Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
192     auto zero = FloatImm::make(0.f);
193     auto a = A.load(i);
194     return CompareSelect::make(a, zero, zero, a, kLT);
195   });
196   wrap = wrapTECompute(wrap, B, {A, N});
197   updateNNCCache(aten::relu, wrap);
198   return wrap;
199 }
200 
createTanh()201 std::shared_ptr<TEWrapper> createTanh() {
202   auto wrap = lookupNNCCache(aten::tanh);
203   if (wrap) {
204     return wrap;
205   }
206   wrap = std::make_shared<TEWrapper>();
207   auto N = VarHandle("N", kInt);
208   BufHandle A("A", {N}, kFloat);
209   Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
210     auto a = A.load(i);
211     return fast_tanh(a);
212   });
213   wrap = wrapTECompute(wrap, B, {A, N});
214   updateNNCCache(aten::tanh, wrap);
215   return wrap;
216 }
217 
createSigmoid()218 std::shared_ptr<TEWrapper> createSigmoid() {
219   auto wrap = lookupNNCCache(aten::sigmoid);
220   if (wrap) {
221     return wrap;
222   }
223   wrap = std::make_shared<TEWrapper>();
224   auto N = VarHandle("N", kInt);
225   BufHandle A("A", {N}, kFloat);
226   Tensor B = Compute(
227       "B", {N}, [&](const VarHandle& i) { return fast_sigmoid(A.load(i)); });
228   wrap = wrapTECompute(wrap, B, {A, N});
229   updateNNCCache(aten::sigmoid, wrap);
230   return wrap;
231 }
232 
createClamp()233 std::shared_ptr<TEWrapper> createClamp() {
234   static auto clamp_symbol = c10::Symbol::fromQualString("aten::clamp");
235   auto wrap = lookupNNCCache(clamp_symbol);
236   if (wrap) {
237     return wrap;
238   }
239   wrap = std::make_shared<TEWrapper>();
240   auto N = VarHandle("N", kInt);
241   auto min_handle = VarHandle("min", kFloat);
242   auto max_handle = VarHandle("max", kFloat);
243 
244   BufHandle A("A", {N}, kFloat);
245   Tensor result = Compute("aten_clamp", {N}, [&](const VarHandle& i) {
246     auto a = A.load(i);
247     return tensorexpr::clamp(min_handle, max_handle, a);
248   });
249   wrap = wrapTECompute(wrap, result, {A, min_handle, max_handle, N});
250   updateNNCCache(clamp_symbol, wrap);
251   return wrap;
252 }
253 
createClampNanToNum()254 std::shared_ptr<TEWrapper> createClampNanToNum() {
255   static auto symbol =
256       c10::Symbol::fromQualString("static_runtime::clamp_nan_to_num");
257   auto wrap = lookupNNCCache(symbol);
258   if (wrap) {
259     return wrap;
260   }
261   wrap = std::make_shared<TEWrapper>();
262   auto N = VarHandle("N", kInt);
263   auto min_handle = VarHandle("min", kFloat);
264   auto max_handle = VarHandle("max", kFloat);
265   auto nan_replace_val = VarHandle("nan_replace_val", kFloat);
266 
267   BufHandle A("A", {N}, kFloat);
268   Tensor result = Compute("aten_clamp", {N}, [&](const VarHandle& i) {
269     auto a = A.load(i);
270     auto clamp = tensorexpr::clamp(min_handle, max_handle, a);
271     auto is_nan = tensorexpr::isnan(clamp);
272     auto nans_replaced =
273         tensorexpr::CompareSelect::make(is_nan, 1, nan_replace_val, clamp, kEQ);
274     return nans_replaced;
275   });
276   wrap = wrapTECompute(
277       wrap, result, {A, min_handle, max_handle, nan_replace_val, N});
278   updateNNCCache(symbol, wrap);
279   return wrap;
280 }
281 
createSignedLog1p()282 std::shared_ptr<TEWrapper> createSignedLog1p() {
283   static auto signed_log1p_symbol =
284       c10::Symbol::fromQualString("static_runtime::signed_log1p");
285   auto wrap = lookupNNCCache(signed_log1p_symbol);
286   if (wrap) {
287     return wrap;
288   }
289   wrap = std::make_shared<TEWrapper>();
290   auto N = VarHandle("N", kInt);
291   BufHandle A("A", {N}, kFloat);
292   Tensor abs_result = Compute("aten_abs", {N}, [&](const VarHandle& i) {
293     return tensorexpr::abs(A.load(i));
294   });
295   Tensor log1p_result = Compute("aten_log1p", {N}, [&](const VarHandle& i) {
296     return log1p(abs_result.load(i));
297   });
298   Tensor sign = computeSign({A}, {N});
299   Tensor output = Compute("aten_mul", {N}, [&](const VarHandle& i) {
300     return sign.load(i) * log1p_result.load(i);
301   });
302   LoopNest ln({output}, {abs_result, log1p_result, sign, output});
303   GRAPH_DEBUG("Original stmt: ", *ln.root_stmt());
304   ln.inlineIntermediateBufs(true);
305   ln.prepareForCodegen();
306   ln.simplify();
307   ln.vectorizeInnerLoops();
308   ln.simplify();
309   GRAPH_DEBUG("Final stmt: ", *ln.root_stmt());
310   wrap = wrapTECompute(wrap, &ln, {output, A, N});
311   updateNNCCache(signed_log1p_symbol, wrap);
312   return wrap;
313 }
314 
315 } // namespace torch::jit
316