1 #include <torch/csrc/jit/runtime/static/te_wrapper.h>
2
3 #include <ATen/CPUFunctions.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/jit/runtime/static/impl.h>
7 #include <torch/csrc/jit/tensorexpr/expr.h>
8 #include <torch/csrc/jit/tensorexpr/operators/misc.h>
9 #include <torch/csrc/jit/tensorexpr/operators/operators.h>
10
11 #include <utility>
12
13 namespace torch::jit {
14
15 using namespace torch::jit::tensorexpr;
16
17 // Use the width of an AVX-512 vector by default; this happens to work OK for
18 // AVX2 as well. Some ops benefit from using multiple AVX ports, in which case
19 // they are vectorized by twice this constant. An exception is logit, since it
20 // contains FP divide, which is single-ported.
21 static constexpr int kVectorWidth = 16;
22
23 #ifdef TORCH_ENABLE_LLVM
24
update(std::unique_ptr<LLVMCodeGen> && cg_)25 void TEWrapper::update(std::unique_ptr<LLVMCodeGen>&& cg_) {
26 cg = std::move(cg_);
27 }
28
call(const std::vector<void * > & args)29 void TEWrapper::call(const std::vector<void*>& args) {
30 cg->call_raw(args);
31 }
32
optimizePointwise(LoopNest * ln,Tensor target,int width)33 static void optimizePointwise(LoopNest* ln, Tensor target, int width) {
34 std::vector<ForPtr> loops = ln->getLoopStmtsFor(target);
35 ForPtr inner, tail;
36 TORCH_CHECK(loops.size() > 0, "No loops created for pointwise op");
37 ln->splitWithTail(loops[0], width, &inner, &tail);
38 ln->vectorize(inner);
39 }
40
wrapTECompute(std::shared_ptr<TEWrapper> wrap,Tensor out,std::vector<CodeGen::BufferArg> args,int width=kVectorWidth)41 static std::shared_ptr<TEWrapper> wrapTECompute(
42 std::shared_ptr<TEWrapper> wrap,
43 Tensor out,
44 std::vector<CodeGen::BufferArg> args,
45 int width = kVectorWidth) {
46 LoopNest ln({out});
47 optimizePointwise(&ln, out, width);
48 ln.prepareForCodegen();
49 StmtPtr s = ln.root_stmt();
50 s = IRSimplifier::simplify(s);
51 args.insert(args.begin(), out);
52 auto cg = std::make_unique<LLVMCodeGen>(s, args);
53 cg->cleanup_memory();
54 wrap->update(std::move(cg));
55 return wrap;
56 }
57
wrapTECompute(std::shared_ptr<TEWrapper> wrap,LoopNest * ln,std::vector<CodeGen::BufferArg> args)58 static std::shared_ptr<TEWrapper> wrapTECompute(
59 std::shared_ptr<TEWrapper> wrap,
60 LoopNest* ln,
61 std::vector<CodeGen::BufferArg> args) {
62 auto cg = std::make_unique<LLVMCodeGen>(ln->root_stmt(), args);
63 wrap->update(std::move(cg));
64 return wrap;
65 }
66
67 #else
68
call(const std::vector<void * > & args)69 void TEWrapper::call(const std::vector<void*>& args) {
70 DCHECK(0 && "Invalid call");
71 }
72
wrapTECompute(std::shared_ptr<TEWrapper> wrap,const Tensor & out,const std::vector<CodeGen::BufferArg> & args,int width=kVectorWidth)73 static std::shared_ptr<TEWrapper> wrapTECompute(
74 std::shared_ptr<TEWrapper> wrap,
75 const Tensor& out,
76 const std::vector<CodeGen::BufferArg>& args,
77 int width = kVectorWidth) {
78 return wrap;
79 }
80
wrapTECompute(std::shared_ptr<TEWrapper> wrap,LoopNest * ln,const std::vector<CodeGen::BufferArg> & args)81 static std::shared_ptr<TEWrapper> wrapTECompute(
82 std::shared_ptr<TEWrapper> wrap,
83 LoopNest* ln,
84 const std::vector<CodeGen::BufferArg>& args) {
85 return wrap;
86 }
87
88 #endif
89
90 namespace {
91
getNNCCacheMutex()92 std::mutex& getNNCCacheMutex() {
93 static std::mutex nncCacheMutex;
94 return nncCacheMutex;
95 }
96
getNNCCache()97 c10::FastMap<NodeKind, std::shared_ptr<TEWrapper>>& getNNCCache() {
98 static c10::FastMap<NodeKind, std::shared_ptr<TEWrapper>> nncCache;
99 return nncCache;
100 }
101
lookupNNCCache(NodeKind kind)102 std::shared_ptr<TEWrapper> lookupNNCCache(NodeKind kind) {
103 std::lock_guard<std::mutex> lock(getNNCCacheMutex());
104 auto it = getNNCCache().find(kind);
105 if (it != getNNCCache().end()) {
106 return it->second;
107 }
108 return nullptr;
109 }
110
updateNNCCache(NodeKind kind,std::shared_ptr<TEWrapper> code)111 void updateNNCCache(NodeKind kind, std::shared_ptr<TEWrapper> code) {
112 std::lock_guard<std::mutex> lock(getNNCCacheMutex());
113 getNNCCache()[kind] = std::move(code);
114 }
115
116 } // namespace
117
createDiv()118 std::shared_ptr<TEWrapper> createDiv() {
119 auto wrap = lookupNNCCache(aten::div);
120 if (wrap) {
121 return wrap;
122 }
123 wrap = std::make_shared<TEWrapper>();
124
125 auto dim = VarHandle("dim", kInt);
126 auto mode = VarHandle("mode", kInt);
127 BufHandle A("A", {dim}, kFloat);
128 BufHandle B("B", {dim}, kFloat);
129
130 using axis = const VarHandle&;
131 Tensor C = Compute("C", {dim}, [&](axis x) {
132 auto true_div_result = A.load(x) / B.load(x);
133
134 auto mode_default = IntImm::make(0);
135 auto mode_trunc = IntImm::make(1);
136 auto mode_floor = IntImm::make(2);
137
138 // this is a glorified ternary choice operator train
139 return CompareSelect::make(
140 mode,
141 mode_default,
142 true_div_result,
143 CompareSelect::make(
144 mode,
145 mode_trunc,
146 trunc(true_div_result),
147 floor(true_div_result),
148 kEQ),
149 kEQ);
150 });
151
152 wrap = wrapTECompute(wrap, C, {A, B, mode, dim});
153
154 updateNNCCache(aten::div, wrap);
155 return wrap;
156 }
157
createLogit()158 std::shared_ptr<TEWrapper> createLogit() {
159 auto wrap = lookupNNCCache(aten::logit);
160 if (wrap) {
161 return wrap;
162 }
163 wrap = std::make_shared<TEWrapper>();
164 auto N = VarHandle("N", kInt);
165 auto C = VarHandle("C", kFloat);
166 BufHandle A("A", {N}, kFloat);
167 Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
168 auto A_elem = [&]() {
169 auto elem = A.load(i);
170 auto one = FloatImm::make(1.0f);
171 const auto& min = C;
172 auto max = one - C;
173 elem = CompareSelect::make(elem, min, min, elem, kLT);
174 return CompareSelect::make(elem, max, max, elem, kGT);
175 }();
176 return log_vml(A_elem / (FloatImm::make(1.0f) - A_elem));
177 });
178 wrap = wrapTECompute(wrap, B, {A, N, C});
179 updateNNCCache(aten::logit, wrap);
180 return wrap;
181 }
182
createRelu()183 std::shared_ptr<TEWrapper> createRelu() {
184 auto wrap = lookupNNCCache(aten::relu);
185 if (wrap) {
186 return wrap;
187 }
188 wrap = std::make_shared<TEWrapper>();
189 auto N = VarHandle("N", kInt);
190 BufHandle A("A", {N}, kFloat);
191 Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
192 auto zero = FloatImm::make(0.f);
193 auto a = A.load(i);
194 return CompareSelect::make(a, zero, zero, a, kLT);
195 });
196 wrap = wrapTECompute(wrap, B, {A, N});
197 updateNNCCache(aten::relu, wrap);
198 return wrap;
199 }
200
createTanh()201 std::shared_ptr<TEWrapper> createTanh() {
202 auto wrap = lookupNNCCache(aten::tanh);
203 if (wrap) {
204 return wrap;
205 }
206 wrap = std::make_shared<TEWrapper>();
207 auto N = VarHandle("N", kInt);
208 BufHandle A("A", {N}, kFloat);
209 Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
210 auto a = A.load(i);
211 return fast_tanh(a);
212 });
213 wrap = wrapTECompute(wrap, B, {A, N});
214 updateNNCCache(aten::tanh, wrap);
215 return wrap;
216 }
217
createSigmoid()218 std::shared_ptr<TEWrapper> createSigmoid() {
219 auto wrap = lookupNNCCache(aten::sigmoid);
220 if (wrap) {
221 return wrap;
222 }
223 wrap = std::make_shared<TEWrapper>();
224 auto N = VarHandle("N", kInt);
225 BufHandle A("A", {N}, kFloat);
226 Tensor B = Compute(
227 "B", {N}, [&](const VarHandle& i) { return fast_sigmoid(A.load(i)); });
228 wrap = wrapTECompute(wrap, B, {A, N});
229 updateNNCCache(aten::sigmoid, wrap);
230 return wrap;
231 }
232
createClamp()233 std::shared_ptr<TEWrapper> createClamp() {
234 static auto clamp_symbol = c10::Symbol::fromQualString("aten::clamp");
235 auto wrap = lookupNNCCache(clamp_symbol);
236 if (wrap) {
237 return wrap;
238 }
239 wrap = std::make_shared<TEWrapper>();
240 auto N = VarHandle("N", kInt);
241 auto min_handle = VarHandle("min", kFloat);
242 auto max_handle = VarHandle("max", kFloat);
243
244 BufHandle A("A", {N}, kFloat);
245 Tensor result = Compute("aten_clamp", {N}, [&](const VarHandle& i) {
246 auto a = A.load(i);
247 return tensorexpr::clamp(min_handle, max_handle, a);
248 });
249 wrap = wrapTECompute(wrap, result, {A, min_handle, max_handle, N});
250 updateNNCCache(clamp_symbol, wrap);
251 return wrap;
252 }
253
createClampNanToNum()254 std::shared_ptr<TEWrapper> createClampNanToNum() {
255 static auto symbol =
256 c10::Symbol::fromQualString("static_runtime::clamp_nan_to_num");
257 auto wrap = lookupNNCCache(symbol);
258 if (wrap) {
259 return wrap;
260 }
261 wrap = std::make_shared<TEWrapper>();
262 auto N = VarHandle("N", kInt);
263 auto min_handle = VarHandle("min", kFloat);
264 auto max_handle = VarHandle("max", kFloat);
265 auto nan_replace_val = VarHandle("nan_replace_val", kFloat);
266
267 BufHandle A("A", {N}, kFloat);
268 Tensor result = Compute("aten_clamp", {N}, [&](const VarHandle& i) {
269 auto a = A.load(i);
270 auto clamp = tensorexpr::clamp(min_handle, max_handle, a);
271 auto is_nan = tensorexpr::isnan(clamp);
272 auto nans_replaced =
273 tensorexpr::CompareSelect::make(is_nan, 1, nan_replace_val, clamp, kEQ);
274 return nans_replaced;
275 });
276 wrap = wrapTECompute(
277 wrap, result, {A, min_handle, max_handle, nan_replace_val, N});
278 updateNNCCache(symbol, wrap);
279 return wrap;
280 }
281
createSignedLog1p()282 std::shared_ptr<TEWrapper> createSignedLog1p() {
283 static auto signed_log1p_symbol =
284 c10::Symbol::fromQualString("static_runtime::signed_log1p");
285 auto wrap = lookupNNCCache(signed_log1p_symbol);
286 if (wrap) {
287 return wrap;
288 }
289 wrap = std::make_shared<TEWrapper>();
290 auto N = VarHandle("N", kInt);
291 BufHandle A("A", {N}, kFloat);
292 Tensor abs_result = Compute("aten_abs", {N}, [&](const VarHandle& i) {
293 return tensorexpr::abs(A.load(i));
294 });
295 Tensor log1p_result = Compute("aten_log1p", {N}, [&](const VarHandle& i) {
296 return log1p(abs_result.load(i));
297 });
298 Tensor sign = computeSign({A}, {N});
299 Tensor output = Compute("aten_mul", {N}, [&](const VarHandle& i) {
300 return sign.load(i) * log1p_result.load(i);
301 });
302 LoopNest ln({output}, {abs_result, log1p_result, sign, output});
303 GRAPH_DEBUG("Original stmt: ", *ln.root_stmt());
304 ln.inlineIntermediateBufs(true);
305 ln.prepareForCodegen();
306 ln.simplify();
307 ln.vectorizeInnerLoops();
308 ln.simplify();
309 GRAPH_DEBUG("Final stmt: ", *ln.root_stmt());
310 wrap = wrapTECompute(wrap, &ln, {output, A, N});
311 updateNNCCache(signed_log1p_symbol, wrap);
312 return wrap;
313 }
314
315 } // namespace torch::jit
316