1 #include <torch/csrc/jit/runtime/static/passes.h>
2
3 #include <torch/csrc/jit/ir/alias_analysis.h>
4 #include <torch/csrc/jit/ir/subgraph_matcher.h>
5 #include <torch/csrc/jit/passes/constant_pooling.h>
6 #include <torch/csrc/jit/passes/constant_propagation.h>
7 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
8 #include <torch/csrc/jit/passes/variadic_ops.h>
9 #include <torch/csrc/jit/runtime/graph_iterator.h>
10 #include <torch/csrc/jit/runtime/static/ops.h>
11
12 C10_DEFINE_bool(
13 enable_clip_ranges_gather_fusions,
14 true,
15 "If on, static runtime or optimize_sparse_nn_model will fuse clip ranges gather ops.");
16
17 namespace torch::jit {
18
graphHasOp(std::shared_ptr<Graph> & graph,const char * op_name)19 bool graphHasOp(std::shared_ptr<Graph>& graph, const char* op_name) {
20 DepthFirstGraphNodeIterator graph_it(graph);
21 for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
22 const char* node_qual_string = node->kind().toQualString();
23 if (strcmp(node_qual_string, op_name) == 0) {
24 return true;
25 }
26 }
27 return false;
28 }
29
forwardHasOp(const torch::jit::script::Module & module,const char * op_name)30 bool forwardHasOp(
31 const torch::jit::script::Module& module,
32 const char* op_name) {
33 using Method = ::torch::jit::Method;
34 Method method = module.get_method("forward");
35 auto graph = method.graph();
36 return graphHasOp(graph, op_name);
37 }
38
39 namespace {
40 C10_UNUSED
ConcatAddMulReplaceNaNClip(std::shared_ptr<torch::jit::Graph> & graph)41 void ConcatAddMulReplaceNaNClip(std::shared_ptr<torch::jit::Graph>& graph) {
42 // TODO:: check restrictions for inputs; outputs not used elsewhere
43 std::string pattern = R"IR(
44 graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
45 %y0 = aten::cat(%a, %b)
46 %y1 = aten::add(%y0, %c, %d)
47 %y2 = aten::mul(%y1, %e)
48 %y3 = aten::nan_to_num(%y2, %f, %g, %h)
49 %res = aten::clamp(%y3, %i, %j)
50 return (%res))IR";
51 std::string pattern2 = R"IR(
52 graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
53 %y0 = aten::cat(%a, %b)
54 %y1 = aten::add(%y0, %c, %d)
55 %y2 = aten::mul(%y1, %e)
56 %y3 = aten::nan_to_num_(%y2, %f, %g, %h)
57 %res = aten::clamp(%y3, %i, %j)
58 return (%res))IR";
59 std::string pattern3 = R"IR(
60 graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
61 %y0 = aten::cat(%a, %b)
62 %y1 = aten::add(%y0, %c, %d)
63 %y2 = aten::mul(%y1, %e)
64 %y3 = aten::nan_to_num_(%y2, %f, %g, %h)
65 %res = aten::clamp_(%y3, %i, %j)
66 return (%res))IR";
67 std::string pattern4 = R"IR(
68 graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
69 %y0 = aten::cat(%a, %b)
70 %y1 = aten::add(%y0, %c, %d)
71 %y2 = aten::mul(%y1, %e)
72 %y3 = aten::nan_to_num(%y2, %f, %g, %h)
73 %res = aten::clamp_(%y3, %i, %j)
74 return (%res))IR";
75 std::string fused_pattern = R"IR(
76 graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
77 %res = fb::concat_add_mul_replacenan_clip(%c, %e, %a, %i, %j, %b)
78 return (%res))IR";
79
80 SubgraphRewriter fuse;
81 fuse.RegisterRewritePattern(pattern, fused_pattern);
82 fuse.runOnGraph(graph);
83
84 fuse.RegisterRewritePattern(pattern2, fused_pattern);
85 fuse.runOnGraph(graph);
86
87 fuse.RegisterRewritePattern(pattern3, fused_pattern);
88 fuse.runOnGraph(graph);
89
90 fuse.RegisterRewritePattern(pattern4, fused_pattern);
91 fuse.runOnGraph(graph);
92 }
93
94 C10_UNUSED
CastedBatchOneHotLengths(std::shared_ptr<torch::jit::Graph> & graph)95 void CastedBatchOneHotLengths(std::shared_ptr<torch::jit::Graph>& graph) {
96 // TODO:: check restrictions for inputs; outputs not used elsewhere
97 std::string pattern = R"IR(
98 graph(%a, %b, %c, %d, %e, %f, %g):
99 %y0 : Tensor = aten::to(%a, %b, %c, %c, %d)
100 %y1 : Tensor = fb::batch_one_hot_lengths(%y0, %e, %f)
101 %res : Tensor = aten::to(%y1, %g, %c, %c, %d)
102 return (%res))IR";
103 std::string fused_pattern = R"IR(
104 graph(%a, %b, %c, %d, %e, %f, %g):
105 %res : Tensor = fb::casted_batch_one_hot_lengths(%a, %e, %f)
106 return (%res))IR";
107 SubgraphRewriter fuse;
108 fuse.RegisterRewritePattern(pattern, fused_pattern);
109 fuse.runOnGraph(graph);
110
111 std::string pattern2 = R"IR(
112 graph(%a, %b, %c, %d, %e, %f):
113 %y0 : Tensor = aten::to(%a, %b, %c, %c)
114 %y1 : Tensor = fb::batch_one_hot_lengths(%y0, %d, %e)
115 %res : Tensor = aten::to(%y1, %f, %c, %c)
116 return (%res))IR";
117 std::string fused_pattern2 = R"IR(
118 graph(%a, %b, %c, %d, %e, %f):
119 %res : Tensor = fb::casted_batch_one_hot_lengths(%a, %d, %e)
120 return (%res))IR";
121 fuse.RegisterRewritePattern(pattern2, fused_pattern2);
122 fuse.runOnGraph(graph);
123 }
124
125 C10_UNUSED
ConcatBatchMatMulBatchGather(std::shared_ptr<torch::jit::Graph> & graph)126 void ConcatBatchMatMulBatchGather(std::shared_ptr<torch::jit::Graph>& graph) {
127 std::string pattern = R"IR(
128 graph(%a, %b, %c, %d, %e, %f):
129 %y0 : Tensor = aten::stack(%a, %b)
130 %y1 : Tensor = aten::transpose(%y0, %b, %c)
131 %y2 : Tensor = aten::bmm(%y0, %y1)
132 %y3 : Tensor = aten::flatten(%y2, %d, %e)
133 %res : Tensor = aten::index_select(%y3, %b, %f)
134 return (%res))IR";
135 std::string fused_pattern = R"IR(
136 graph(%a, %b, %c, %d, %e, %f):
137 %res : Tensor = fb::concat_batch_matmul_batch_gather(%f, %a)
138 return (%res))IR";
139 SubgraphRewriter fuse;
140 fuse.RegisterRewritePattern(pattern, fused_pattern);
141
142 // this pattern found in several models has a redundant second `flatten`
143 std::string pattern_broadcast = R"IR(
144 graph(%a, %b, %c, %d, %e, %indices):
145 %y0 : Tensor = fb::broadcast_stack(%a, %b)
146 %y1 : Tensor = aten::transpose(%y0, %b, %c)
147 %y2 : Tensor = aten::matmul(%y0, %y1)
148 %y3 : Tensor = aten::flatten(%y2, %b, %e)
149 %y4 : Tensor = aten::flatten(%y3, %d, %d)
150 %res : Tensor = aten::index_select(%y4, %b, %indices)
151 return (%res))IR";
152 std::string fused_pattern_broadcast = R"IR(
153 graph(%a, %b, %c, %d, %e, %indices):
154 %res : Tensor = fb::broadcast_concat_batch_matmul_batch_gather(%indices, %a)
155 return (%res))IR";
156 fuse.RegisterRewritePattern(pattern_broadcast, fused_pattern_broadcast);
157
158 std::string pattern_broadcast2 = R"IR(
159 graph(%a, %b, %c, %d, %indices):
160 %y0 : Tensor = fb::broadcast_stack(%a, %b)
161 %y1 : Tensor = aten::transpose(%y0, %b, %c)
162 %y2 : Tensor = aten::matmul(%y0, %y1)
163 %y3 : Tensor = aten::flatten(%y2, %b, %d)
164 %res : Tensor = aten::index_select(%y3, %b, %indices)
165 return (%res))IR";
166 std::string fused_pattern_broadcast2 = R"IR(
167 graph(%a, %b, %c, %d, %indices):
168 %res : Tensor = fb::broadcast_concat_batch_matmul_batch_gather(%indices, %a)
169 return (%res))IR";
170 fuse.RegisterRewritePattern(pattern_broadcast2, fused_pattern_broadcast2);
171 fuse.runOnGraph(graph);
172 }
173
ClipRangesGatherRangesLengthsToOffsets(std::shared_ptr<torch::jit::Graph> & graph)174 C10_UNUSED void ClipRangesGatherRangesLengthsToOffsets(
175 std::shared_ptr<torch::jit::Graph>& graph) {
176 // TODO:: check restrictions for inputs; outputs not used elsewhere
177 std::string pattern = R"IR(
178 graph(%a, %b, %c, %d):
179 %y0 : Tensor = fb::clip_ranges(%b, %c)
180 %y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0)
181 %y3 : Tensor = fb::lengths_to_offsets(%y2, %d)
182 return (%y3, %y1))IR";
183 std::string fused_pattern = R"IR(
184 graph(%a, %b, %c, %d):
185 %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_lengths_to_offsets(%a, %b, %c, %d)
186 return (%y1, %y0))IR";
187 SubgraphRewriter fuse;
188 fuse.RegisterRewritePattern(pattern, fused_pattern);
189 fuse.runOnGraph(graph);
190 }
191
ClipRangesGather(std::shared_ptr<torch::jit::Graph> & graph)192 C10_UNUSED void ClipRangesGather(std::shared_ptr<torch::jit::Graph>& graph) {
193 // TODO:: check restrictions for inputs; outputs not used elsewhere
194 // fuse without lengths-to-offsets
195 std::string pattern = R"IR(
196 graph(%a, %b, %c):
197 %y0 : Tensor = fb::clip_ranges(%b, %c)
198 %y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0)
199 return (%y2, %y1))IR";
200 std::string fused_pattern = R"IR(
201 graph(%a, %b, %c):
202 %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather(%a, %b, %c)
203 return (%y1, %y0))IR";
204 SubgraphRewriter fuse;
205 fuse.RegisterRewritePattern(pattern, fused_pattern);
206 fuse.runOnGraph(graph);
207 }
208
PrecomputeMultiplierShiftForSigridHash(std::shared_ptr<torch::jit::Graph> & graph)209 C10_UNUSED void PrecomputeMultiplierShiftForSigridHash(
210 std::shared_ptr<torch::jit::Graph>& graph) {
211 std::string pattern = R"IR(
212 graph(%a, %b, %c, %d, %e):
213 %y0 : Tensor = fb::sigrid_hash(%a, %b, %c, %d, %e)
214 return (%y0)
215 )IR";
216 std::string split_pattern = R"IR(
217 graph(%a, %b, %c, %d, %e):
218 %y0 : Tensor = fb::sigrid_hash_compute_multipler_shift(%c)
219 %y2 : Tensor = fb::sigrid_hash_precompute(%a, %b, %c, %y0, %d, %e)
220 return (%y2)
221 )IR";
222 SubgraphRewriter fuse;
223 fuse.RegisterRewritePattern(pattern, split_pattern);
224 fuse.runOnGraph(graph);
225 }
226
ClipRangesToGatherToOffsets(std::shared_ptr<torch::jit::Graph> & graph)227 C10_UNUSED void ClipRangesToGatherToOffsets(
228 std::shared_ptr<torch::jit::Graph>& graph) {
229 std::string pattern = R"IR(
230 graph(%a, %b, %c, %d, %to0_in0, %to0_in1, %to0_in2):
231 %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather(%a, %b, %c)
232 %y2 : Tensor = aten::to(%y1, %to0_in0, %to0_in1, %to0_in1, %to0_in2)
233 %y3 : Tensor = fb::lengths_to_offsets(%y2, %d)
234 return (%y3, %y0))IR";
235 std::string fused_pattern = R"IR(
236 graph(%a, %b, %c, %d, %to0_in0, %to0_in1, %to0_in2):
237 %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_to_offsets(%a, %b, %c, %d, %to0_in0)
238 return (%y1, %y0))IR";
239 SubgraphRewriter fuse;
240 fuse.RegisterRewritePattern(pattern, fused_pattern);
241 fuse.runOnGraph(graph);
242
243 std::string pattern2 = R"IR(
244 graph(%a, %b, %c, %d, %to0_in0, %to0_in1):
245 %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather(%a, %b, %c)
246 %y2 : Tensor = aten::to(%y1, %to0_in0, %to0_in1, %to0_in1)
247 %y3 : Tensor = fb::lengths_to_offsets(%y2, %d)
248 return (%y3, %y0))IR";
249 std::string fused_pattern2 = R"IR(
250 graph(%a, %b, %c, %d, %to0_in0, %to0_in1):
251 %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_to_offsets(%a, %b, %c, %d, %to0_in0)
252 return (%y1, %y0))IR";
253 fuse.RegisterRewritePattern(pattern2, fused_pattern2);
254 fuse.runOnGraph(graph);
255 }
256
ToLengthsToOffsets(std::shared_ptr<torch::jit::Graph> & graph)257 C10_UNUSED void ToLengthsToOffsets(std::shared_ptr<torch::jit::Graph>& graph) {
258 std::string pattern = R"IR(
259 graph(%a, %includelastoffset, %dtype, %nonblocking, %copy, %memoryformat):
260 %y0 : Tensor = aten::to(%a, %dtype, %nonblocking, %copy, %memoryformat)
261 %y1 : Tensor = fb::lengths_to_offsets(%y0, %includelastoffset)
262 return (%y1))IR";
263 std::string fused_pattern = R"IR(
264 graph(%a, %includelastoffset, %dtype, %nonblocking, %copy, %memoryformat):
265 %y0 : Tensor = fb::to_lengths_to_offsets(%a, %includelastoffset, %dtype)
266 return (%y0))IR";
267 SubgraphRewriter fuse;
268 fuse.RegisterRewritePattern(pattern, fused_pattern);
269 fuse.runOnGraph(graph);
270
271 std::string pattern2 = R"IR(
272 graph(%a, %includelastoffset, %dtype, %nonblocking, %copy):
273 %y0 : Tensor = aten::to(%a, %dtype, %nonblocking, %copy)
274 %y1 : Tensor = fb::lengths_to_offsets(%y0, %includelastoffset)
275 return (%y1))IR";
276 std::string fused_pattern2 = R"IR(
277 graph(%a, %includelastoffset, %dtype, %nonblocking, %copy):
278 %y0 : Tensor = fb::to_lengths_to_offsets(%a, %includelastoffset, %dtype)
279 return (%y0))IR";
280 fuse.RegisterRewritePattern(pattern2, fused_pattern2);
281 fuse.runOnGraph(graph);
282 }
283
284 C10_UNUSED
ClipRangesGatherSigridHash(std::shared_ptr<torch::jit::Graph> & graph)285 void ClipRangesGatherSigridHash(std::shared_ptr<torch::jit::Graph>& graph) {
286 // TODO:: check restrictions for inputs; outputs not used elsewhere
287 std::string pattern = R"IR(
288 graph(%a, %b, %c, %d, %e, %f, %g, %h):
289 %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_lengths_to_offsets(%a, %b, %c, %d)
290 %y2 : Tensor = fb::sigrid_hash_precompute(%y0, %e, %f, %g, %h)
291 return (%y2, %y1))IR";
292 std::string fused_pattern = R"IR(
293 graph(%a, %b, %c, %d, %e, %f, %g, %h):
294 %off : Tensor, %out : Tensor = fb::clip_ranges_gather_sigrid_hash_precompute_offsets(%b, %a, %c, %e, %f, %g, %h, %d)
295 return (%out, %off))IR";
296 SubgraphRewriter fuse;
297 fuse.RegisterRewritePattern(pattern, fused_pattern);
298 fuse.runOnGraph(graph);
299 }
300
ClipRangesGatherRangesSigridHash(std::shared_ptr<torch::jit::Graph> & graph)301 C10_UNUSED void ClipRangesGatherRangesSigridHash(
302 std::shared_ptr<torch::jit::Graph>& graph) {
303 std::string pattern = R"IR(
304 graph(%a, %b, %c, %d, %e, %f, %g):
305 %y0 : Tensor = fb::clip_ranges(%b, %c)
306 %y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0)
307 %y3 : Tensor = fb::sigrid_hash_precompute(%y1, %d, %e, %f, %g)
308 return (%y3, %y2))IR";
309 std::string fused_pattern = R"IR(
310 graph(%a, %b, %c, %d, %e, %f, %g):
311 %off : Tensor, %out : Tensor = fb::clip_ranges_gather_sigrid_hash_precompute_v3(%b, %a, %c, %d, %e, %f, %g)
312 return (%out, %off))IR";
313
314 SubgraphRewriter fuse;
315 fuse.RegisterRewritePattern(pattern, fused_pattern);
316 fuse.runOnGraph(graph);
317 }
318
ClipRangesGatherRangesX2SigridHashPrecompute(std::shared_ptr<torch::jit::Graph> & graph)319 C10_UNUSED void ClipRangesGatherRangesX2SigridHashPrecompute(
320 std::shared_ptr<torch::jit::Graph>& graph) {
321 // Placeholder is a dummy op used to capture the first subgraph
322 std::string pattern = R"IR(
323 graph(%ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32):
324 %clipped : Tensor = fb::clip_ranges(%ranges, %max_length)
325 %output : Tensor, %unused : Tensor = fb::gather_ranges(%values, %clipped)
326 %sigrid_hash_out : Tensor = fb::sigrid_hash_precompute(%output, %salt, %max_value, %mul_shift, %hash_into_int32)
327 return (%sigrid_hash_out, %clipped))IR";
328 std::string fused_pattern = R"IR(
329 graph(%ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32):
330 %sigrid_hash_out : Tensor, %clipped : Tensor = fb::placeholder(%ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32)
331 return (%sigrid_hash_out, %clipped))IR";
332
333 // the second gather_ranges can be eliminated because the `lengths` is
334 // produces is identical to the lengths produced by
335 // clip_ranges_gather_sigrid_hash_v3 (caveat, the fused ops makes some
336 // simplifying assumptions about the ranges input)
337 std::string pattern2 = R"IR(
338 graph(%gather2_values, %ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32):
339 %sigrid_hash_out : Tensor, %clipped : Tensor = fb::placeholder(%ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32)
340 %unused : Tensor, %lengths : Tensor = fb::gather_ranges(%gather2_values, %clipped)
341 return (%lengths, %sigrid_hash_out))IR";
342
343 std::string fused_pattern2 = R"IR(
344 graph(%gather2_values, %ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32):
345 %lengths : Tensor, %sigrid_hash_out : Tensor = fb::clip_ranges_gather_sigrid_hash_precompute_v3(%ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32)
346 return (%lengths, %sigrid_hash_out))IR";
347
348 SubgraphRewriter fuse;
349 fuse.RegisterRewritePattern(pattern, fused_pattern);
350 fuse.runOnGraph(graph);
351
352 fuse.RegisterRewritePattern(pattern2, fused_pattern2);
353 fuse.runOnGraph(graph);
354
355 // reverse the ops that got fused in step 1 but not in step2
356 fuse.RegisterRewritePattern(fused_pattern, pattern);
357 fuse.runOnGraph(graph);
358 }
359
SplitOutPrecomputeOpsForSparseNN(std::shared_ptr<torch::jit::Graph> & graph)360 C10_UNUSED void SplitOutPrecomputeOpsForSparseNN(
361 std::shared_ptr<torch::jit::Graph>& graph) {
362 #ifdef FBCODE_CAFFE2
363 PrecomputeMultiplierShiftForSigridHash(graph);
364 ConstantPropagation(graph);
365 ConstantPooling(graph);
366 #endif
367 }
368 } // namespace
369
FuseInferenceOpsForSparseNN(std::shared_ptr<torch::jit::Graph> & graph)370 void FuseInferenceOpsForSparseNN(std::shared_ptr<torch::jit::Graph>& graph) {
371 #ifdef FBCODE_CAFFE2
372 SplitOutPrecomputeOpsForSparseNN(graph);
373
374 ConcatAddMulReplaceNaNClip(graph);
375 CastedBatchOneHotLengths(graph);
376 ConcatBatchMatMulBatchGather(graph);
377
378 if (FLAGS_enable_clip_ranges_gather_fusions) {
379 ClipRangesGatherRangesLengthsToOffsets(graph);
380 }
381 ClipRangesGatherSigridHash(graph);
382 ClipRangesGatherRangesSigridHash(graph);
383
384 ClipRangesGatherRangesX2SigridHashPrecompute(graph);
385
386 if (FLAGS_enable_clip_ranges_gather_fusions) {
387 // prioritize clip_ranges+gather_ranges+sigrid_hash fusion over
388 // clip_ranges+gather_ranges
389 ClipRangesGather(graph);
390
391 ClipRangesToGatherToOffsets(graph);
392 }
393
394 ToLengthsToOffsets(graph);
395 #endif
396 }
397
TORCH_LIBRARY_FRAGMENT(static_runtime,m)398 TORCH_LIBRARY_FRAGMENT(static_runtime, m) {
399 m.def(torch::schema(
400 "static_runtime::permute_copy(Tensor self, int[] dims) -> Tensor",
401 c10::AliasAnalysisKind::PURE_FUNCTION));
402 m.def(torch::schema(
403 "static_runtime::reshape_copy(Tensor self, int[] shape) -> Tensor",
404 c10::AliasAnalysisKind::PURE_FUNCTION));
405 m.def(torch::schema(
406 "static_runtime::flatten_copy.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor",
407 c10::AliasAnalysisKind::PURE_FUNCTION));
408 m.def(torch::schema(
409 "static_runtime::expand_dims_copy(Tensor input, int[] dims) -> Tensor",
410 c10::AliasAnalysisKind::PURE_FUNCTION));
411 m.def(torch::schema(
412 "static_runtime::to_maybe_copy_out.prim_dtype(Tensor self, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor, bool)",
413 c10::AliasAnalysisKind::PURE_FUNCTION));
414 m.def(torch::schema(
415 "static_runtime::to_maybe_copy_out.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> (Tensor, bool)",
416 c10::AliasAnalysisKind::PURE_FUNCTION));
417 m.def(torch::schema(
418 "static_runtime::to_maybe_copy_out.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> (Tensor, bool)",
419 c10::AliasAnalysisKind::PURE_FUNCTION));
420 m.def(torch::schema(
421 "static_runtime::to_copy.prim_dtype(Tensor self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor",
422 c10::AliasAnalysisKind::PURE_FUNCTION));
423 m.def(torch::schema(
424 "static_runtime::to_copy.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor",
425 c10::AliasAnalysisKind::PURE_FUNCTION));
426 m.def(torch::schema(
427 "static_runtime::to_copy.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor",
428 c10::AliasAnalysisKind::PURE_FUNCTION));
429 m.def(torch::schema(
430 "static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor, Tensor, Tensor)",
431 c10::AliasAnalysisKind::PURE_FUNCTION));
432 m.def("static_runtime::signed_log1p(Tensor input) -> Tensor");
433 m.def(torch::schema(
434 "static_runtime::dict_unpack(...) -> ...",
435 c10::AliasAnalysisKind::CONSERVATIVE));
436 m.def(torch::schema(
437 "static_runtime::VarTupleUnpack(...) -> ...",
438 c10::AliasAnalysisKind::CONSERVATIVE));
439 m.def(torch::schema(
440 "static_runtime::fused_equally_split(Tensor input, int num_split, int dim) -> ...",
441 c10::AliasAnalysisKind::PURE_FUNCTION));
442 m.def(torch::schema(
443 "static_runtime::dequantize_copy.self(Tensor self) -> Tensor",
444 c10::AliasAnalysisKind::PURE_FUNCTION));
445 m.def(torch::schema(
446 "static_runtime::select_tensor(Tensor(a) a, Tensor(b) b, bool use_b) -> Tensor(a|b)",
447 c10::AliasAnalysisKind::FROM_SCHEMA));
448 m.def(torch::schema(
449 "static_runtime::create_owned_ref(...) -> ...",
450 c10::AliasAnalysisKind::CONSERVATIVE));
451 m.def(torch::schema(
452 "static_runtime::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor)",
453 c10::AliasAnalysisKind::PURE_FUNCTION));
454 m.def(torch::schema(
455 "static_runtime::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor)",
456 c10::AliasAnalysisKind::PURE_FUNCTION));
457 m.def(torch::schema(
458 "static_runtime::clamp_nan_to_num(Tensor input, Scalar? min, Scalar? max, float? nan, float? posinf, float? posinf) -> Tensor",
459 c10::AliasAnalysisKind::PURE_FUNCTION));
460 }
461
FuseSignLog1P(std::shared_ptr<torch::jit::Graph> & graph)462 void FuseSignLog1P(std::shared_ptr<torch::jit::Graph>& graph) {
463 std::string pattern = R"IR(
464 graph(%input):
465 %0 : Tensor = aten::sign(%input)
466 %1 : Tensor = aten::abs(%input)
467 %2 : Tensor = aten::log1p(%1)
468 %res : Tensor = aten::mul(%0, %2)
469 return (%res)
470 )IR";
471
472 std::string fused_pattern = R"IR(
473 graph(%input):
474 %res : Tensor = static_runtime::signed_log1p(%input)
475 return (%res)
476 )IR";
477
478 SubgraphRewriter fuse;
479 fuse.RegisterRewritePattern(pattern, fused_pattern);
480 fuse.runOnGraph(graph);
481 }
482
483 namespace {
484
485 using TupleUnpackBlock = std::vector<Node*>;
486
CollectVariadicTupleUnpackFusionCandidates(const std::shared_ptr<Graph> & graph)487 std::vector<TupleUnpackBlock> CollectVariadicTupleUnpackFusionCandidates(
488 const std::shared_ptr<Graph>& graph) {
489 std::vector<TupleUnpackBlock> candidates;
490 auto nodes = graph->nodes();
491 std::vector<Node*> block;
492 for (Node* cur_node : nodes) {
493 if (cur_node->kind() == prim::TupleUnpack) {
494 block.push_back(cur_node);
495 continue;
496 }
497 if (block.size() > 1) {
498 candidates.emplace_back(std::move(block));
499 }
500 block.clear();
501 }
502 TORCH_CHECK(block.empty());
503 return candidates;
504 }
505
FuseTupleUnpackBlock(const TupleUnpackBlock & nodes)506 void FuseTupleUnpackBlock(const TupleUnpackBlock& nodes) {
507 TORCH_CHECK(!nodes.empty());
508 auto graph = nodes[0]->owningGraph();
509 auto var_unpack = graph->create(
510 fromQualString("static_runtime::VarTupleUnpack"),
511 /* num_outputs */ 0);
512 var_unpack->insertAfter(nodes[nodes.size() - 1]);
513 for (Node* node : nodes) {
514 TORCH_CHECK(
515 node->kind() == prim::TupleUnpack && node->inputs().size() == 1);
516 var_unpack->addInput(node->input());
517
518 for (Value* output : node->outputs()) {
519 auto new_output = var_unpack->addOutput();
520 new_output->copyMetadata(output);
521 output->replaceAllUsesWith(new_output);
522 }
523 node->destroy();
524 }
525 }
526
527 } // namespace
528
UseVariadicTupleUnpack(const std::shared_ptr<Graph> & graph)529 void UseVariadicTupleUnpack(const std::shared_ptr<Graph>& graph) {
530 for (auto& c : CollectVariadicTupleUnpackFusionCandidates(graph)) {
531 FuseTupleUnpackBlock(c);
532 }
533 }
534
535 // This macro makes maps from c10::Symbol -> c10::Symbol a lot easier to read.
536 #define OP_PAIR(first, second) \
537 { fromQualString(first), fromQualString(second) }
538
539 // Out variants of ops cannot participate in memory planning if they
540 // have outputs that alias inputs. For ops that either return their
541 // input directly or copy it (most notably aten::to), we adopt the
542 // following strategy instead of directly making them out variants so
543 // that they can participate in memory planning anyway. Let `a` denote
544 // the input Tensor to the op.
545 //
546 // 1) Pass `a` (and the other operator inputs) to a special
547 // `static_runtime::$OP_maybe_copy_out` variant of the op. This op
548 // returns a normal output Tensor (call it `b_out` as well as a
549 // `did_copy` flag indicating whether the output should be used. If
550 // `did_copy` is false, the value of `b_out` is unspecified. Note that
551 // this operator is an ordinary out variant that is perfectly amenable
552 // to memory planning.
553 //
554 // 2) Pass `a`, `b_out`, and `did_copy` to a special
555 // `static_runtime::select_tensor` op, which returns `b_out` if
556 // `did_copy` is true and `a` otherwise. Note that this operator does
557 // not need to participate in memory planning because its output
558 // always aliases one of its inputs.
559 //
560 // Here is an illustration:
561 //
562 // |
563 // |----------------------+ a
564 // | v
565 // | +------------------------------------+
566 // | | |
567 // | | static_runtime::$OP_maybe_copy_out |
568 // | | |
569 // | +------------------+--------+--------+
570 // | | |
571 // +--------------+ | b_out | did_copy
572 // | a | |
573 // v v v
574 // +------------------------------------+
575 // | |
576 // | static_runtime::select_tensor |
577 // | |
578 // +------------------+-----------------+
579 // |
580 // |
581 // | either a or b_out
582 // |
583 // v
584
ReplaceWithMaybeCopy(std::shared_ptr<Graph> & graph,bool outputs_are_immutable)585 void ReplaceWithMaybeCopy(
586 std::shared_ptr<Graph>& graph,
587 bool outputs_are_immutable) {
588 AliasDb db(graph);
589 // for ops that have overloads, match the schema
590 static const std::array<std::pair<c10::FunctionSchema, c10::Symbol>, 3> supported_schema =
591 {{{torch::schema(
592 "aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"),
593 fromQualString("static_runtime::to_maybe_copy_out")},
594 {torch::schema(
595 "aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"),
596 fromQualString("static_runtime::to_maybe_copy_out")},
597 {torch::schema(
598 "aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"),
599 fromQualString("static_runtime::to_maybe_copy_out")}}};
600
601 auto match_schema = [](const Node* node, c10::Symbol& out_matched_symbol) {
602 for (auto& schema : supported_schema) {
603 if (node->matches(schema.first)) {
604 out_matched_symbol = schema.second;
605 return true;
606 }
607 }
608 return false;
609 };
610
611 // old node, new node, select_tensor node
612 std::vector<std::tuple<Node*, Node*, Node*>> replacement;
613 DepthFirstGraphNodeIterator graph_it(graph);
614 for (auto n = graph_it.next(); n != nullptr; n = graph_it.next()) {
615 c10::Symbol new_symbol;
616 if (!match_schema(n, new_symbol)) {
617 continue;
618 }
619 TORCH_CHECK(n->outputs().size() == 1);
620
621 // Duplicate input writers guard from ReplaceWithCopy below.
622 if (db.hasInputWriters(n)) {
623 continue;
624 }
625
626 auto* out = n->output();
627 if (!outputs_are_immutable && db.mayContainAlias(out, graph->outputs())) {
628 continue;
629 }
630
631 // Add the did_copy flag to outputs.
632 auto* new_node = graph->create(new_symbol, n->outputs().size() + 1);
633 for (auto* input : n->inputs()) {
634 new_node->addInput(input);
635 }
636 new_node->outputs().at(1)->setType(c10::BoolType::get());
637
638 static const auto select_tensor_symbol =
639 fromQualString("static_runtime::select_tensor");
640 auto* select_tensor_node = graph->create(select_tensor_symbol, 1);
641 TORCH_DCHECK_EQ(new_node->outputs().size(), 2);
642 select_tensor_node->addInput(n->input(0));
643 for (auto* output : new_node->outputs()) {
644 select_tensor_node->addInput(output);
645 }
646 replacement.emplace_back(n, new_node, select_tensor_node);
647 }
648
649 for (const auto& tup : replacement) {
650 auto* const old_node = std::get<0>(tup);
651 auto* const new_node = std::get<1>(tup);
652 auto* const select_tensor_node = std::get<2>(tup);
653
654 new_node->insertBefore(old_node);
655 select_tensor_node->insertBefore(old_node);
656 new_node->outputs()[0]->copyMetadata(old_node->output());
657 select_tensor_node->output()->copyMetadata(old_node->output());
658 old_node->replaceAllUsesWith(select_tensor_node);
659 old_node->destroy();
660 }
661 #ifndef NDEBUG
662 graph->lint();
663 AliasDb db2(graph);
664 torch::jit::Lint(&db2);
665 #endif
666 }
667
ReplaceWithCopyImpl(std::shared_ptr<Graph> & graph,const c10::FastMap<c10::Symbol,c10::Symbol> & supported,const std::vector<std::pair<c10::FunctionSchema,c10::Symbol>> & supported_schema,const std::function<bool (Node *)> & f_extra_checks,bool outputs_are_immutable)668 static void ReplaceWithCopyImpl(
669 std::shared_ptr<Graph>& graph,
670 const c10::FastMap<c10::Symbol, c10::Symbol>& supported,
671 const std::vector<std::pair<c10::FunctionSchema, c10::Symbol>>&
672 supported_schema,
673 const std::function<bool(Node*)>& f_extra_checks,
674 bool outputs_are_immutable) {
675 AliasDb db(graph);
676
677 auto match_schema = [&supported_schema](
678 const Node* node, c10::Symbol& out_matched_symbol) {
679 for (auto& schema : supported_schema) {
680 if (node->matches(schema.first)) {
681 out_matched_symbol = schema.second;
682 return true;
683 }
684 }
685 return false;
686 };
687
688 std::vector<std::pair<Node*, Node*>> replacement;
689 DepthFirstGraphNodeIterator graph_it(graph);
690 for (auto n = graph_it.next(); n != nullptr; n = graph_it.next()) {
691 c10::Symbol new_symbol;
692 if (supported.count(n->kind()) && opIsRegistered(supported.at(n->kind()))) {
693 new_symbol = supported.at(n->kind());
694 } else if (!match_schema(n, new_symbol)) {
695 continue;
696 }
697 TORCH_CHECK(n->outputs().size() == 1);
698
699 // We do not want to replace operators with their copy variant when the
700 // inputs to the operators have writers (can be updated). With an output
701 // that aliases to the input, updates to the input will be visible to the
702 // operator's output as well. For example:
703 //
704 // def forward(self, inp: Tensor, shape: List[int]):
705 // a = inp + inp
706 // b = a.reshape(shape)
707 // c = b.sigmoid_()
708 // d = c + c
709 // e = a + a
710 // f = b + b
711 // return (d, e, f)
712 //
713 // b and c are aliases of a, sigmoid_ changes b, c, as well as a. e should
714 // equal to d in this case. If we replace reshape with the copy version, b
715 // and c are no longer aliases of a, the value of e would change as a
716 // result. To keep static runtime consistent with the jit interpreter, here
717 // we choose not to replace reshape with the copy version
718 if (db.hasInputWriters(n)) {
719 continue;
720 }
721
722 auto* out = n->output();
723 if (!outputs_are_immutable && db.mayContainAlias(out, graph->outputs())) {
724 continue;
725 }
726 if (!f_extra_checks(n)) {
727 continue;
728 }
729 auto* new_node = graph->create(new_symbol, n->outputs().size());
730 for (auto* input : n->inputs()) {
731 new_node->addInput(input);
732 }
733 replacement.emplace_back(n, new_node);
734 }
735
736 for (const auto& p : replacement) {
737 auto* old_node = p.first;
738 auto* new_node = p.second;
739 new_node->insertBefore(old_node);
740 new_node->output()->copyMetadata(old_node->output());
741 old_node->replaceAllUsesWith(new_node);
742 old_node->destroy();
743 }
744 #ifndef NDEBUG
745 graph->lint();
746 AliasDb db2(graph);
747 torch::jit::Lint(&db2);
748 #endif
749 }
750
751 // replace aten::permute with copy version only when it's followed by
752 // reshape/flatten. It's only enabled when ReplaceWithCopy is off.
ReplacePermuteWithCopy(std::shared_ptr<Graph> & graph,bool outputs_are_immutable)753 void ReplacePermuteWithCopy(
754 std::shared_ptr<Graph>& graph,
755 bool outputs_are_immutable) {
756 AliasDb db(graph);
757 const c10::FastMap<c10::Symbol, c10::Symbol> supported = {
758 #ifdef FBCODE_CAFFE2
759 OP_PAIR("aten::permute", "static_runtime::permute_copy"),
760 #endif
761 };
762 auto f_extra_checks = [](Node* n) {
763 Value* out = n->output();
764 Node* next_node = out->uses()[0].user;
765 if (next_node->kind() != aten::reshape ||
766 next_node->kind() != aten::flatten) {
767 return true;
768 }
769 return false;
770 };
771 ReplaceWithCopyImpl(
772 graph, supported, {}, f_extra_checks, outputs_are_immutable);
773 }
774
ReplaceWithCopy(std::shared_ptr<Graph> & graph,bool outputs_are_immutable)775 void ReplaceWithCopy(
776 std::shared_ptr<Graph>& graph,
777 bool outputs_are_immutable) {
778 AliasDb db(graph);
779 const c10::FastMap<c10::Symbol, c10::Symbol> supported = {
780 #ifdef FBCODE_CAFFE2
781 OP_PAIR("aten::permute", "static_runtime::permute_copy"),
782 OP_PAIR("fb::expand_dims", "static_runtime::expand_dims_copy"),
783 #endif
784 OP_PAIR("aten::narrow", "aten::narrow_copy"),
785 OP_PAIR("aten::reshape", "static_runtime::reshape_copy"),
786 OP_PAIR("aten::flatten", "static_runtime::flatten_copy")};
787
788 static const std::vector<std::pair<c10::FunctionSchema, c10::Symbol>>
789 supported_schema = {
790 {{torch::schema("aten::dequantize.self(Tensor self) -> Tensor"),
791 fromQualString("static_runtime::dequantize_copy")}}};
792
793 ReplaceWithCopyImpl(
794 graph,
795 supported,
796 supported_schema,
797 [](Node* n) { return true; },
798 outputs_are_immutable);
799 }
800
EliminateTrivialEquallySplit(std::shared_ptr<torch::jit::Graph> & graph)801 void EliminateTrivialEquallySplit(std::shared_ptr<torch::jit::Graph>& graph) {
802 const auto equally_split = fromQualString("fb::equally_split");
803 std::vector<Node*> to_remove;
804 DepthFirstGraphNodeIterator graph_it(graph);
805 for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
806 if (node->kind() != equally_split) {
807 continue;
808 }
809
810 const Value* value_out = node->outputs()[0];
811 if (value_out->uses().size() != 1) {
812 continue;
813 }
814
815 Node* list_unpack_node = value_out->uses()[0].user;
816 if (list_unpack_node->kind() != prim::ListUnpack) {
817 continue;
818 }
819
820 auto list_unpack_outputs = list_unpack_node->outputs();
821 if (list_unpack_outputs.size() != 1) {
822 continue;
823 }
824
825 list_unpack_node->output()->replaceAllUsesWith(node->input(0));
826 to_remove.push_back(list_unpack_node);
827 to_remove.push_back(node);
828 }
829
830 for (Node* node : to_remove) {
831 node->destroy();
832 }
833 }
834
835 namespace {
836
shouldNotFuseListUnpackSpecialCase(const Node * node)837 bool shouldNotFuseListUnpackSpecialCase(const Node* node) {
838 const static std::array<c10::Symbol, 3> sigrid_transforms_symbols{
839 c10::Symbol::fromQualString("fb::variadic_sigrid_transforms_torch_bind"),
840 c10::Symbol::fromQualString("fb::sigrid_transforms_torch_bind"),
841 c10::Symbol::fromQualString("fb::sigrid_transforms")};
842
843 if (std::find(
844 sigrid_transforms_symbols.begin(),
845 sigrid_transforms_symbols.end(),
846 node->kind()) == sigrid_transforms_symbols.end()) {
847 return false;
848 }
849
850 // To fuse with sigrid transforms, we must be able to statically determine
851 // `instance` and `use_offsets` - these two together let us statically
852 // determine the types of the outputs. Rationale: it is a huge pain to write
853 // fused sigrid transforms without static type information, and these two
854 // arguments are indeed statically known in every model we've seen.
855 // The reason why trying to fuse the outputs is annoying without static type
856 // information is that, if one of the outputs is not managed, you need to
857 // reset to an empty tensor of the correct type each iteration. So, if we
858 // can't collect types ahead of time, we would have to do it lazily on the
859 // first iteration, which would could be wasteful in terms of time/memory
860 // - either each thread would have its own set of output types, or we would
861 // need a lock to prevent data races.
862 const auto num_inputs = node->inputs().size();
863 return !toIValue(node->input(0)).has_value() ||
864 !toIValue(node->input(num_inputs - 1)).has_value();
865 }
866
867 } // namespace
868
FuseListUnpack(std::shared_ptr<torch::jit::Graph> & graph)869 void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
870 const c10::FastMap<c10::Symbol, c10::Symbol> unfused_to_fused = {
871 OP_PAIR(
872 "torcharrow::inference_wrapper_run_flat",
873 "static_runtime::fused_inference_wrapper_run_flat"),
874 OP_PAIR(
875 "torcharrow::variadic_inference_wrapper_run_flat",
876 "static_runtime::fused_variadic_inference_wrapper_run_flat"),
877 OP_PAIR("fb::equally_split", "static_runtime::fused_equally_split"),
878 OP_PAIR(
879 "fb::sigrid_transforms", "static_runtime::fused_sigrid_transforms"),
880 OP_PAIR(
881 "static_runtime::variadic_grouped_accessor_op_v2",
882 "static_runtime::fused_variadic_grouped_accessor_op_v2"),
883 OP_PAIR(
884 "fb::sigrid_transforms_torch_bind",
885 "static_runtime::fused_sigrid_transforms_torch_bind"),
886 OP_PAIR(
887 "fb::variadic_sigrid_transforms_torch_bind",
888 "static_runtime::fused_variadic_sigrid_transforms_torch_bind"),
889 OP_PAIR(
890 "fb::gather_ranges_to_dense",
891 "static_runtime::fused_gather_ranges_to_dense"),
892 OP_PAIR(
893 "fb::gather_ranges_to_dense_v2",
894 "static_runtime::fused_gather_ranges_to_dense_v2"),
895 OP_PAIR(
896 "fb::split_and_squeeze",
897 "static_runtime::fused_split_and_squeeze_copy")};
898
899 // replacement contains (old_node, new_node, list_unpack_node)
900 std::vector<std::tuple<Node*, Node*, Node*>> replacement;
901 DepthFirstGraphNodeIterator graph_it(graph);
902 for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
903 auto unfused_to_fused_it = unfused_to_fused.find(node->kind());
904 if (unfused_to_fused_it == unfused_to_fused.end()) {
905 continue;
906 }
907
908 const Value* value_out = node->outputs()[0];
909 if (value_out->uses().size() != 1) {
910 continue;
911 }
912
913 Node* list_unpack_node = value_out->uses()[0].user;
914 if (list_unpack_node->kind() != prim::ListUnpack) {
915 continue;
916 }
917
918 auto list_unpack_outputs = list_unpack_node->outputs();
919 if (list_unpack_outputs.empty()) {
920 continue;
921 }
922
923 if (shouldNotFuseListUnpackSpecialCase(node)) {
924 continue;
925 }
926
927 const auto& new_sym = unfused_to_fused_it->second;
928 auto* new_node = graph->create(new_sym, 0);
929
930 for (Value* in : node->inputs()) {
931 new_node->addInput(in);
932 }
933
934 for (Value* out : list_unpack_outputs) {
935 Value* new_out = new_node->addOutput();
936 new_out->copyMetadata(out);
937 out->replaceAllUsesWith(new_out);
938 }
939 replacement.emplace_back(node, new_node, list_unpack_node);
940 }
941
942 for (const auto& nodes : replacement) {
943 auto* old_node = std::get<0>(nodes);
944 auto* new_node = std::get<1>(nodes);
945 auto* list_unpack_node = std::get<2>(nodes);
946
947 new_node->insertAfter(old_node);
948 list_unpack_node->destroy();
949 old_node->destroy();
950 }
951 } // namespace jit
952
RemoveImmutableInputDictLookups(std::shared_ptr<torch::jit::Graph> & graph)953 void RemoveImmutableInputDictLookups(
954 std::shared_ptr<torch::jit::Graph>& graph) {
955 auto nodes = graph->nodes();
956 AliasDb db(graph);
957 // Gather all dict -> getitems where dict is immutable and getitems use
958 // constant keys.
959 std::unordered_map<Value*, std::vector<Node*>> dict_to_getitems;
960 std::unordered_set<Node*> keys;
961 for (Node* node : nodes) {
962 // Find aten::__getitem__(%dict, %constant_key).
963 if (node->kind() != aten::__getitem__) {
964 continue;
965 }
966 Node* getitem_node = node;
967 Value* dict = getitem_node->input(0);
968 if (db.hasWriters(dict)) {
969 // Mutable dict. Skip this optimization.
970 continue;
971 }
972 if (dict->type()->kind() != TypeKind::DictType ||
973 dict->node() != graph->param_node()) {
974 continue;
975 }
976 DCHECK(getitem_node->inputs().size() == 2);
977 Node* key = getitem_node->input(1)->node();
978 if (key->kind() != prim::Constant) {
979 continue;
980 }
981 keys.insert(key);
982 auto iter = dict_to_getitems.find(dict);
983 if (iter == dict_to_getitems.end()) {
984 dict_to_getitems.emplace(dict, std::vector<Node*>{getitem_node});
985 continue;
986 }
987 iter->second.push_back(getitem_node);
988 }
989 if (keys.empty()) {
990 return;
991 }
992 // Move all keys to the beginning of the graph and insert new dict_unpack
993 // nodes after that.
994 auto* marker = graph->create(prim::Constant);
995 graph->prependNode(marker);
996 graph->setInsertPoint(marker);
997 for (Node* key : keys) {
998 DCHECK(key->inputs().empty());
999 key->moveBefore(marker);
1000 }
1001 const c10::Symbol static_runtime_dict_unpack_symbol =
1002 fromQualString("static_runtime::dict_unpack");
1003 for (auto& it : dict_to_getitems) {
1004 Value* dict = it.first;
1005 std::vector<Node*>& getitems = it.second;
1006 DCHECK(!getitems.empty());
1007 auto* dict_unpack =
1008 graph->create(static_runtime_dict_unpack_symbol, getitems.size());
1009 graph->insertNode(dict_unpack);
1010 dict_unpack->addInput(getitems[0]->input(0));
1011 for (size_t i = 0; i < getitems.size(); ++i) {
1012 Node* getitem_node = getitems[i];
1013 DCHECK(getitem_node->input(0) == dict);
1014 dict_unpack->addInput(getitem_node->input(1));
1015 dict_unpack->output(i)->copyMetadata(getitem_node->output());
1016 getitem_node->output(0)->replaceAllUsesWith(dict_unpack->output(i));
1017 getitem_node->destroy();
1018 }
1019 }
1020 graph->setInsertPoint(graph->block());
1021 marker->destroy();
1022 }
1023
UseVariadicGroupedAccessor(const std::shared_ptr<Graph> & graph)1024 void UseVariadicGroupedAccessor(const std::shared_ptr<Graph>& graph) {
1025 UseVariadicOp(
1026 graph,
1027 fromQualString("grouped_accessor::grouped_accessor_op_v2"),
1028 fromQualString("static_runtime::variadic_grouped_accessor_op_v2"));
1029 UseVariadicOp(
1030 graph,
1031 fromQualString("fb::grouped_accessor_op_async"),
1032 fromQualString("static_runtime::variadic_grouped_accessor_op_async"));
1033 }
1034
1035 namespace {
1036
CreateOwnedRefsForSpecialValuesHelper(Graph & graph,Block * block)1037 void CreateOwnedRefsForSpecialValuesHelper(Graph& graph, Block* block) {
1038 for (auto* node : block->nodes()) {
1039 for (auto* sub_block : node->blocks()) {
1040 CreateOwnedRefsForSpecialValuesHelper(graph, sub_block);
1041 }
1042 }
1043
1044 auto outputs = block->outputs();
1045 // Create owned refs for inputs. Otherwise, the input cleanup process
1046 // will destroy our outputs before we return.
1047 c10::FastSet<Value*> inputs = {
1048 block->inputs().begin(), block->inputs().end()};
1049
1050 for (const auto i : c10::irange(outputs.size())) {
1051 auto* output = outputs[i];
1052
1053 if (output->type()->kind() == c10::TypeKind::NoneType) {
1054 // No need to create owned refs of NoneType since moving
1055 // from None will have no effect
1056 continue;
1057 }
1058
1059 if ((inputs.find(output) != inputs.end()) || toIValue(output).has_value() ||
1060 // If the output's owning block is not this one, it's from an outer
1061 // scope
1062 output->node()->owningBlock() != block) {
1063 auto* create_owned_ref_node =
1064 graph.create(fromQualString("static_runtime::create_owned_ref"));
1065 create_owned_ref_node->addInput(output);
1066 create_owned_ref_node->output()->copyMetadata(output);
1067
1068 block->appendNode(create_owned_ref_node);
1069 block->replaceOutput(i, create_owned_ref_node->output());
1070 }
1071 }
1072 }
1073
ForceNonEmptyOutputsHelper(Value * none_value,Block * block)1074 void ForceNonEmptyOutputsHelper(Value* none_value, Block* block) {
1075 for (auto* node : block->nodes()) {
1076 bool needs_output = false;
1077 for (auto* sub_block : node->blocks()) {
1078 if (sub_block->outputs().empty()) {
1079 sub_block->registerOutput(none_value);
1080 needs_output = true;
1081 }
1082
1083 ForceNonEmptyOutputsHelper(none_value, sub_block);
1084 }
1085
1086 if (needs_output) {
1087 // Loop sub-blocks should always return at least one output (the new loop
1088 // condition)
1089 DCHECK(node->kind() == prim::If);
1090 auto* output = node->addOutput();
1091 output->setType(c10::NoneType::get());
1092 }
1093 }
1094 }
1095
findOrCreateNoneConstant(Graph & graph)1096 Node* findOrCreateNoneConstant(Graph& graph) {
1097 // Only search the top-level block
1098 for (auto* node : graph.nodes()) {
1099 if (node->kind() != prim::Constant) {
1100 continue;
1101 }
1102 const auto ival_opt = toIValue(node->output());
1103 DCHECK(ival_opt.has_value());
1104 if (ival_opt->isNone()) {
1105 return node;
1106 }
1107 }
1108
1109 auto* none_node = graph.create(prim::Constant);
1110 none_node->output()->setType(c10::NoneType::get());
1111 graph.prependNode(none_node);
1112 return none_node;
1113 }
1114
1115 } // namespace
1116
CreateOwnedRefsForSpecialValues(Graph & graph)1117 void CreateOwnedRefsForSpecialValues(Graph& graph) {
1118 CreateOwnedRefsForSpecialValuesHelper(graph, graph.block());
1119 }
1120
ForceNonEmptyOutputs(Graph & graph)1121 void ForceNonEmptyOutputs(Graph& graph) {
1122 auto* none_node = findOrCreateNoneConstant(graph);
1123 ForceNonEmptyOutputsHelper(none_node->output(), graph.block());
1124 if (!none_node->hasUses()) {
1125 none_node->destroy();
1126 }
1127 }
1128
1129 namespace {
1130
inputIsConstantList(Node * node,size_t input_idx,const c10::List<int64_t> & expected)1131 bool inputIsConstantList(
1132 Node* node,
1133 size_t input_idx,
1134 const c10::List<int64_t>& expected) {
1135 auto input_opt = toIValue(node->input(input_idx));
1136 if (!input_opt.has_value() || !input_opt->isIntList()) {
1137 return false;
1138 }
1139 return input_opt->toIntList() == expected;
1140 }
1141
inputIsConstantInt(Node * node,size_t input_idx,int64_t expected)1142 bool inputIsConstantInt(Node* node, size_t input_idx, int64_t expected) {
1143 auto input_opt = toIValue(node->input(input_idx));
1144 if (!input_opt.has_value() || !input_opt->isInt()) {
1145 return false;
1146 }
1147 return input_opt->toInt() == expected;
1148 }
1149
eliminatePermuteOpsSumPattern(std::shared_ptr<Graph> & graph)1150 void eliminatePermuteOpsSumPattern(std::shared_ptr<Graph>& graph) {
1151 // SubgraphRewriter can't pattern-match on constants, so we use this
1152 // extra filter to make sure the values of the `dim` arguments are
1153 // correct.
1154 auto dims_are_valid_constants =
1155 [](const Match& match,
1156 const std::unordered_map<std::string, Value*>& vmap) {
1157 // Get the nodes in the real graph from the nodes in the template
1158 // pattern graph
1159 const auto& node_map = match.nodes_map;
1160 auto* sum_node = node_map.at(vmap.at("c")->node());
1161 auto* permute_node = node_map.at(vmap.at("b")->node());
1162 return inputIsConstantList(sum_node, 1, c10::List<int64_t>{-1}) &&
1163 inputIsConstantList(permute_node, 1, c10::List<int64_t>{0, 2, 1});
1164 };
1165
1166 const auto pattern = R"IR(
1167 graph(%a, %sum_dim, %permute_dim, %keepdim, %dtype):
1168 %b = aten::permute(%a, %permute_dim)
1169 %c = aten::sum(%b, %sum_dim, %keepdim, %dtype)
1170 return (%c))IR";
1171
1172 const auto fused_pattern = R"IR(
1173 graph(%a, %sum_dim, %permute_dim, %keepdim, %dtype):
1174 %new_sum_dim: int[] = prim::Constant[value=[1]]()
1175 %d = aten::sum(%a, %new_sum_dim, %keepdim, %dtype)
1176 return (%d))IR";
1177
1178 SubgraphRewriter fuse;
1179 fuse.RegisterRewritePattern(pattern, fused_pattern);
1180 fuse.runOnGraph(graph, dims_are_valid_constants);
1181 }
1182
eliminatePermuteOpsSoftmaxPattern(std::shared_ptr<Graph> & graph)1183 void eliminatePermuteOpsSoftmaxPattern(std::shared_ptr<Graph>& graph) {
1184 const auto pattern = R"IR(
1185 graph(%a, %permute_dim_1, %permute_dim_2, %softmax_dim, %softmax_dtype):
1186 %b = aten::permute(%a, %permute_dim_1)
1187 %c = aten::softmax(%b, %softmax_dim, %softmax_dtype)
1188 %d = aten::permute(%c, %permute_dim_2)
1189 return (%d)
1190 )IR";
1191
1192 const auto fused_pattern = R"IR(
1193 graph(%a, %permute_dim_1, %permute_dim_2, %softmax_dim, %softmax_dtype):
1194 %new_softmax_dim: int = prim::Constant[value=1]()
1195 %e = aten::softmax(%a, %new_softmax_dim, %softmax_dtype)
1196 return (%e)
1197 )IR";
1198
1199 // Check that permute_dim is (0, 2, 1) and softmax_dim is 2
1200 auto dims_are_valid_constants =
1201 [](const Match& match,
1202 const std::unordered_map<std::string, Value*>& vmap) {
1203 const auto& node_map = match.nodes_map;
1204 auto* permute_node_1 = node_map.at(vmap.at("b")->node());
1205 auto* permute_node_2 = node_map.at(vmap.at("d")->node());
1206 auto* softmax_node = node_map.at(vmap.at("c")->node());
1207 return inputIsConstantInt(softmax_node, 1, 2) &&
1208 inputIsConstantList(
1209 permute_node_1, 1, c10::List<int64_t>{0, 2, 1}) &&
1210 inputIsConstantList(permute_node_2, 1, c10::List<int64_t>{0, 2, 1});
1211 };
1212
1213 SubgraphRewriter fuse;
1214 fuse.RegisterRewritePattern(pattern, fused_pattern);
1215 fuse.runOnGraph(graph, dims_are_valid_constants);
1216 }
1217
1218 } // namespace
1219
EliminateExtraPermuteOps(std::shared_ptr<Graph> & graph)1220 void EliminateExtraPermuteOps(std::shared_ptr<Graph>& graph) {
1221 eliminatePermuteOpsSumPattern(graph);
1222 eliminatePermuteOpsSoftmaxPattern(graph);
1223 }
1224
1225 namespace {
1226
maybeUserWithKind(Value * value,c10::Symbol kind)1227 Node* maybeUserWithKind(Value* value, c10::Symbol kind) {
1228 auto& uses = value->uses();
1229 if (uses.size() != 1) {
1230 return nullptr;
1231 }
1232 auto* user = uses[0].user;
1233 if (user->kind() != kind) {
1234 return nullptr;
1235 }
1236 return user;
1237 }
1238
1239 } // namespace
1240
UseSplitAndSqueeze(std::shared_ptr<Graph> & graph)1241 void UseSplitAndSqueeze(std::shared_ptr<Graph>& graph) {
1242 std::vector<Node*> to_erase;
1243 for (auto* node : graph->nodes()) {
1244 if (node->kind() != aten::split) {
1245 continue;
1246 }
1247 auto axis_opt = toIValue(node->input(2));
1248 if (!axis_opt) {
1249 continue;
1250 }
1251 auto axis = *axis_opt;
1252 auto* split_node_output = node->output();
1253 auto* list_unpack_node =
1254 maybeUserWithKind(split_node_output, prim::ListUnpack);
1255 if (list_unpack_node == nullptr) {
1256 continue;
1257 }
1258 std::vector<Node*> squeeze_nodes;
1259 squeeze_nodes.reserve(list_unpack_node->outputs().size());
1260 for (auto* output : list_unpack_node->outputs()) {
1261 auto* squeeze_node = maybeUserWithKind(output, aten::squeeze);
1262 if (squeeze_node == nullptr) {
1263 break;
1264 }
1265 auto dim_opt = toIValue(squeeze_node->input(1));
1266 if (!dim_opt || *dim_opt != axis) {
1267 break;
1268 }
1269 squeeze_nodes.push_back(squeeze_node);
1270 }
1271 auto num_outputs = list_unpack_node->outputs().size();
1272 if (squeeze_nodes.size() != num_outputs) {
1273 continue;
1274 }
1275 auto* split_and_squeeze_node = graph->create(
1276 c10::Symbol::fromQualString(
1277 "static_runtime::fused_split_and_squeeze_copy"),
1278 num_outputs);
1279 split_and_squeeze_node->addInput(node->input(0));
1280 split_and_squeeze_node->addInput(node->input(1));
1281 split_and_squeeze_node->addInput(node->input(2));
1282 split_and_squeeze_node->insertBefore(node);
1283 for (const auto i : c10::irange(num_outputs)) {
1284 auto* squeeze_node = squeeze_nodes[i];
1285 split_and_squeeze_node->output(i)->copyMetadata(squeeze_node->output());
1286 squeeze_node->output()->replaceAllUsesWith(
1287 split_and_squeeze_node->output(i));
1288 }
1289 to_erase.insert(to_erase.end(), squeeze_nodes.begin(), squeeze_nodes.end());
1290 to_erase.push_back(list_unpack_node);
1291 to_erase.push_back(node);
1292 }
1293 for (auto* node : to_erase) {
1294 node->destroy();
1295 }
1296 }
1297
RemoveUnnecessaryOutputs(std::shared_ptr<torch::jit::Graph> & graph)1298 C10_UNUSED void RemoveUnnecessaryOutputs(
1299 std::shared_ptr<torch::jit::Graph>& graph) {
1300 RemoveUnnecessaryEmbeddingBagOutputs(graph);
1301 }
1302
RemoveUnnecessaryEmbeddingBagOutputs(std::shared_ptr<torch::jit::Graph> & graph)1303 C10_UNUSED void RemoveUnnecessaryEmbeddingBagOutputs(
1304 std::shared_ptr<torch::jit::Graph>& graph) {
1305 std::string pattern = R"IR(
1306 graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset):
1307 %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
1308 return (%y2, %y1, %y0))IR";
1309 std::string transformed_pattern = R"IR(
1310 graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset):
1311 %y0 : Tensor, %y1 : Tensor, %y2 : Tensor = static_runtime::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
1312 return (%y2, %y1, %y0))IR";
1313 SubgraphRewriter fuse;
1314 fuse.RegisterRewritePattern(pattern, transformed_pattern);
1315 fuse.runOnGraph(graph);
1316
1317 std::string pattern2 = R"IR(
1318 graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx):
1319 %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx)
1320 return (%y2, %y1, %y0))IR";
1321 std::string transformed_pattern2 = R"IR(
1322 graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx):
1323 %y0 : Tensor, %y1 : Tensor, %y2 : Tensor = static_runtime::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx)
1324 return (%y2, %y1, %y0))IR";
1325 fuse.RegisterRewritePattern(pattern2, transformed_pattern2);
1326 fuse.runOnGraph(graph);
1327 }
1328
1329 namespace {
isNoOpSlice(Node * node)1330 bool isNoOpSlice(Node* node) {
1331 DCHECK(node->kind() == aten::slice);
1332 auto step = toIValue(node->input(3));
1333 if (!step.has_value() || step->toInt() != 1) {
1334 return false;
1335 }
1336 auto start = toIValue(node->input(1));
1337 if (!start.has_value() || (start->isInt() && start->toInt() != 0)) {
1338 return false;
1339 }
1340 auto end = toIValue(node->input(2));
1341 // Could also look at list length, but most models that have this pattern are
1342 // just doing list[0:], so it's not needed for now.
1343 return end.has_value() && end->isNone();
1344 }
1345 } // namespace
1346
EliminateNoOpSlice(std::shared_ptr<Graph> & graph)1347 void EliminateNoOpSlice(std::shared_ptr<Graph>& graph) {
1348 DepthFirstGraphNodeIterator it(graph);
1349 auto schema = torch::schema(
1350 "aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]",
1351 /*allow_typevars*/ true);
1352 Node* node = nullptr;
1353 std::vector<Node*> to_delete;
1354 while ((node = it.next()) != nullptr) {
1355 if (!node->matches(schema) || !isNoOpSlice(node)) {
1356 continue;
1357 }
1358
1359 node->output()->replaceAllUsesWith(node->input(0));
1360 to_delete.push_back(node);
1361 }
1362 for (auto* node : to_delete) {
1363 node->destroy();
1364 }
1365 }
1366
UseInPlaceGetRealInputsFromOptionalInputsV2(std::shared_ptr<Graph> & graph)1367 void UseInPlaceGetRealInputsFromOptionalInputsV2(
1368 std::shared_ptr<Graph>& graph) {
1369 #ifdef FBCODE_CAFFE2
1370 const std::string original_pattern = R"IR(
1371 graph(%optional_input: (Tensor, Tensor?, Tensor?)?[], %include_last_offsets: bool[]):
1372 %x : (Tensor, Tensor?, Tensor?)[] = remote_collection::get_real_inputs_from_optional_inputs_v2(%optional_input, %include_last_offsets)
1373 return (%x))IR";
1374
1375 const std::string new_pattern = R"IR(
1376 graph(%optional_input: (Tensor, Tensor?, Tensor?)?[], %include_last_offsets: bool[]):
1377 %x : (Tensor, Tensor?, Tensor?)[] = static_runtime::get_real_inputs_from_optional_inputs_v2_inplace(%optional_input, %include_last_offsets)
1378 return (%x))IR";
1379
1380 auto isSingleUse = [](Value* value) { return value->uses().size() == 1; };
1381
1382 auto filter = [&isSingleUse](
1383 const Match& match,
1384 const std::unordered_map<std::string, Value*>& vmap) {
1385 auto* real_node = match.nodes_map.at(vmap.at("x")->node());
1386 return isSingleUse(real_node->input(0));
1387 };
1388
1389 SubgraphRewriter fuse;
1390 fuse.RegisterRewritePattern(original_pattern, new_pattern);
1391 fuse.runOnGraph(graph, filter);
1392 #endif
1393 }
1394
FuseClampNaNToNum(std::shared_ptr<Graph> & graph)1395 void FuseClampNaNToNum(std::shared_ptr<Graph>& graph) {
1396 #ifdef FBCODE_CAFFE2
1397 std::string pattern = R"IR(
1398 graph(%input, %clamp_min: Scalar?, %clamp_max: Scalar?, %nan, %posinf, %neginf):
1399 %x : Tensor = aten::clamp(%input, %clamp_min, %clamp_max)
1400 %y : Tensor = aten::nan_to_num(%x, %nan, %posinf, %neginf)
1401 return (%y))IR";
1402
1403 std::string fused_pattern = R"IR(
1404 graph(%input, %clamp_min: Scalar?, %clamp_max: Scalar?, %nan, %posinf, %neginf):
1405 %x : Tensor = static_runtime::clamp_nan_to_num(%input, %clamp_min, %clamp_max, %nan, %posinf, %neginf)
1406 return (%x))IR";
1407
1408 auto isConstantAndNotNone = [](Value* value) {
1409 auto ival_opt = toIValue(value);
1410 if (!ival_opt.has_value()) {
1411 return false;
1412 }
1413 auto scalar_opt = ival_opt->toOptional<at::Scalar>();
1414 return scalar_opt.has_value();
1415 };
1416
1417 auto clampValuesAreConstant =
1418 [&isConstantAndNotNone](
1419 const Match& match,
1420 const std::unordered_map<std::string, Value*>& vmap) {
1421 // Get the nodes in the real graph from the nodes in the template
1422 // pattern graph
1423 const auto& node_map = match.nodes_map;
1424 auto* clamp_node = node_map.at(vmap.at("x")->node());
1425 return isConstantAndNotNone(clamp_node->input(1)) &&
1426 isConstantAndNotNone(clamp_node->input(2));
1427 };
1428
1429 SubgraphRewriter fuse;
1430 fuse.RegisterRewritePattern(pattern, fused_pattern);
1431 fuse.runOnGraph(graph, clampValuesAreConstant);
1432 #endif
1433 }
1434
PrepackWeights(std::shared_ptr<Graph> & graph)1435 void PrepackWeights(std::shared_ptr<Graph>& graph) {
1436 const auto pattern = R"IR(
1437 graph(%input: Tensor, %weight: Tensor, %bias: Tensor?, %scale: Tensor, %zero_point: Tensor):
1438 %result: Tensor = fb::quantized_linear_unpacked_weight_v2(%input, %weight, %bias, %scale, %zero_point)
1439 return (%result)
1440 )IR";
1441
1442 const auto split_pattern = R"IR(
1443 graph(%input: Tensor, %weight: Tensor, %bias: Tensor?, %scale: Tensor, %zero_point: Tensor):
1444 %packed_params = quantized::linear_prepack(%weight, %bias)
1445 %scale_float: float = aten::item(%scale)
1446 %zero_point_int: int = aten::item(%zero_point)
1447 %result: Tensor = quantized::linear(%input, %packed_params, %scale_float, %zero_point_int)
1448 return (%result)
1449 )IR";
1450
1451 SubgraphRewriter fuse;
1452 fuse.RegisterRewritePattern(pattern, split_pattern);
1453 fuse.runOnGraph(graph);
1454 // Constant propagation should be called after this pass + others.
1455 }
1456
1457 } // namespace torch::jit
1458