xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/passes.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/static/passes.h>
2 
3 #include <torch/csrc/jit/ir/alias_analysis.h>
4 #include <torch/csrc/jit/ir/subgraph_matcher.h>
5 #include <torch/csrc/jit/passes/constant_pooling.h>
6 #include <torch/csrc/jit/passes/constant_propagation.h>
7 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
8 #include <torch/csrc/jit/passes/variadic_ops.h>
9 #include <torch/csrc/jit/runtime/graph_iterator.h>
10 #include <torch/csrc/jit/runtime/static/ops.h>
11 
12 C10_DEFINE_bool(
13     enable_clip_ranges_gather_fusions,
14     true,
15     "If on, static runtime or optimize_sparse_nn_model will fuse clip ranges gather ops.");
16 
17 namespace torch::jit {
18 
graphHasOp(std::shared_ptr<Graph> & graph,const char * op_name)19 bool graphHasOp(std::shared_ptr<Graph>& graph, const char* op_name) {
20   DepthFirstGraphNodeIterator graph_it(graph);
21   for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
22     const char* node_qual_string = node->kind().toQualString();
23     if (strcmp(node_qual_string, op_name) == 0) {
24       return true;
25     }
26   }
27   return false;
28 }
29 
forwardHasOp(const torch::jit::script::Module & module,const char * op_name)30 bool forwardHasOp(
31     const torch::jit::script::Module& module,
32     const char* op_name) {
33   using Method = ::torch::jit::Method;
34   Method method = module.get_method("forward");
35   auto graph = method.graph();
36   return graphHasOp(graph, op_name);
37 }
38 
39 namespace {
40 C10_UNUSED
ConcatAddMulReplaceNaNClip(std::shared_ptr<torch::jit::Graph> & graph)41 void ConcatAddMulReplaceNaNClip(std::shared_ptr<torch::jit::Graph>& graph) {
42   // TODO:: check restrictions for inputs; outputs not used elsewhere
43   std::string pattern = R"IR(
44     graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
45         %y0 = aten::cat(%a, %b)
46         %y1 = aten::add(%y0, %c, %d)
47         %y2 = aten::mul(%y1, %e)
48         %y3 = aten::nan_to_num(%y2, %f, %g, %h)
49         %res = aten::clamp(%y3, %i, %j)
50         return (%res))IR";
51   std::string pattern2 = R"IR(
52     graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
53         %y0 = aten::cat(%a, %b)
54         %y1 = aten::add(%y0, %c, %d)
55         %y2 = aten::mul(%y1, %e)
56         %y3 = aten::nan_to_num_(%y2, %f, %g, %h)
57         %res = aten::clamp(%y3, %i, %j)
58         return (%res))IR";
59   std::string pattern3 = R"IR(
60     graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
61         %y0 = aten::cat(%a, %b)
62         %y1 = aten::add(%y0, %c, %d)
63         %y2 = aten::mul(%y1, %e)
64         %y3 = aten::nan_to_num_(%y2, %f, %g, %h)
65         %res = aten::clamp_(%y3, %i, %j)
66         return (%res))IR";
67   std::string pattern4 = R"IR(
68     graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
69         %y0 = aten::cat(%a, %b)
70         %y1 = aten::add(%y0, %c, %d)
71         %y2 = aten::mul(%y1, %e)
72         %y3 = aten::nan_to_num(%y2, %f, %g, %h)
73         %res = aten::clamp_(%y3, %i, %j)
74         return (%res))IR";
75   std::string fused_pattern = R"IR(
76     graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j):
77         %res = fb::concat_add_mul_replacenan_clip(%c, %e, %a, %i, %j, %b)
78         return (%res))IR";
79 
80   SubgraphRewriter fuse;
81   fuse.RegisterRewritePattern(pattern, fused_pattern);
82   fuse.runOnGraph(graph);
83 
84   fuse.RegisterRewritePattern(pattern2, fused_pattern);
85   fuse.runOnGraph(graph);
86 
87   fuse.RegisterRewritePattern(pattern3, fused_pattern);
88   fuse.runOnGraph(graph);
89 
90   fuse.RegisterRewritePattern(pattern4, fused_pattern);
91   fuse.runOnGraph(graph);
92 }
93 
94 C10_UNUSED
CastedBatchOneHotLengths(std::shared_ptr<torch::jit::Graph> & graph)95 void CastedBatchOneHotLengths(std::shared_ptr<torch::jit::Graph>& graph) {
96   // TODO:: check restrictions for inputs; outputs not used elsewhere
97   std::string pattern = R"IR(
98     graph(%a, %b, %c, %d, %e, %f, %g):
99         %y0 : Tensor = aten::to(%a, %b, %c, %c, %d)
100         %y1 : Tensor = fb::batch_one_hot_lengths(%y0, %e, %f)
101         %res : Tensor = aten::to(%y1, %g, %c, %c, %d)
102         return (%res))IR";
103   std::string fused_pattern = R"IR(
104     graph(%a, %b, %c, %d, %e, %f, %g):
105         %res : Tensor = fb::casted_batch_one_hot_lengths(%a, %e, %f)
106         return (%res))IR";
107   SubgraphRewriter fuse;
108   fuse.RegisterRewritePattern(pattern, fused_pattern);
109   fuse.runOnGraph(graph);
110 
111   std::string pattern2 = R"IR(
112     graph(%a, %b, %c, %d, %e, %f):
113         %y0 : Tensor = aten::to(%a, %b, %c, %c)
114         %y1 : Tensor = fb::batch_one_hot_lengths(%y0, %d, %e)
115         %res : Tensor = aten::to(%y1, %f, %c, %c)
116         return (%res))IR";
117   std::string fused_pattern2 = R"IR(
118     graph(%a, %b, %c, %d, %e, %f):
119         %res : Tensor = fb::casted_batch_one_hot_lengths(%a, %d, %e)
120         return (%res))IR";
121   fuse.RegisterRewritePattern(pattern2, fused_pattern2);
122   fuse.runOnGraph(graph);
123 }
124 
125 C10_UNUSED
ConcatBatchMatMulBatchGather(std::shared_ptr<torch::jit::Graph> & graph)126 void ConcatBatchMatMulBatchGather(std::shared_ptr<torch::jit::Graph>& graph) {
127   std::string pattern = R"IR(
128     graph(%a, %b, %c, %d, %e, %f):
129         %y0 : Tensor = aten::stack(%a, %b)
130         %y1 : Tensor = aten::transpose(%y0, %b, %c)
131         %y2 : Tensor = aten::bmm(%y0, %y1)
132         %y3 : Tensor = aten::flatten(%y2, %d, %e)
133         %res : Tensor = aten::index_select(%y3, %b, %f)
134         return (%res))IR";
135   std::string fused_pattern = R"IR(
136     graph(%a, %b, %c, %d, %e, %f):
137         %res : Tensor = fb::concat_batch_matmul_batch_gather(%f, %a)
138         return (%res))IR";
139   SubgraphRewriter fuse;
140   fuse.RegisterRewritePattern(pattern, fused_pattern);
141 
142   // this pattern found in several models has a redundant second `flatten`
143   std::string pattern_broadcast = R"IR(
144     graph(%a, %b, %c, %d, %e, %indices):
145         %y0 : Tensor = fb::broadcast_stack(%a, %b)
146         %y1 : Tensor = aten::transpose(%y0, %b, %c)
147         %y2 : Tensor = aten::matmul(%y0, %y1)
148         %y3 : Tensor = aten::flatten(%y2, %b, %e)
149         %y4 : Tensor = aten::flatten(%y3, %d, %d)
150         %res : Tensor = aten::index_select(%y4, %b, %indices)
151         return (%res))IR";
152   std::string fused_pattern_broadcast = R"IR(
153     graph(%a, %b, %c, %d, %e, %indices):
154         %res : Tensor = fb::broadcast_concat_batch_matmul_batch_gather(%indices, %a)
155         return (%res))IR";
156   fuse.RegisterRewritePattern(pattern_broadcast, fused_pattern_broadcast);
157 
158   std::string pattern_broadcast2 = R"IR(
159     graph(%a, %b, %c, %d, %indices):
160         %y0 : Tensor = fb::broadcast_stack(%a, %b)
161         %y1 : Tensor = aten::transpose(%y0, %b, %c)
162         %y2 : Tensor = aten::matmul(%y0, %y1)
163         %y3 : Tensor = aten::flatten(%y2, %b, %d)
164         %res : Tensor = aten::index_select(%y3, %b, %indices)
165         return (%res))IR";
166   std::string fused_pattern_broadcast2 = R"IR(
167     graph(%a, %b, %c, %d, %indices):
168         %res : Tensor = fb::broadcast_concat_batch_matmul_batch_gather(%indices, %a)
169         return (%res))IR";
170   fuse.RegisterRewritePattern(pattern_broadcast2, fused_pattern_broadcast2);
171   fuse.runOnGraph(graph);
172 }
173 
ClipRangesGatherRangesLengthsToOffsets(std::shared_ptr<torch::jit::Graph> & graph)174 C10_UNUSED void ClipRangesGatherRangesLengthsToOffsets(
175     std::shared_ptr<torch::jit::Graph>& graph) {
176   // TODO:: check restrictions for inputs; outputs not used elsewhere
177   std::string pattern = R"IR(
178     graph(%a, %b, %c, %d):
179         %y0 : Tensor = fb::clip_ranges(%b, %c)
180         %y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0)
181         %y3 : Tensor = fb::lengths_to_offsets(%y2, %d)
182         return (%y3, %y1))IR";
183   std::string fused_pattern = R"IR(
184     graph(%a, %b, %c, %d):
185         %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_lengths_to_offsets(%a, %b, %c, %d)
186         return (%y1, %y0))IR";
187   SubgraphRewriter fuse;
188   fuse.RegisterRewritePattern(pattern, fused_pattern);
189   fuse.runOnGraph(graph);
190 }
191 
ClipRangesGather(std::shared_ptr<torch::jit::Graph> & graph)192 C10_UNUSED void ClipRangesGather(std::shared_ptr<torch::jit::Graph>& graph) {
193   // TODO:: check restrictions for inputs; outputs not used elsewhere
194   // fuse without lengths-to-offsets
195   std::string pattern = R"IR(
196     graph(%a, %b, %c):
197         %y0 : Tensor = fb::clip_ranges(%b, %c)
198         %y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0)
199         return (%y2, %y1))IR";
200   std::string fused_pattern = R"IR(
201     graph(%a, %b, %c):
202         %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather(%a, %b, %c)
203         return (%y1, %y0))IR";
204   SubgraphRewriter fuse;
205   fuse.RegisterRewritePattern(pattern, fused_pattern);
206   fuse.runOnGraph(graph);
207 }
208 
PrecomputeMultiplierShiftForSigridHash(std::shared_ptr<torch::jit::Graph> & graph)209 C10_UNUSED void PrecomputeMultiplierShiftForSigridHash(
210     std::shared_ptr<torch::jit::Graph>& graph) {
211   std::string pattern = R"IR(
212     graph(%a, %b, %c, %d, %e):
213         %y0 : Tensor = fb::sigrid_hash(%a, %b, %c, %d, %e)
214         return (%y0)
215   )IR";
216   std::string split_pattern = R"IR(
217     graph(%a, %b, %c, %d, %e):
218         %y0 : Tensor = fb::sigrid_hash_compute_multipler_shift(%c)
219         %y2 : Tensor = fb::sigrid_hash_precompute(%a, %b, %c, %y0, %d, %e)
220         return (%y2)
221   )IR";
222   SubgraphRewriter fuse;
223   fuse.RegisterRewritePattern(pattern, split_pattern);
224   fuse.runOnGraph(graph);
225 }
226 
ClipRangesToGatherToOffsets(std::shared_ptr<torch::jit::Graph> & graph)227 C10_UNUSED void ClipRangesToGatherToOffsets(
228     std::shared_ptr<torch::jit::Graph>& graph) {
229   std::string pattern = R"IR(
230     graph(%a, %b, %c, %d, %to0_in0, %to0_in1, %to0_in2):
231         %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather(%a, %b, %c)
232         %y2 : Tensor = aten::to(%y1, %to0_in0, %to0_in1, %to0_in1, %to0_in2)
233         %y3 : Tensor = fb::lengths_to_offsets(%y2, %d)
234         return (%y3, %y0))IR";
235   std::string fused_pattern = R"IR(
236     graph(%a, %b, %c, %d, %to0_in0, %to0_in1, %to0_in2):
237         %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_to_offsets(%a, %b, %c, %d, %to0_in0)
238         return (%y1, %y0))IR";
239   SubgraphRewriter fuse;
240   fuse.RegisterRewritePattern(pattern, fused_pattern);
241   fuse.runOnGraph(graph);
242 
243   std::string pattern2 = R"IR(
244     graph(%a, %b, %c, %d, %to0_in0, %to0_in1):
245         %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather(%a, %b, %c)
246         %y2 : Tensor = aten::to(%y1, %to0_in0, %to0_in1, %to0_in1)
247         %y3 : Tensor = fb::lengths_to_offsets(%y2, %d)
248         return (%y3, %y0))IR";
249   std::string fused_pattern2 = R"IR(
250     graph(%a, %b, %c, %d, %to0_in0, %to0_in1):
251         %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_to_offsets(%a, %b, %c, %d, %to0_in0)
252         return (%y1, %y0))IR";
253   fuse.RegisterRewritePattern(pattern2, fused_pattern2);
254   fuse.runOnGraph(graph);
255 }
256 
ToLengthsToOffsets(std::shared_ptr<torch::jit::Graph> & graph)257 C10_UNUSED void ToLengthsToOffsets(std::shared_ptr<torch::jit::Graph>& graph) {
258   std::string pattern = R"IR(
259     graph(%a, %includelastoffset, %dtype, %nonblocking, %copy, %memoryformat):
260         %y0 : Tensor = aten::to(%a, %dtype, %nonblocking, %copy, %memoryformat)
261         %y1 : Tensor = fb::lengths_to_offsets(%y0, %includelastoffset)
262         return (%y1))IR";
263   std::string fused_pattern = R"IR(
264     graph(%a, %includelastoffset, %dtype, %nonblocking, %copy, %memoryformat):
265         %y0 : Tensor = fb::to_lengths_to_offsets(%a, %includelastoffset, %dtype)
266         return (%y0))IR";
267   SubgraphRewriter fuse;
268   fuse.RegisterRewritePattern(pattern, fused_pattern);
269   fuse.runOnGraph(graph);
270 
271   std::string pattern2 = R"IR(
272     graph(%a, %includelastoffset, %dtype, %nonblocking, %copy):
273         %y0 : Tensor = aten::to(%a, %dtype, %nonblocking, %copy)
274         %y1 : Tensor = fb::lengths_to_offsets(%y0, %includelastoffset)
275         return (%y1))IR";
276   std::string fused_pattern2 = R"IR(
277     graph(%a, %includelastoffset, %dtype, %nonblocking, %copy):
278         %y0 : Tensor = fb::to_lengths_to_offsets(%a, %includelastoffset, %dtype)
279         return (%y0))IR";
280   fuse.RegisterRewritePattern(pattern2, fused_pattern2);
281   fuse.runOnGraph(graph);
282 }
283 
284 C10_UNUSED
ClipRangesGatherSigridHash(std::shared_ptr<torch::jit::Graph> & graph)285 void ClipRangesGatherSigridHash(std::shared_ptr<torch::jit::Graph>& graph) {
286   // TODO:: check restrictions for inputs; outputs not used elsewhere
287   std::string pattern = R"IR(
288     graph(%a, %b, %c, %d, %e, %f, %g, %h):
289         %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_lengths_to_offsets(%a, %b, %c, %d)
290         %y2 : Tensor = fb::sigrid_hash_precompute(%y0, %e, %f, %g, %h)
291         return (%y2, %y1))IR";
292   std::string fused_pattern = R"IR(
293     graph(%a, %b, %c, %d, %e, %f, %g, %h):
294         %off : Tensor, %out : Tensor = fb::clip_ranges_gather_sigrid_hash_precompute_offsets(%b, %a, %c, %e, %f, %g, %h, %d)
295         return (%out, %off))IR";
296   SubgraphRewriter fuse;
297   fuse.RegisterRewritePattern(pattern, fused_pattern);
298   fuse.runOnGraph(graph);
299 }
300 
ClipRangesGatherRangesSigridHash(std::shared_ptr<torch::jit::Graph> & graph)301 C10_UNUSED void ClipRangesGatherRangesSigridHash(
302     std::shared_ptr<torch::jit::Graph>& graph) {
303   std::string pattern = R"IR(
304     graph(%a, %b, %c, %d, %e, %f, %g):
305         %y0 : Tensor = fb::clip_ranges(%b, %c)
306         %y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0)
307         %y3 : Tensor = fb::sigrid_hash_precompute(%y1, %d, %e, %f, %g)
308         return (%y3, %y2))IR";
309   std::string fused_pattern = R"IR(
310     graph(%a, %b, %c, %d, %e, %f, %g):
311         %off : Tensor, %out : Tensor = fb::clip_ranges_gather_sigrid_hash_precompute_v3(%b, %a, %c, %d, %e, %f, %g)
312         return (%out, %off))IR";
313 
314   SubgraphRewriter fuse;
315   fuse.RegisterRewritePattern(pattern, fused_pattern);
316   fuse.runOnGraph(graph);
317 }
318 
ClipRangesGatherRangesX2SigridHashPrecompute(std::shared_ptr<torch::jit::Graph> & graph)319 C10_UNUSED void ClipRangesGatherRangesX2SigridHashPrecompute(
320     std::shared_ptr<torch::jit::Graph>& graph) {
321   // Placeholder is a dummy op used to capture the first subgraph
322   std::string pattern = R"IR(
323     graph(%ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32):
324         %clipped : Tensor = fb::clip_ranges(%ranges, %max_length)
325         %output : Tensor, %unused : Tensor = fb::gather_ranges(%values, %clipped)
326         %sigrid_hash_out : Tensor = fb::sigrid_hash_precompute(%output, %salt, %max_value, %mul_shift, %hash_into_int32)
327         return (%sigrid_hash_out, %clipped))IR";
328   std::string fused_pattern = R"IR(
329     graph(%ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32):
330         %sigrid_hash_out : Tensor, %clipped : Tensor = fb::placeholder(%ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32)
331         return (%sigrid_hash_out, %clipped))IR";
332 
333   // the second gather_ranges can be eliminated because the `lengths` is
334   // produces is identical to the lengths produced by
335   // clip_ranges_gather_sigrid_hash_v3 (caveat, the fused ops makes some
336   // simplifying assumptions about the ranges input)
337   std::string pattern2 = R"IR(
338     graph(%gather2_values, %ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32):
339         %sigrid_hash_out : Tensor, %clipped : Tensor = fb::placeholder(%ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32)
340         %unused : Tensor, %lengths : Tensor = fb::gather_ranges(%gather2_values, %clipped)
341         return (%lengths, %sigrid_hash_out))IR";
342 
343   std::string fused_pattern2 = R"IR(
344     graph(%gather2_values, %ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32):
345         %lengths : Tensor, %sigrid_hash_out : Tensor = fb::clip_ranges_gather_sigrid_hash_precompute_v3(%ranges, %values, %max_length, %salt, %max_value, %mul_shift, %hash_into_int32)
346         return (%lengths, %sigrid_hash_out))IR";
347 
348   SubgraphRewriter fuse;
349   fuse.RegisterRewritePattern(pattern, fused_pattern);
350   fuse.runOnGraph(graph);
351 
352   fuse.RegisterRewritePattern(pattern2, fused_pattern2);
353   fuse.runOnGraph(graph);
354 
355   // reverse the ops that got fused in step 1 but not in step2
356   fuse.RegisterRewritePattern(fused_pattern, pattern);
357   fuse.runOnGraph(graph);
358 }
359 
SplitOutPrecomputeOpsForSparseNN(std::shared_ptr<torch::jit::Graph> & graph)360 C10_UNUSED void SplitOutPrecomputeOpsForSparseNN(
361     std::shared_ptr<torch::jit::Graph>& graph) {
362 #ifdef FBCODE_CAFFE2
363   PrecomputeMultiplierShiftForSigridHash(graph);
364   ConstantPropagation(graph);
365   ConstantPooling(graph);
366 #endif
367 }
368 } // namespace
369 
FuseInferenceOpsForSparseNN(std::shared_ptr<torch::jit::Graph> & graph)370 void FuseInferenceOpsForSparseNN(std::shared_ptr<torch::jit::Graph>& graph) {
371 #ifdef FBCODE_CAFFE2
372   SplitOutPrecomputeOpsForSparseNN(graph);
373 
374   ConcatAddMulReplaceNaNClip(graph);
375   CastedBatchOneHotLengths(graph);
376   ConcatBatchMatMulBatchGather(graph);
377 
378   if (FLAGS_enable_clip_ranges_gather_fusions) {
379     ClipRangesGatherRangesLengthsToOffsets(graph);
380   }
381   ClipRangesGatherSigridHash(graph);
382   ClipRangesGatherRangesSigridHash(graph);
383 
384   ClipRangesGatherRangesX2SigridHashPrecompute(graph);
385 
386   if (FLAGS_enable_clip_ranges_gather_fusions) {
387     // prioritize clip_ranges+gather_ranges+sigrid_hash fusion over
388     // clip_ranges+gather_ranges
389     ClipRangesGather(graph);
390 
391     ClipRangesToGatherToOffsets(graph);
392   }
393 
394   ToLengthsToOffsets(graph);
395 #endif
396 }
397 
TORCH_LIBRARY_FRAGMENT(static_runtime,m)398 TORCH_LIBRARY_FRAGMENT(static_runtime, m) {
399   m.def(torch::schema(
400       "static_runtime::permute_copy(Tensor self, int[] dims) -> Tensor",
401       c10::AliasAnalysisKind::PURE_FUNCTION));
402   m.def(torch::schema(
403       "static_runtime::reshape_copy(Tensor self, int[] shape) -> Tensor",
404       c10::AliasAnalysisKind::PURE_FUNCTION));
405   m.def(torch::schema(
406       "static_runtime::flatten_copy.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor",
407       c10::AliasAnalysisKind::PURE_FUNCTION));
408   m.def(torch::schema(
409       "static_runtime::expand_dims_copy(Tensor input, int[] dims) -> Tensor",
410       c10::AliasAnalysisKind::PURE_FUNCTION));
411   m.def(torch::schema(
412       "static_runtime::to_maybe_copy_out.prim_dtype(Tensor self, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor, bool)",
413       c10::AliasAnalysisKind::PURE_FUNCTION));
414   m.def(torch::schema(
415       "static_runtime::to_maybe_copy_out.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> (Tensor, bool)",
416       c10::AliasAnalysisKind::PURE_FUNCTION));
417   m.def(torch::schema(
418       "static_runtime::to_maybe_copy_out.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> (Tensor, bool)",
419       c10::AliasAnalysisKind::PURE_FUNCTION));
420   m.def(torch::schema(
421       "static_runtime::to_copy.prim_dtype(Tensor self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor",
422       c10::AliasAnalysisKind::PURE_FUNCTION));
423   m.def(torch::schema(
424       "static_runtime::to_copy.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor",
425       c10::AliasAnalysisKind::PURE_FUNCTION));
426   m.def(torch::schema(
427       "static_runtime::to_copy.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor",
428       c10::AliasAnalysisKind::PURE_FUNCTION));
429   m.def(torch::schema(
430       "static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor, Tensor, Tensor)",
431       c10::AliasAnalysisKind::PURE_FUNCTION));
432   m.def("static_runtime::signed_log1p(Tensor input) -> Tensor");
433   m.def(torch::schema(
434       "static_runtime::dict_unpack(...) -> ...",
435       c10::AliasAnalysisKind::CONSERVATIVE));
436   m.def(torch::schema(
437       "static_runtime::VarTupleUnpack(...) -> ...",
438       c10::AliasAnalysisKind::CONSERVATIVE));
439   m.def(torch::schema(
440       "static_runtime::fused_equally_split(Tensor input, int num_split, int dim) -> ...",
441       c10::AliasAnalysisKind::PURE_FUNCTION));
442   m.def(torch::schema(
443       "static_runtime::dequantize_copy.self(Tensor self) -> Tensor",
444       c10::AliasAnalysisKind::PURE_FUNCTION));
445   m.def(torch::schema(
446       "static_runtime::select_tensor(Tensor(a) a, Tensor(b) b, bool use_b) -> Tensor(a|b)",
447       c10::AliasAnalysisKind::FROM_SCHEMA));
448   m.def(torch::schema(
449       "static_runtime::create_owned_ref(...) -> ...",
450       c10::AliasAnalysisKind::CONSERVATIVE));
451   m.def(torch::schema(
452       "static_runtime::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor)",
453       c10::AliasAnalysisKind::PURE_FUNCTION));
454   m.def(torch::schema(
455       "static_runtime::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor)",
456       c10::AliasAnalysisKind::PURE_FUNCTION));
457   m.def(torch::schema(
458       "static_runtime::clamp_nan_to_num(Tensor input, Scalar? min, Scalar? max, float? nan, float? posinf, float? posinf) -> Tensor",
459       c10::AliasAnalysisKind::PURE_FUNCTION));
460 }
461 
FuseSignLog1P(std::shared_ptr<torch::jit::Graph> & graph)462 void FuseSignLog1P(std::shared_ptr<torch::jit::Graph>& graph) {
463   std::string pattern = R"IR(
464     graph(%input):
465         %0 : Tensor = aten::sign(%input)
466         %1 : Tensor = aten::abs(%input)
467         %2 : Tensor = aten::log1p(%1)
468         %res : Tensor = aten::mul(%0, %2)
469         return (%res)
470   )IR";
471 
472   std::string fused_pattern = R"IR(
473     graph(%input):
474         %res : Tensor = static_runtime::signed_log1p(%input)
475         return (%res)
476     )IR";
477 
478   SubgraphRewriter fuse;
479   fuse.RegisterRewritePattern(pattern, fused_pattern);
480   fuse.runOnGraph(graph);
481 }
482 
483 namespace {
484 
485 using TupleUnpackBlock = std::vector<Node*>;
486 
CollectVariadicTupleUnpackFusionCandidates(const std::shared_ptr<Graph> & graph)487 std::vector<TupleUnpackBlock> CollectVariadicTupleUnpackFusionCandidates(
488     const std::shared_ptr<Graph>& graph) {
489   std::vector<TupleUnpackBlock> candidates;
490   auto nodes = graph->nodes();
491   std::vector<Node*> block;
492   for (Node* cur_node : nodes) {
493     if (cur_node->kind() == prim::TupleUnpack) {
494       block.push_back(cur_node);
495       continue;
496     }
497     if (block.size() > 1) {
498       candidates.emplace_back(std::move(block));
499     }
500     block.clear();
501   }
502   TORCH_CHECK(block.empty());
503   return candidates;
504 }
505 
FuseTupleUnpackBlock(const TupleUnpackBlock & nodes)506 void FuseTupleUnpackBlock(const TupleUnpackBlock& nodes) {
507   TORCH_CHECK(!nodes.empty());
508   auto graph = nodes[0]->owningGraph();
509   auto var_unpack = graph->create(
510       fromQualString("static_runtime::VarTupleUnpack"),
511       /* num_outputs */ 0);
512   var_unpack->insertAfter(nodes[nodes.size() - 1]);
513   for (Node* node : nodes) {
514     TORCH_CHECK(
515         node->kind() == prim::TupleUnpack && node->inputs().size() == 1);
516     var_unpack->addInput(node->input());
517 
518     for (Value* output : node->outputs()) {
519       auto new_output = var_unpack->addOutput();
520       new_output->copyMetadata(output);
521       output->replaceAllUsesWith(new_output);
522     }
523     node->destroy();
524   }
525 }
526 
527 } // namespace
528 
UseVariadicTupleUnpack(const std::shared_ptr<Graph> & graph)529 void UseVariadicTupleUnpack(const std::shared_ptr<Graph>& graph) {
530   for (auto& c : CollectVariadicTupleUnpackFusionCandidates(graph)) {
531     FuseTupleUnpackBlock(c);
532   }
533 }
534 
535 // This macro makes maps from c10::Symbol -> c10::Symbol a lot easier to read.
536 #define OP_PAIR(first, second) \
537   { fromQualString(first), fromQualString(second) }
538 
539 // Out variants of ops cannot participate in memory planning if they
540 // have outputs that alias inputs. For ops that either return their
541 // input directly or copy it (most notably aten::to), we adopt the
542 // following strategy instead of directly making them out variants so
543 // that they can participate in memory planning anyway. Let `a` denote
544 // the input Tensor to the op.
545 //
546 // 1) Pass `a` (and the other operator inputs) to a special
547 // `static_runtime::$OP_maybe_copy_out` variant of the op. This op
548 // returns a normal output Tensor (call it `b_out` as well as a
549 // `did_copy` flag indicating whether the output should be used. If
550 // `did_copy` is false, the value of `b_out` is unspecified. Note that
551 // this operator is an ordinary out variant that is perfectly amenable
552 // to memory planning.
553 //
554 // 2) Pass `a`, `b_out`, and `did_copy` to a special
555 // `static_runtime::select_tensor` op, which returns `b_out` if
556 // `did_copy` is true and `a` otherwise. Note that this operator does
557 // not need to participate in memory planning because its output
558 // always aliases one of its inputs.
559 //
560 // Here is an illustration:
561 //
562 //                        |
563 // |----------------------+ a
564 // |                      v
565 // |    +------------------------------------+
566 // |    |                                    |
567 // |    | static_runtime::$OP_maybe_copy_out |
568 // |    |                                    |
569 // |    +------------------+--------+--------+
570 // |                       |        |
571 // +--------------+        | b_out  | did_copy
572 //                | a      |        |
573 //                v        v        v
574 //      +------------------------------------+
575 //      |                                    |
576 //      |    static_runtime::select_tensor   |
577 //      |                                    |
578 //      +------------------+-----------------+
579 //                         |
580 //                         |
581 //                         | either a or b_out
582 //                         |
583 //                         v
584 
ReplaceWithMaybeCopy(std::shared_ptr<Graph> & graph,bool outputs_are_immutable)585 void ReplaceWithMaybeCopy(
586     std::shared_ptr<Graph>& graph,
587     bool outputs_are_immutable) {
588   AliasDb db(graph);
589   // for ops that have overloads, match the schema
590   static const std::array<std::pair<c10::FunctionSchema, c10::Symbol>, 3> supported_schema =
591       {{{torch::schema(
592              "aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"),
593          fromQualString("static_runtime::to_maybe_copy_out")},
594         {torch::schema(
595              "aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"),
596          fromQualString("static_runtime::to_maybe_copy_out")},
597         {torch::schema(
598              "aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"),
599          fromQualString("static_runtime::to_maybe_copy_out")}}};
600 
601   auto match_schema = [](const Node* node, c10::Symbol& out_matched_symbol) {
602     for (auto& schema : supported_schema) {
603       if (node->matches(schema.first)) {
604         out_matched_symbol = schema.second;
605         return true;
606       }
607     }
608     return false;
609   };
610 
611   // old node, new node, select_tensor node
612   std::vector<std::tuple<Node*, Node*, Node*>> replacement;
613   DepthFirstGraphNodeIterator graph_it(graph);
614   for (auto n = graph_it.next(); n != nullptr; n = graph_it.next()) {
615     c10::Symbol new_symbol;
616     if (!match_schema(n, new_symbol)) {
617       continue;
618     }
619     TORCH_CHECK(n->outputs().size() == 1);
620 
621     // Duplicate input writers guard from ReplaceWithCopy below.
622     if (db.hasInputWriters(n)) {
623       continue;
624     }
625 
626     auto* out = n->output();
627     if (!outputs_are_immutable && db.mayContainAlias(out, graph->outputs())) {
628       continue;
629     }
630 
631     // Add the did_copy flag to outputs.
632     auto* new_node = graph->create(new_symbol, n->outputs().size() + 1);
633     for (auto* input : n->inputs()) {
634       new_node->addInput(input);
635     }
636     new_node->outputs().at(1)->setType(c10::BoolType::get());
637 
638     static const auto select_tensor_symbol =
639         fromQualString("static_runtime::select_tensor");
640     auto* select_tensor_node = graph->create(select_tensor_symbol, 1);
641     TORCH_DCHECK_EQ(new_node->outputs().size(), 2);
642     select_tensor_node->addInput(n->input(0));
643     for (auto* output : new_node->outputs()) {
644       select_tensor_node->addInput(output);
645     }
646     replacement.emplace_back(n, new_node, select_tensor_node);
647   }
648 
649   for (const auto& tup : replacement) {
650     auto* const old_node = std::get<0>(tup);
651     auto* const new_node = std::get<1>(tup);
652     auto* const select_tensor_node = std::get<2>(tup);
653 
654     new_node->insertBefore(old_node);
655     select_tensor_node->insertBefore(old_node);
656     new_node->outputs()[0]->copyMetadata(old_node->output());
657     select_tensor_node->output()->copyMetadata(old_node->output());
658     old_node->replaceAllUsesWith(select_tensor_node);
659     old_node->destroy();
660   }
661 #ifndef NDEBUG
662   graph->lint();
663   AliasDb db2(graph);
664   torch::jit::Lint(&db2);
665 #endif
666 }
667 
ReplaceWithCopyImpl(std::shared_ptr<Graph> & graph,const c10::FastMap<c10::Symbol,c10::Symbol> & supported,const std::vector<std::pair<c10::FunctionSchema,c10::Symbol>> & supported_schema,const std::function<bool (Node *)> & f_extra_checks,bool outputs_are_immutable)668 static void ReplaceWithCopyImpl(
669     std::shared_ptr<Graph>& graph,
670     const c10::FastMap<c10::Symbol, c10::Symbol>& supported,
671     const std::vector<std::pair<c10::FunctionSchema, c10::Symbol>>&
672         supported_schema,
673     const std::function<bool(Node*)>& f_extra_checks,
674     bool outputs_are_immutable) {
675   AliasDb db(graph);
676 
677   auto match_schema = [&supported_schema](
678                           const Node* node, c10::Symbol& out_matched_symbol) {
679     for (auto& schema : supported_schema) {
680       if (node->matches(schema.first)) {
681         out_matched_symbol = schema.second;
682         return true;
683       }
684     }
685     return false;
686   };
687 
688   std::vector<std::pair<Node*, Node*>> replacement;
689   DepthFirstGraphNodeIterator graph_it(graph);
690   for (auto n = graph_it.next(); n != nullptr; n = graph_it.next()) {
691     c10::Symbol new_symbol;
692     if (supported.count(n->kind()) && opIsRegistered(supported.at(n->kind()))) {
693       new_symbol = supported.at(n->kind());
694     } else if (!match_schema(n, new_symbol)) {
695       continue;
696     }
697     TORCH_CHECK(n->outputs().size() == 1);
698 
699     // We do not want to replace operators with their copy variant when the
700     // inputs to the operators have writers (can be updated). With an output
701     // that aliases to the input, updates to the input will be visible to the
702     // operator's output as well. For example:
703     //
704     // def forward(self, inp: Tensor, shape: List[int]):
705     //   a = inp + inp
706     //   b = a.reshape(shape)
707     //   c = b.sigmoid_()
708     //   d = c + c
709     //   e = a + a
710     //   f = b + b
711     //   return (d, e, f)
712     //
713     // b and c are aliases of a, sigmoid_ changes b, c, as well as a. e should
714     // equal to d in this case. If we replace reshape with the copy version, b
715     // and c are no longer aliases of a, the value of e would change as a
716     // result. To keep static runtime consistent with the jit interpreter, here
717     // we choose not to replace reshape with the copy version
718     if (db.hasInputWriters(n)) {
719       continue;
720     }
721 
722     auto* out = n->output();
723     if (!outputs_are_immutable && db.mayContainAlias(out, graph->outputs())) {
724       continue;
725     }
726     if (!f_extra_checks(n)) {
727       continue;
728     }
729     auto* new_node = graph->create(new_symbol, n->outputs().size());
730     for (auto* input : n->inputs()) {
731       new_node->addInput(input);
732     }
733     replacement.emplace_back(n, new_node);
734   }
735 
736   for (const auto& p : replacement) {
737     auto* old_node = p.first;
738     auto* new_node = p.second;
739     new_node->insertBefore(old_node);
740     new_node->output()->copyMetadata(old_node->output());
741     old_node->replaceAllUsesWith(new_node);
742     old_node->destroy();
743   }
744 #ifndef NDEBUG
745   graph->lint();
746   AliasDb db2(graph);
747   torch::jit::Lint(&db2);
748 #endif
749 }
750 
751 // replace aten::permute with copy version only when it's followed by
752 // reshape/flatten. It's only enabled when ReplaceWithCopy is off.
ReplacePermuteWithCopy(std::shared_ptr<Graph> & graph,bool outputs_are_immutable)753 void ReplacePermuteWithCopy(
754     std::shared_ptr<Graph>& graph,
755     bool outputs_are_immutable) {
756   AliasDb db(graph);
757   const c10::FastMap<c10::Symbol, c10::Symbol> supported = {
758 #ifdef FBCODE_CAFFE2
759       OP_PAIR("aten::permute", "static_runtime::permute_copy"),
760 #endif
761   };
762   auto f_extra_checks = [](Node* n) {
763     Value* out = n->output();
764     Node* next_node = out->uses()[0].user;
765     if (next_node->kind() != aten::reshape ||
766         next_node->kind() != aten::flatten) {
767       return true;
768     }
769     return false;
770   };
771   ReplaceWithCopyImpl(
772       graph, supported, {}, f_extra_checks, outputs_are_immutable);
773 }
774 
ReplaceWithCopy(std::shared_ptr<Graph> & graph,bool outputs_are_immutable)775 void ReplaceWithCopy(
776     std::shared_ptr<Graph>& graph,
777     bool outputs_are_immutable) {
778   AliasDb db(graph);
779   const c10::FastMap<c10::Symbol, c10::Symbol> supported = {
780 #ifdef FBCODE_CAFFE2
781       OP_PAIR("aten::permute", "static_runtime::permute_copy"),
782       OP_PAIR("fb::expand_dims", "static_runtime::expand_dims_copy"),
783 #endif
784       OP_PAIR("aten::narrow", "aten::narrow_copy"),
785       OP_PAIR("aten::reshape", "static_runtime::reshape_copy"),
786       OP_PAIR("aten::flatten", "static_runtime::flatten_copy")};
787 
788   static const std::vector<std::pair<c10::FunctionSchema, c10::Symbol>>
789       supported_schema = {
790           {{torch::schema("aten::dequantize.self(Tensor self) -> Tensor"),
791             fromQualString("static_runtime::dequantize_copy")}}};
792 
793   ReplaceWithCopyImpl(
794       graph,
795       supported,
796       supported_schema,
797       [](Node* n) { return true; },
798       outputs_are_immutable);
799 }
800 
EliminateTrivialEquallySplit(std::shared_ptr<torch::jit::Graph> & graph)801 void EliminateTrivialEquallySplit(std::shared_ptr<torch::jit::Graph>& graph) {
802   const auto equally_split = fromQualString("fb::equally_split");
803   std::vector<Node*> to_remove;
804   DepthFirstGraphNodeIterator graph_it(graph);
805   for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
806     if (node->kind() != equally_split) {
807       continue;
808     }
809 
810     const Value* value_out = node->outputs()[0];
811     if (value_out->uses().size() != 1) {
812       continue;
813     }
814 
815     Node* list_unpack_node = value_out->uses()[0].user;
816     if (list_unpack_node->kind() != prim::ListUnpack) {
817       continue;
818     }
819 
820     auto list_unpack_outputs = list_unpack_node->outputs();
821     if (list_unpack_outputs.size() != 1) {
822       continue;
823     }
824 
825     list_unpack_node->output()->replaceAllUsesWith(node->input(0));
826     to_remove.push_back(list_unpack_node);
827     to_remove.push_back(node);
828   }
829 
830   for (Node* node : to_remove) {
831     node->destroy();
832   }
833 }
834 
835 namespace {
836 
shouldNotFuseListUnpackSpecialCase(const Node * node)837 bool shouldNotFuseListUnpackSpecialCase(const Node* node) {
838   const static std::array<c10::Symbol, 3> sigrid_transforms_symbols{
839       c10::Symbol::fromQualString("fb::variadic_sigrid_transforms_torch_bind"),
840       c10::Symbol::fromQualString("fb::sigrid_transforms_torch_bind"),
841       c10::Symbol::fromQualString("fb::sigrid_transforms")};
842 
843   if (std::find(
844           sigrid_transforms_symbols.begin(),
845           sigrid_transforms_symbols.end(),
846           node->kind()) == sigrid_transforms_symbols.end()) {
847     return false;
848   }
849 
850   // To fuse with sigrid transforms, we must be able to statically determine
851   // `instance` and `use_offsets` - these two together let us statically
852   // determine the types of the outputs. Rationale: it is a huge pain to write
853   // fused sigrid transforms without static type information, and these two
854   // arguments are indeed statically known in every model we've seen.
855   // The reason why trying to fuse the outputs is annoying without static type
856   // information is that, if one of the outputs is not managed, you need to
857   // reset to an empty tensor of the correct type each iteration. So, if we
858   // can't collect types ahead of time, we would have to do it lazily on the
859   // first iteration, which would could be wasteful in terms of time/memory
860   // - either each thread would have its own set of output types, or we would
861   // need a lock to prevent data races.
862   const auto num_inputs = node->inputs().size();
863   return !toIValue(node->input(0)).has_value() ||
864       !toIValue(node->input(num_inputs - 1)).has_value();
865 }
866 
867 } // namespace
868 
FuseListUnpack(std::shared_ptr<torch::jit::Graph> & graph)869 void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
870   const c10::FastMap<c10::Symbol, c10::Symbol> unfused_to_fused = {
871       OP_PAIR(
872           "torcharrow::inference_wrapper_run_flat",
873           "static_runtime::fused_inference_wrapper_run_flat"),
874       OP_PAIR(
875           "torcharrow::variadic_inference_wrapper_run_flat",
876           "static_runtime::fused_variadic_inference_wrapper_run_flat"),
877       OP_PAIR("fb::equally_split", "static_runtime::fused_equally_split"),
878       OP_PAIR(
879           "fb::sigrid_transforms", "static_runtime::fused_sigrid_transforms"),
880       OP_PAIR(
881           "static_runtime::variadic_grouped_accessor_op_v2",
882           "static_runtime::fused_variadic_grouped_accessor_op_v2"),
883       OP_PAIR(
884           "fb::sigrid_transforms_torch_bind",
885           "static_runtime::fused_sigrid_transforms_torch_bind"),
886       OP_PAIR(
887           "fb::variadic_sigrid_transforms_torch_bind",
888           "static_runtime::fused_variadic_sigrid_transforms_torch_bind"),
889       OP_PAIR(
890           "fb::gather_ranges_to_dense",
891           "static_runtime::fused_gather_ranges_to_dense"),
892       OP_PAIR(
893           "fb::gather_ranges_to_dense_v2",
894           "static_runtime::fused_gather_ranges_to_dense_v2"),
895       OP_PAIR(
896           "fb::split_and_squeeze",
897           "static_runtime::fused_split_and_squeeze_copy")};
898 
899   // replacement contains (old_node, new_node, list_unpack_node)
900   std::vector<std::tuple<Node*, Node*, Node*>> replacement;
901   DepthFirstGraphNodeIterator graph_it(graph);
902   for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
903     auto unfused_to_fused_it = unfused_to_fused.find(node->kind());
904     if (unfused_to_fused_it == unfused_to_fused.end()) {
905       continue;
906     }
907 
908     const Value* value_out = node->outputs()[0];
909     if (value_out->uses().size() != 1) {
910       continue;
911     }
912 
913     Node* list_unpack_node = value_out->uses()[0].user;
914     if (list_unpack_node->kind() != prim::ListUnpack) {
915       continue;
916     }
917 
918     auto list_unpack_outputs = list_unpack_node->outputs();
919     if (list_unpack_outputs.empty()) {
920       continue;
921     }
922 
923     if (shouldNotFuseListUnpackSpecialCase(node)) {
924       continue;
925     }
926 
927     const auto& new_sym = unfused_to_fused_it->second;
928     auto* new_node = graph->create(new_sym, 0);
929 
930     for (Value* in : node->inputs()) {
931       new_node->addInput(in);
932     }
933 
934     for (Value* out : list_unpack_outputs) {
935       Value* new_out = new_node->addOutput();
936       new_out->copyMetadata(out);
937       out->replaceAllUsesWith(new_out);
938     }
939     replacement.emplace_back(node, new_node, list_unpack_node);
940   }
941 
942   for (const auto& nodes : replacement) {
943     auto* old_node = std::get<0>(nodes);
944     auto* new_node = std::get<1>(nodes);
945     auto* list_unpack_node = std::get<2>(nodes);
946 
947     new_node->insertAfter(old_node);
948     list_unpack_node->destroy();
949     old_node->destroy();
950   }
951 } // namespace jit
952 
RemoveImmutableInputDictLookups(std::shared_ptr<torch::jit::Graph> & graph)953 void RemoveImmutableInputDictLookups(
954     std::shared_ptr<torch::jit::Graph>& graph) {
955   auto nodes = graph->nodes();
956   AliasDb db(graph);
957   // Gather all dict -> getitems where dict is immutable and getitems use
958   // constant keys.
959   std::unordered_map<Value*, std::vector<Node*>> dict_to_getitems;
960   std::unordered_set<Node*> keys;
961   for (Node* node : nodes) {
962     // Find aten::__getitem__(%dict, %constant_key).
963     if (node->kind() != aten::__getitem__) {
964       continue;
965     }
966     Node* getitem_node = node;
967     Value* dict = getitem_node->input(0);
968     if (db.hasWriters(dict)) {
969       // Mutable dict. Skip this optimization.
970       continue;
971     }
972     if (dict->type()->kind() != TypeKind::DictType ||
973         dict->node() != graph->param_node()) {
974       continue;
975     }
976     DCHECK(getitem_node->inputs().size() == 2);
977     Node* key = getitem_node->input(1)->node();
978     if (key->kind() != prim::Constant) {
979       continue;
980     }
981     keys.insert(key);
982     auto iter = dict_to_getitems.find(dict);
983     if (iter == dict_to_getitems.end()) {
984       dict_to_getitems.emplace(dict, std::vector<Node*>{getitem_node});
985       continue;
986     }
987     iter->second.push_back(getitem_node);
988   }
989   if (keys.empty()) {
990     return;
991   }
992   // Move all keys to the beginning of the graph and insert new dict_unpack
993   // nodes after that.
994   auto* marker = graph->create(prim::Constant);
995   graph->prependNode(marker);
996   graph->setInsertPoint(marker);
997   for (Node* key : keys) {
998     DCHECK(key->inputs().empty());
999     key->moveBefore(marker);
1000   }
1001   const c10::Symbol static_runtime_dict_unpack_symbol =
1002       fromQualString("static_runtime::dict_unpack");
1003   for (auto& it : dict_to_getitems) {
1004     Value* dict = it.first;
1005     std::vector<Node*>& getitems = it.second;
1006     DCHECK(!getitems.empty());
1007     auto* dict_unpack =
1008         graph->create(static_runtime_dict_unpack_symbol, getitems.size());
1009     graph->insertNode(dict_unpack);
1010     dict_unpack->addInput(getitems[0]->input(0));
1011     for (size_t i = 0; i < getitems.size(); ++i) {
1012       Node* getitem_node = getitems[i];
1013       DCHECK(getitem_node->input(0) == dict);
1014       dict_unpack->addInput(getitem_node->input(1));
1015       dict_unpack->output(i)->copyMetadata(getitem_node->output());
1016       getitem_node->output(0)->replaceAllUsesWith(dict_unpack->output(i));
1017       getitem_node->destroy();
1018     }
1019   }
1020   graph->setInsertPoint(graph->block());
1021   marker->destroy();
1022 }
1023 
UseVariadicGroupedAccessor(const std::shared_ptr<Graph> & graph)1024 void UseVariadicGroupedAccessor(const std::shared_ptr<Graph>& graph) {
1025   UseVariadicOp(
1026       graph,
1027       fromQualString("grouped_accessor::grouped_accessor_op_v2"),
1028       fromQualString("static_runtime::variadic_grouped_accessor_op_v2"));
1029   UseVariadicOp(
1030       graph,
1031       fromQualString("fb::grouped_accessor_op_async"),
1032       fromQualString("static_runtime::variadic_grouped_accessor_op_async"));
1033 }
1034 
1035 namespace {
1036 
CreateOwnedRefsForSpecialValuesHelper(Graph & graph,Block * block)1037 void CreateOwnedRefsForSpecialValuesHelper(Graph& graph, Block* block) {
1038   for (auto* node : block->nodes()) {
1039     for (auto* sub_block : node->blocks()) {
1040       CreateOwnedRefsForSpecialValuesHelper(graph, sub_block);
1041     }
1042   }
1043 
1044   auto outputs = block->outputs();
1045   // Create owned refs for inputs. Otherwise, the input cleanup process
1046   // will destroy our outputs before we return.
1047   c10::FastSet<Value*> inputs = {
1048       block->inputs().begin(), block->inputs().end()};
1049 
1050   for (const auto i : c10::irange(outputs.size())) {
1051     auto* output = outputs[i];
1052 
1053     if (output->type()->kind() == c10::TypeKind::NoneType) {
1054       // No need to create owned refs of NoneType since moving
1055       // from None will have no effect
1056       continue;
1057     }
1058 
1059     if ((inputs.find(output) != inputs.end()) || toIValue(output).has_value() ||
1060         // If the output's owning block is not this one, it's from an outer
1061         // scope
1062         output->node()->owningBlock() != block) {
1063       auto* create_owned_ref_node =
1064           graph.create(fromQualString("static_runtime::create_owned_ref"));
1065       create_owned_ref_node->addInput(output);
1066       create_owned_ref_node->output()->copyMetadata(output);
1067 
1068       block->appendNode(create_owned_ref_node);
1069       block->replaceOutput(i, create_owned_ref_node->output());
1070     }
1071   }
1072 }
1073 
ForceNonEmptyOutputsHelper(Value * none_value,Block * block)1074 void ForceNonEmptyOutputsHelper(Value* none_value, Block* block) {
1075   for (auto* node : block->nodes()) {
1076     bool needs_output = false;
1077     for (auto* sub_block : node->blocks()) {
1078       if (sub_block->outputs().empty()) {
1079         sub_block->registerOutput(none_value);
1080         needs_output = true;
1081       }
1082 
1083       ForceNonEmptyOutputsHelper(none_value, sub_block);
1084     }
1085 
1086     if (needs_output) {
1087       // Loop sub-blocks should always return at least one output (the new loop
1088       // condition)
1089       DCHECK(node->kind() == prim::If);
1090       auto* output = node->addOutput();
1091       output->setType(c10::NoneType::get());
1092     }
1093   }
1094 }
1095 
findOrCreateNoneConstant(Graph & graph)1096 Node* findOrCreateNoneConstant(Graph& graph) {
1097   // Only search the top-level block
1098   for (auto* node : graph.nodes()) {
1099     if (node->kind() != prim::Constant) {
1100       continue;
1101     }
1102     const auto ival_opt = toIValue(node->output());
1103     DCHECK(ival_opt.has_value());
1104     if (ival_opt->isNone()) {
1105       return node;
1106     }
1107   }
1108 
1109   auto* none_node = graph.create(prim::Constant);
1110   none_node->output()->setType(c10::NoneType::get());
1111   graph.prependNode(none_node);
1112   return none_node;
1113 }
1114 
1115 } // namespace
1116 
CreateOwnedRefsForSpecialValues(Graph & graph)1117 void CreateOwnedRefsForSpecialValues(Graph& graph) {
1118   CreateOwnedRefsForSpecialValuesHelper(graph, graph.block());
1119 }
1120 
ForceNonEmptyOutputs(Graph & graph)1121 void ForceNonEmptyOutputs(Graph& graph) {
1122   auto* none_node = findOrCreateNoneConstant(graph);
1123   ForceNonEmptyOutputsHelper(none_node->output(), graph.block());
1124   if (!none_node->hasUses()) {
1125     none_node->destroy();
1126   }
1127 }
1128 
1129 namespace {
1130 
inputIsConstantList(Node * node,size_t input_idx,const c10::List<int64_t> & expected)1131 bool inputIsConstantList(
1132     Node* node,
1133     size_t input_idx,
1134     const c10::List<int64_t>& expected) {
1135   auto input_opt = toIValue(node->input(input_idx));
1136   if (!input_opt.has_value() || !input_opt->isIntList()) {
1137     return false;
1138   }
1139   return input_opt->toIntList() == expected;
1140 }
1141 
inputIsConstantInt(Node * node,size_t input_idx,int64_t expected)1142 bool inputIsConstantInt(Node* node, size_t input_idx, int64_t expected) {
1143   auto input_opt = toIValue(node->input(input_idx));
1144   if (!input_opt.has_value() || !input_opt->isInt()) {
1145     return false;
1146   }
1147   return input_opt->toInt() == expected;
1148 }
1149 
eliminatePermuteOpsSumPattern(std::shared_ptr<Graph> & graph)1150 void eliminatePermuteOpsSumPattern(std::shared_ptr<Graph>& graph) {
1151   // SubgraphRewriter can't pattern-match on constants, so we use this
1152   // extra filter to make sure the values of the `dim` arguments are
1153   // correct.
1154   auto dims_are_valid_constants =
1155       [](const Match& match,
1156          const std::unordered_map<std::string, Value*>& vmap) {
1157         // Get the nodes in the real graph from the nodes in the template
1158         // pattern graph
1159         const auto& node_map = match.nodes_map;
1160         auto* sum_node = node_map.at(vmap.at("c")->node());
1161         auto* permute_node = node_map.at(vmap.at("b")->node());
1162         return inputIsConstantList(sum_node, 1, c10::List<int64_t>{-1}) &&
1163             inputIsConstantList(permute_node, 1, c10::List<int64_t>{0, 2, 1});
1164       };
1165 
1166   const auto pattern = R"IR(
1167     graph(%a, %sum_dim, %permute_dim, %keepdim, %dtype):
1168         %b = aten::permute(%a, %permute_dim)
1169         %c = aten::sum(%b, %sum_dim, %keepdim, %dtype)
1170         return (%c))IR";
1171 
1172   const auto fused_pattern = R"IR(
1173     graph(%a, %sum_dim, %permute_dim, %keepdim, %dtype):
1174         %new_sum_dim: int[] = prim::Constant[value=[1]]()
1175         %d = aten::sum(%a, %new_sum_dim, %keepdim, %dtype)
1176         return (%d))IR";
1177 
1178   SubgraphRewriter fuse;
1179   fuse.RegisterRewritePattern(pattern, fused_pattern);
1180   fuse.runOnGraph(graph, dims_are_valid_constants);
1181 }
1182 
eliminatePermuteOpsSoftmaxPattern(std::shared_ptr<Graph> & graph)1183 void eliminatePermuteOpsSoftmaxPattern(std::shared_ptr<Graph>& graph) {
1184   const auto pattern = R"IR(
1185     graph(%a, %permute_dim_1, %permute_dim_2, %softmax_dim, %softmax_dtype):
1186         %b = aten::permute(%a, %permute_dim_1)
1187         %c = aten::softmax(%b, %softmax_dim, %softmax_dtype)
1188         %d = aten::permute(%c, %permute_dim_2)
1189         return (%d)
1190   )IR";
1191 
1192   const auto fused_pattern = R"IR(
1193     graph(%a, %permute_dim_1, %permute_dim_2, %softmax_dim, %softmax_dtype):
1194         %new_softmax_dim: int = prim::Constant[value=1]()
1195         %e = aten::softmax(%a, %new_softmax_dim, %softmax_dtype)
1196         return (%e)
1197   )IR";
1198 
1199   // Check that permute_dim is (0, 2, 1) and softmax_dim is 2
1200   auto dims_are_valid_constants =
1201       [](const Match& match,
1202          const std::unordered_map<std::string, Value*>& vmap) {
1203         const auto& node_map = match.nodes_map;
1204         auto* permute_node_1 = node_map.at(vmap.at("b")->node());
1205         auto* permute_node_2 = node_map.at(vmap.at("d")->node());
1206         auto* softmax_node = node_map.at(vmap.at("c")->node());
1207         return inputIsConstantInt(softmax_node, 1, 2) &&
1208             inputIsConstantList(
1209                    permute_node_1, 1, c10::List<int64_t>{0, 2, 1}) &&
1210             inputIsConstantList(permute_node_2, 1, c10::List<int64_t>{0, 2, 1});
1211       };
1212 
1213   SubgraphRewriter fuse;
1214   fuse.RegisterRewritePattern(pattern, fused_pattern);
1215   fuse.runOnGraph(graph, dims_are_valid_constants);
1216 }
1217 
1218 } // namespace
1219 
EliminateExtraPermuteOps(std::shared_ptr<Graph> & graph)1220 void EliminateExtraPermuteOps(std::shared_ptr<Graph>& graph) {
1221   eliminatePermuteOpsSumPattern(graph);
1222   eliminatePermuteOpsSoftmaxPattern(graph);
1223 }
1224 
1225 namespace {
1226 
maybeUserWithKind(Value * value,c10::Symbol kind)1227 Node* maybeUserWithKind(Value* value, c10::Symbol kind) {
1228   auto& uses = value->uses();
1229   if (uses.size() != 1) {
1230     return nullptr;
1231   }
1232   auto* user = uses[0].user;
1233   if (user->kind() != kind) {
1234     return nullptr;
1235   }
1236   return user;
1237 }
1238 
1239 } // namespace
1240 
UseSplitAndSqueeze(std::shared_ptr<Graph> & graph)1241 void UseSplitAndSqueeze(std::shared_ptr<Graph>& graph) {
1242   std::vector<Node*> to_erase;
1243   for (auto* node : graph->nodes()) {
1244     if (node->kind() != aten::split) {
1245       continue;
1246     }
1247     auto axis_opt = toIValue(node->input(2));
1248     if (!axis_opt) {
1249       continue;
1250     }
1251     auto axis = *axis_opt;
1252     auto* split_node_output = node->output();
1253     auto* list_unpack_node =
1254         maybeUserWithKind(split_node_output, prim::ListUnpack);
1255     if (list_unpack_node == nullptr) {
1256       continue;
1257     }
1258     std::vector<Node*> squeeze_nodes;
1259     squeeze_nodes.reserve(list_unpack_node->outputs().size());
1260     for (auto* output : list_unpack_node->outputs()) {
1261       auto* squeeze_node = maybeUserWithKind(output, aten::squeeze);
1262       if (squeeze_node == nullptr) {
1263         break;
1264       }
1265       auto dim_opt = toIValue(squeeze_node->input(1));
1266       if (!dim_opt || *dim_opt != axis) {
1267         break;
1268       }
1269       squeeze_nodes.push_back(squeeze_node);
1270     }
1271     auto num_outputs = list_unpack_node->outputs().size();
1272     if (squeeze_nodes.size() != num_outputs) {
1273       continue;
1274     }
1275     auto* split_and_squeeze_node = graph->create(
1276         c10::Symbol::fromQualString(
1277             "static_runtime::fused_split_and_squeeze_copy"),
1278         num_outputs);
1279     split_and_squeeze_node->addInput(node->input(0));
1280     split_and_squeeze_node->addInput(node->input(1));
1281     split_and_squeeze_node->addInput(node->input(2));
1282     split_and_squeeze_node->insertBefore(node);
1283     for (const auto i : c10::irange(num_outputs)) {
1284       auto* squeeze_node = squeeze_nodes[i];
1285       split_and_squeeze_node->output(i)->copyMetadata(squeeze_node->output());
1286       squeeze_node->output()->replaceAllUsesWith(
1287           split_and_squeeze_node->output(i));
1288     }
1289     to_erase.insert(to_erase.end(), squeeze_nodes.begin(), squeeze_nodes.end());
1290     to_erase.push_back(list_unpack_node);
1291     to_erase.push_back(node);
1292   }
1293   for (auto* node : to_erase) {
1294     node->destroy();
1295   }
1296 }
1297 
RemoveUnnecessaryOutputs(std::shared_ptr<torch::jit::Graph> & graph)1298 C10_UNUSED void RemoveUnnecessaryOutputs(
1299     std::shared_ptr<torch::jit::Graph>& graph) {
1300   RemoveUnnecessaryEmbeddingBagOutputs(graph);
1301 }
1302 
RemoveUnnecessaryEmbeddingBagOutputs(std::shared_ptr<torch::jit::Graph> & graph)1303 C10_UNUSED void RemoveUnnecessaryEmbeddingBagOutputs(
1304     std::shared_ptr<torch::jit::Graph>& graph) {
1305   std::string pattern = R"IR(
1306     graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset):
1307         %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
1308         return (%y2, %y1, %y0))IR";
1309   std::string transformed_pattern = R"IR(
1310     graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset):
1311         %y0 : Tensor, %y1 : Tensor, %y2 : Tensor = static_runtime::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
1312         return (%y2, %y1, %y0))IR";
1313   SubgraphRewriter fuse;
1314   fuse.RegisterRewritePattern(pattern, transformed_pattern);
1315   fuse.runOnGraph(graph);
1316 
1317   std::string pattern2 = R"IR(
1318     graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx):
1319         %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx)
1320         return (%y2, %y1, %y0))IR";
1321   std::string transformed_pattern2 = R"IR(
1322     graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx):
1323         %y0 : Tensor, %y1 : Tensor, %y2 : Tensor = static_runtime::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx)
1324         return (%y2, %y1, %y0))IR";
1325   fuse.RegisterRewritePattern(pattern2, transformed_pattern2);
1326   fuse.runOnGraph(graph);
1327 }
1328 
1329 namespace {
isNoOpSlice(Node * node)1330 bool isNoOpSlice(Node* node) {
1331   DCHECK(node->kind() == aten::slice);
1332   auto step = toIValue(node->input(3));
1333   if (!step.has_value() || step->toInt() != 1) {
1334     return false;
1335   }
1336   auto start = toIValue(node->input(1));
1337   if (!start.has_value() || (start->isInt() && start->toInt() != 0)) {
1338     return false;
1339   }
1340   auto end = toIValue(node->input(2));
1341   // Could also look at list length, but most models that have this pattern are
1342   // just doing list[0:], so it's not needed for now.
1343   return end.has_value() && end->isNone();
1344 }
1345 } // namespace
1346 
EliminateNoOpSlice(std::shared_ptr<Graph> & graph)1347 void EliminateNoOpSlice(std::shared_ptr<Graph>& graph) {
1348   DepthFirstGraphNodeIterator it(graph);
1349   auto schema = torch::schema(
1350       "aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]",
1351       /*allow_typevars*/ true);
1352   Node* node = nullptr;
1353   std::vector<Node*> to_delete;
1354   while ((node = it.next()) != nullptr) {
1355     if (!node->matches(schema) || !isNoOpSlice(node)) {
1356       continue;
1357     }
1358 
1359     node->output()->replaceAllUsesWith(node->input(0));
1360     to_delete.push_back(node);
1361   }
1362   for (auto* node : to_delete) {
1363     node->destroy();
1364   }
1365 }
1366 
UseInPlaceGetRealInputsFromOptionalInputsV2(std::shared_ptr<Graph> & graph)1367 void UseInPlaceGetRealInputsFromOptionalInputsV2(
1368     std::shared_ptr<Graph>& graph) {
1369 #ifdef FBCODE_CAFFE2
1370   const std::string original_pattern = R"IR(
1371     graph(%optional_input: (Tensor, Tensor?, Tensor?)?[], %include_last_offsets: bool[]):
1372         %x : (Tensor, Tensor?, Tensor?)[] = remote_collection::get_real_inputs_from_optional_inputs_v2(%optional_input, %include_last_offsets)
1373         return (%x))IR";
1374 
1375   const std::string new_pattern = R"IR(
1376     graph(%optional_input: (Tensor, Tensor?, Tensor?)?[], %include_last_offsets: bool[]):
1377         %x : (Tensor, Tensor?, Tensor?)[] = static_runtime::get_real_inputs_from_optional_inputs_v2_inplace(%optional_input, %include_last_offsets)
1378         return (%x))IR";
1379 
1380   auto isSingleUse = [](Value* value) { return value->uses().size() == 1; };
1381 
1382   auto filter = [&isSingleUse](
1383                     const Match& match,
1384                     const std::unordered_map<std::string, Value*>& vmap) {
1385     auto* real_node = match.nodes_map.at(vmap.at("x")->node());
1386     return isSingleUse(real_node->input(0));
1387   };
1388 
1389   SubgraphRewriter fuse;
1390   fuse.RegisterRewritePattern(original_pattern, new_pattern);
1391   fuse.runOnGraph(graph, filter);
1392 #endif
1393 }
1394 
FuseClampNaNToNum(std::shared_ptr<Graph> & graph)1395 void FuseClampNaNToNum(std::shared_ptr<Graph>& graph) {
1396 #ifdef FBCODE_CAFFE2
1397   std::string pattern = R"IR(
1398     graph(%input, %clamp_min: Scalar?, %clamp_max: Scalar?, %nan, %posinf, %neginf):
1399         %x : Tensor = aten::clamp(%input, %clamp_min, %clamp_max)
1400         %y : Tensor = aten::nan_to_num(%x, %nan, %posinf, %neginf)
1401         return (%y))IR";
1402 
1403   std::string fused_pattern = R"IR(
1404     graph(%input, %clamp_min: Scalar?, %clamp_max: Scalar?, %nan, %posinf, %neginf):
1405         %x : Tensor = static_runtime::clamp_nan_to_num(%input, %clamp_min, %clamp_max, %nan, %posinf, %neginf)
1406         return (%x))IR";
1407 
1408   auto isConstantAndNotNone = [](Value* value) {
1409     auto ival_opt = toIValue(value);
1410     if (!ival_opt.has_value()) {
1411       return false;
1412     }
1413     auto scalar_opt = ival_opt->toOptional<at::Scalar>();
1414     return scalar_opt.has_value();
1415   };
1416 
1417   auto clampValuesAreConstant =
1418       [&isConstantAndNotNone](
1419           const Match& match,
1420           const std::unordered_map<std::string, Value*>& vmap) {
1421         // Get the nodes in the real graph from the nodes in the template
1422         // pattern graph
1423         const auto& node_map = match.nodes_map;
1424         auto* clamp_node = node_map.at(vmap.at("x")->node());
1425         return isConstantAndNotNone(clamp_node->input(1)) &&
1426             isConstantAndNotNone(clamp_node->input(2));
1427       };
1428 
1429   SubgraphRewriter fuse;
1430   fuse.RegisterRewritePattern(pattern, fused_pattern);
1431   fuse.runOnGraph(graph, clampValuesAreConstant);
1432 #endif
1433 }
1434 
PrepackWeights(std::shared_ptr<Graph> & graph)1435 void PrepackWeights(std::shared_ptr<Graph>& graph) {
1436   const auto pattern = R"IR(
1437     graph(%input: Tensor, %weight: Tensor, %bias: Tensor?, %scale: Tensor, %zero_point: Tensor):
1438         %result: Tensor = fb::quantized_linear_unpacked_weight_v2(%input, %weight, %bias, %scale, %zero_point)
1439         return (%result)
1440   )IR";
1441 
1442   const auto split_pattern = R"IR(
1443     graph(%input: Tensor, %weight: Tensor, %bias: Tensor?, %scale: Tensor, %zero_point: Tensor):
1444         %packed_params = quantized::linear_prepack(%weight, %bias)
1445         %scale_float: float = aten::item(%scale)
1446         %zero_point_int: int = aten::item(%zero_point)
1447         %result: Tensor = quantized::linear(%input, %packed_params, %scale_float, %zero_point_int)
1448         return (%result)
1449   )IR";
1450 
1451   SubgraphRewriter fuse;
1452   fuse.RegisterRewritePattern(pattern, split_pattern);
1453   fuse.runOnGraph(graph);
1454   // Constant propagation should be called after this pass + others.
1455 }
1456 
1457 } // namespace torch::jit
1458