xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/impl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/static/impl.h>
2 
3 #include <ATen/MemoryOverlap.h>
4 #include <ATen/core/symbol.h>
5 #include <ATen/record_function.h>
6 #include <c10/core/CPUAllocator.h>
7 #include <c10/core/InferenceMode.h>
8 #include <c10/macros/Macros.h>
9 #include <c10/util/MaybeOwned.h>
10 #include <c10/util/irange.h>
11 #include <caffe2/core/timer.h>
12 #include <torch/csrc/jit/ir/alias_analysis.h>
13 #include <torch/csrc/jit/jit_log.h>
14 #include <torch/csrc/jit/passes/add_if_then_else.h>
15 #include <torch/csrc/jit/passes/canonicalize.h>
16 #include <torch/csrc/jit/passes/dead_code_elimination.h>
17 #include <torch/csrc/jit/passes/eliminate_no_ops.h>
18 #include <torch/csrc/jit/passes/freeze_module.h>
19 #include <torch/csrc/jit/passes/remove_mutation.h>
20 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
21 #include <torch/csrc/jit/passes/variadic_ops.h>
22 #include <torch/csrc/jit/runtime/graph_iterator.h>
23 #include <torch/csrc/jit/runtime/static/fusion.h>
24 #include <torch/csrc/jit/runtime/static/memory_planner.h>
25 #include <torch/csrc/jit/runtime/static/ops.h>
26 #include <torch/csrc/jit/runtime/static/passes.h>
27 #include <torch/csrc/jit/runtime/vararg_functions.h>
28 #include <algorithm>
29 #include <cstdint>
30 #include <iostream>
31 
32 #ifndef AT_PER_OPERATOR_HEADERS
33 #include <ATen/NativeFunctions.h>
34 #else
35 #include <ATen/ops/clone_native.h>
36 #endif
37 
38 #include <iterator>
39 #include <limits>
40 #include <sstream>
41 #include <stdexcept>
42 
43 #ifdef FBCODE_CAFFE2
44 #include <common/logging/logging.h>
45 #include <folly/dynamic.h>
46 #include <folly/json.h>
47 #endif
48 
49 // used in test only
50 C10_DEFINE_bool(
51     static_runtime_disable_debug_memory_overlap_check,
52     false,
53     "If true, disable the memory overlap check in debug mode in ProcessedNode::run()");
54 
55 namespace torch::jit {
56 
57 namespace {
58 #ifndef STRIP_ERROR_MESSAGES
iValueToString(const c10::IValue & val)59 std::string iValueToString(const c10::IValue& val) {
60   std::ostringstream oss;
61   oss << val;
62   return oss.str();
63 }
64 #endif
65 
allArgsAreTensors(const Node * node)66 bool allArgsAreTensors(const Node* node) {
67   const auto& inputs = node->inputs();
68   return std::all_of(inputs.begin(), inputs.end(), [](const Value* value) {
69     return value->type()->kind() == TypeKind::TensorType;
70   });
71 }
72 
73 } // namespace
74 
75 // A manually curated set of ops that are disallowed in static runtime.
76 // These are rarely-used ops. Disallowing them typically eliminates
77 // corner cases in graph optimizations, allowing for more aggressive
78 // optimizations and better performance.
isUnsupportedOp(const Node * node)79 static bool isUnsupportedOp(const Node* node) {
80   auto kind = node->kind();
81   if (kind != aten::__is__ && kind != aten::__isnot__) {
82     return false;
83   }
84 
85   // We can't support aten::__is__ (and __isnot__) with tensor arguments.
86   // Consider the following graph:
87   // def forward(x):
88   //     y = x.detach()
89   //     return x is y
90   // We have a graph optimization that removes the `detach` node since it is
91   // a no-op during inference. But this affects the result - we get true
92   // instead of false! There are many other graph passes affected by this
93   // issue.
94   return allArgsAreTensors(node);
95 }
96 
97 namespace {
98 
canEnableStaticRuntimeImpl(const Block * block)99 bool canEnableStaticRuntimeImpl(const Block* block) {
100   if (block == nullptr) {
101     return false;
102   }
103 
104   bool can_support = true;
105   for (auto* node : block->nodes()) {
106     for (auto* subblock : node->blocks()) {
107       // The ordering prevents && from short circuiting, which we want -
108       // it's useful to see *all* the unsupported ops.
109       can_support = canEnableStaticRuntimeImpl(subblock) && can_support;
110     }
111 
112     const auto kind = node->kind();
113     if (kind == prim::Constant) {
114       continue;
115     }
116     // check if can get op from Node
117     const Operator* op = node->maybeOperator();
118     if (isUnsupportedOp(node) || (!op && !nativeOpIsRegistered(kind))) {
119       can_support = false;
120       LOG(WARNING) << "Found unsupported op: " << kind.toQualString();
121     }
122   }
123   return can_support;
124 }
125 
126 } // namespace
127 
128 // Graph must be frozen. canEnableStaticRuntime will return false
129 // if there's any prim::CallMethod ops left in the graph.
canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph> & graph)130 bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
131   return canEnableStaticRuntimeImpl(graph->block());
132 }
133 
134 namespace {
135 
136 auto sr_metadata_registerer = torch::class_<StaticRuntimeMetadata>(
137     "StaticRuntime",
138     "StaticRuntimeMetadata");
139 
140 } // namespace
141 
dumpValueSet(const c10::FastSet<const Value * > & value_set,const char * set_name)142 std::string dumpValueSet(
143     const c10::FastSet<const Value*>& value_set,
144     const char* set_name) {
145   std::ostringstream oss;
146   oss << set_name << ": {";
147   for (const auto* val : value_set) {
148     oss << "%" << val->debugName() << ", ";
149   }
150   oss << "}";
151   return oss.str();
152 }
153 
154 namespace {
155 
OptimizeGraph(std::shared_ptr<torch::jit::Graph> & graph,const StaticModuleOptions & opts,std::vector<IValue> sample_inputs)156 void OptimizeGraph(
157     std::shared_ptr<torch::jit::Graph>& graph,
158     const StaticModuleOptions& opts,
159     std::vector<IValue> sample_inputs) {
160   GRAPH_DUMP("Before optimizations: ", graph);
161   if (opts.enable_tensorexpr_fusion) {
162     if (sample_inputs.empty()) {
163       VLOG(1) << "Cannot perform TensorExpr fusion - sample_inputs is empty";
164     } else {
165       VLOG(1) << "Performing TensorExpr fusion";
166       performTensorExprFusion(graph, std::move(sample_inputs));
167     }
168   }
169   Inline(*graph);
170   ConstantPropagation(graph);
171   Canonicalize(graph);
172   ConstantPropagation(graph);
173   RemoveTensorMutation(graph);
174   ConstantPropagation(graph);
175   EliminateNoOpSlice(graph);
176   EliminateDeadCode(graph);
177   FuseInferenceOpsForSparseNN(graph);
178   UseVariadicCat(graph);
179   UseVariadicStack(graph);
180   EliminateTrivialEquallySplit(graph);
181   EliminateExtraPermuteOps(graph);
182 
183   if (opts.enable_out_variant) {
184     UseVariadicOp(
185         graph,
186         fromQualString("fb::sigrid_transforms_torch_bind"),
187         fromQualString("fb::variadic_sigrid_transforms_torch_bind"));
188     UseVariadicOp(
189         graph,
190         fromQualString("torcharrow::inference_wrapper_run_flat"),
191         fromQualString("torcharrow::variadic_inference_wrapper_run_flat"));
192     // These fused ops only have out variants - we can't do the fusion when
193     // out variants are disabled.
194     FuseSignLog1P(graph);
195     FuseClampNaNToNum(graph);
196 
197 #ifdef FBCODE_CAFFE2
198     if (opts.use_copy_variants && !opts.enable_tensorexpr_fusion) {
199       ReplaceWithCopy(graph);
200     } else {
201       ReplacePermuteWithCopy(graph);
202     }
203     if (opts.use_maybe_copy_variants && !opts.enable_tensorexpr_fusion) {
204       ReplaceWithMaybeCopy(graph);
205     }
206     FuseListUnpack(graph);
207     RemoveUnnecessaryOutputs(graph);
208     PrepackWeights(graph);
209 #endif
210   }
211 
212   ConstantPropagation(graph);
213   RemoveImmutableInputDictLookups(graph);
214   UseVariadicTupleUnpack(graph);
215   UseVariadicGroupedAccessor(graph);
216   EliminateNoOps(
217       graph, /* custom_ops */ {fromQualString("fb::scale_gradient")});
218   AddIfThenElseOp(graph);
219   UseSplitAndSqueeze(graph);
220   UseInPlaceGetRealInputsFromOptionalInputsV2(graph);
221   GRAPH_DUMP("Final graph after optimizations: ", graph);
222 }
223 
IsSelfInGraphInput(std::shared_ptr<torch::jit::Graph> & graph)224 bool IsSelfInGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
225   return !graph->inputs().empty() && graph->inputs().at(0)->type()->is_module();
226 }
227 
228 // remove unused input 0 from graph
removeSelfFromGraphInput(std::shared_ptr<torch::jit::Graph> & graph)229 bool removeSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
230   if (graph->inputs().at(0)->type()->is_module()) {
231     if (graph->inputs().at(0)->hasUses()) {
232       return false;
233     }
234     graph->eraseInput(0);
235   }
236   return true;
237 }
238 
valueVecFromFastSet(const c10::FastSet<const Value * > & s)239 std::vector<Value*> valueVecFromFastSet(const c10::FastSet<const Value*>& s) {
240   std::vector<Value*> result;
241   result.reserve(s.size());
242   for (auto* v : s) {
243     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
244     result.emplace_back(const_cast<Value*>(v));
245   }
246   return result;
247 }
248 
mayContainAlias(const AliasDb & db,const Value * v1,const Value * v2)249 bool mayContainAlias(const AliasDb& db, const Value* v1, const Value* v2) {
250   // AliasDb is not const-correct here, so we have to const_cast
251   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
252   return db.mayContainAlias(const_cast<Value*>(v1), const_cast<Value*>(v2));
253 }
254 
mayContainAlias(const AliasDb & db,const Value * a,const c10::FastSet<const Value * > & b)255 bool mayContainAlias(
256     const AliasDb& db,
257     const Value* a,
258     const c10::FastSet<const Value*>& b) {
259   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
260   return db.mayContainAlias(const_cast<Value*>(a), valueVecFromFastSet(b));
261 }
262 
escapesScope(const AliasDb & db,const Value * a)263 bool escapesScope(const AliasDb& db, const Value* a) {
264   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
265   return db.escapesScope({const_cast<Value*>(a)});
266 }
267 
PrepareGraphForStaticModule(std::shared_ptr<torch::jit::Graph> graph,const StaticModuleOptions & opts,std::vector<IValue> sample_inputs)268 void PrepareGraphForStaticModule(
269     std::shared_ptr<torch::jit::Graph> graph,
270     const StaticModuleOptions& opts,
271     std::vector<IValue> sample_inputs) {
272   TORCH_CHECK(canEnableStaticRuntime(graph));
273   OptimizeGraph(graph, opts, std::move(sample_inputs));
274 
275   // Static runtime moves its outputs out of the runtime
276   // by default. In some rare cases, this is not actually safe to
277   // do - for example, if the value is a constant, static runtime
278   // needs to hold onto a copy. Rather than adding special logic
279   // to handle this rare case, we use this pass to detect it and
280   // create an owned reference that can be safely moved out of the
281   // runtime.
282   CreateOwnedRefsForSpecialValues(*graph);
283 
284   // We assume that each sub-block has at least one output. If we
285   // detect any that have 0, force the sub-block to return None.
286   ForceNonEmptyOutputs(*graph);
287 }
288 
PrepareForStaticModule(const torch::jit::Module & m,bool is_frozen,const StaticModuleOptions & opts,std::vector<IValue> sample_inputs)289 std::pair<std::shared_ptr<Graph>, std::optional<Module>> PrepareForStaticModule(
290     const torch::jit::Module& m,
291     bool is_frozen,
292     const StaticModuleOptions& opts,
293     std::vector<IValue> sample_inputs) {
294   LOG(INFO) << "StaticModuleOptions: enable_out_variant "
295             << opts.enable_out_variant << ", optimize_memory "
296             << opts.optimize_memory << ", manage_output_tensors "
297             << opts.manage_output_tensors << ", use_copy_variants "
298             << opts.use_copy_variants << ", use_maybe_copy_variants "
299             << opts.use_maybe_copy_variants << ", enable_tensorexpr_fusion "
300             << opts.enable_tensorexpr_fusion;
301 
302   Module module = m.copy();
303   if (!is_frozen) {
304     module.eval();
305     module = freeze_module(module);
306   }
307 
308   Method method = module.get_method("forward");
309   auto graph = module.get_method("forward").graph();
310 
311   if (!sample_inputs.empty() && IsSelfInGraphInput(graph)) {
312     sample_inputs.insert(sample_inputs.begin(), m._ivalue());
313   }
314   PrepareGraphForStaticModule(graph, opts, std::move(sample_inputs));
315 
316   return std::make_pair(graph, module);
317 }
318 
PrepareForStaticModule(const std::shared_ptr<torch::jit::Graph> & graph,const StaticModuleOptions & opts,std::vector<IValue> sample_inputs)319 std::pair<std::shared_ptr<Graph>, std::optional<Module>> PrepareForStaticModule(
320     const std::shared_ptr<torch::jit::Graph>& graph,
321     const StaticModuleOptions& opts,
322     std::vector<IValue> sample_inputs) {
323   PrepareGraphForStaticModule(graph, opts, std::move(sample_inputs));
324   return std::make_pair(graph, std::nullopt);
325 }
326 
327 } // namespace
328 
init(const Block & block,const AliasDb & db)329 void ValueGroup::init(const Block& block, const AliasDb& db) {
330   external_aliases_.clear();
331   output_aliases_.clear();
332   // Build `external_aliases` as we look through nodes forwardly from
333   // the graph's inputs and add aliases of the inputs being created by the
334   // nodes.
335   external_aliases_.insert(block.inputs().begin(), block.inputs().end());
336   for (const auto* node : block.nodes()) {
337     if (node->kind() == prim::Constant) {
338       for (const auto* output : node->outputs()) {
339         external_aliases_.insert(output);
340       }
341     }
342   }
343   for (const auto* node : block.nodes()) {
344     if (node->kind() == prim::Constant) {
345       // Constants are already in `external_aliases`.
346       continue;
347     }
348     for (const auto* v : node->outputs()) {
349       if (escapesScope(db, v) || mayContainAlias(db, v, external_aliases_)) {
350         external_aliases_.insert(v);
351       }
352     }
353   }
354 
355   // Build `output_aliases` as we look through nodes reversely so that we can
356   // start from the output values, and follow the flows backwardly from there.
357   output_aliases_.insert(block.outputs().begin(), block.outputs().end());
358   for (const auto* node : block.nodes().reverse()) {
359     if (node->kind() == prim::Constant) {
360       // Constants cannot create any aliases.
361       continue;
362     }
363     for (const auto* v : node->outputs()) {
364       if (mayContainAlias(db, v, output_aliases_)) {
365         output_aliases_.insert(v);
366       }
367     }
368   }
369 }
370 
371 namespace {
372 
isTensorList(const Value * value)373 bool isTensorList(const Value* value) {
374   auto* type = value->type()->castRaw<ListType>();
375   if (!type) {
376     return false;
377   }
378   return type->getElementType()->kind() == c10::TypeKind::TensorType;
379 }
380 
containTensorsOnly(at::ArrayRef<Value * > values)381 bool containTensorsOnly(at::ArrayRef<Value*> values) {
382   // return true only if all outputs are tensors
383   return std::all_of(values.begin(), values.end(), [](const Value* value) {
384     return value->type()->kind() == c10::TypeKind::TensorType ||
385         isTensorList(value);
386   });
387 }
388 
isPureFunction(const Node * node)389 bool isPureFunction(const Node* node) {
390   auto* schema = node->maybeSchema();
391   return schema &&
392       schema->aliasAnalysis() == c10::AliasAnalysisKind::PURE_FUNCTION;
393 }
394 
395 } // namespace
396 
ManagedTensorRanges(Block & block,const AliasDb & alias_db,const c10::FastSet<const Value * > & managed_tensor_values)397 ManagedTensorRanges::ManagedTensorRanges(
398     Block& block,
399     const AliasDb& alias_db,
400     const c10::FastSet<const Value*>& managed_tensor_values) {
401   const std::vector<Node*> nodes(block.nodes().begin(), block.nodes().end());
402   const c10::FastSet<const Value*> graph_inputs(
403       block.inputs().begin(), block.inputs().end());
404 
405   const auto num_nodes = static_cast<uint32_t>(nodes.size());
406   for (const auto i : c10::irange(num_nodes)) {
407     auto* node = nodes[i];
408     for (auto* input : node->inputs()) {
409       auto* lifetime = getLifetime(input);
410       if (!lifetime) {
411         continue;
412       }
413       DCHECK(lifetime->end <= i);
414       lifetime->end = i;
415     }
416     for (auto* output : node->outputs()) {
417       if (!alias_db.isMutableType(output)) {
418         continue;
419       }
420       value_lifetimes_.emplace(output, Lifetime(i, i));
421     }
422   }
423   for (auto* graph_output : block.outputs()) {
424     auto* lifetime = getLifetime(graph_output);
425     if (!lifetime) {
426       continue;
427     }
428     lifetime->end = num_nodes;
429   }
430 
431   // Handle aliases. Aliases may extend a Value*'s lifetime. If a node
432   // has an input and output that may alias each other, set the input's
433   // lifetime end to max(input.lifetime_end, output.lifetime_end). Iterate
434   // backwards to handle chains of aliases.
435   for (const auto* node : block.nodes().reverse()) {
436     if (isPureFunction(node)) {
437       // If the node is a pure function, it doesn't create any aliases,
438       // so we can safely skip it.
439       continue;
440     }
441 
442     auto inputs = collectValuesWithTrackedLifetimes(node->inputs());
443     auto outputs = collectValuesWithTrackedLifetimes(node->outputs());
444     for (auto* input : inputs) {
445       auto* input_lifetime = getLifetime(input);
446       DCHECK(input_lifetime != nullptr);
447       for (auto* output : outputs) {
448         if (mayContainAlias(alias_db, input, output)) {
449           auto* output_lifetime = getLifetime(output);
450           DCHECK(output_lifetime != nullptr);
451           input_lifetime->end =
452               std::max(output_lifetime->end, input_lifetime->end);
453         }
454       }
455     }
456   }
457   for (auto* managed_tensor : managed_tensor_values) {
458     auto* lifetime = getLifetime(managed_tensor);
459     DCHECK(lifetime && lifetime->end <= num_nodes);
460     Node* freeing_node = nullptr;
461     if (lifetime->end == num_nodes) {
462       freeing_node = block.return_node();
463     } else {
464       freeing_node = nodes[lifetime->end];
465     }
466     node_to_newly_free_tensors_[freeing_node].emplace_back(managed_tensor);
467   }
468 }
469 
nodeFreesManagedTensors(Node * node) const470 bool ManagedTensorRanges::nodeFreesManagedTensors(Node* node) const {
471   auto it = node_to_newly_free_tensors_.find(node);
472   return it != node_to_newly_free_tensors_.end() && !it->second.empty();
473 }
474 
475 const std::vector<const Value*>& ManagedTensorRanges::
availableTensorValuesAfterNode(Node * node) const476     availableTensorValuesAfterNode(Node* node) const {
477   return node_to_newly_free_tensors_.at(node);
478 }
479 
lifetimesOverlap(const Value * v1,const Value * v2) const480 bool ManagedTensorRanges::lifetimesOverlap(const Value* v1, const Value* v2)
481     const {
482   const auto* v1_lifetime = getLifetime(v1);
483   const auto* v2_lifetime = getLifetime(v2);
484   if (!v1_lifetime || !v2_lifetime) {
485     return false;
486   }
487 
488   if (v1_lifetime->start < v2_lifetime->start) {
489     return v1_lifetime->end >= v2_lifetime->start;
490   }
491   return v2_lifetime->end >= v1_lifetime->start;
492 }
493 
getLifetime(const Value * value) const494 const ManagedTensorRanges::Lifetime* ManagedTensorRanges::getLifetime(
495     const Value* value) const {
496   auto it = value_lifetimes_.find(value);
497   if (it != value_lifetimes_.end()) {
498     return &it->second;
499   }
500   return nullptr;
501 }
502 
getLifetime(const Value * value)503 ManagedTensorRanges::Lifetime* ManagedTensorRanges::getLifetime(
504     const Value* value) {
505   // const_cast is safe here, this is just a way to avoid code duplication
506   // between the const/non-const versions of getLifetime.
507 
508   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
509   const auto* const_this = const_cast<const ManagedTensorRanges*>(this);
510 
511   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
512   return const_cast<ManagedTensorRanges::Lifetime*>(
513       const_this->getLifetime(value));
514 }
515 
516 std::vector<const Value*> ManagedTensorRanges::
collectValuesWithTrackedLifetimes(at::ArrayRef<const Value * > values)517     collectValuesWithTrackedLifetimes(at::ArrayRef<const Value*> values) {
518   std::vector<const Value*> mutable_values;
519   mutable_values.reserve(values.size());
520   std::copy_if(
521       values.begin(),
522       values.end(),
523       std::back_inserter(mutable_values),
524       [this](const Value* value) { return getLifetime(value) != nullptr; });
525   return mutable_values;
526 }
527 
StaticModule(const std::shared_ptr<torch::jit::Graph> & g,const StaticModuleOptions & opts,std::vector<IValue> sample_inputs)528 StaticModule::StaticModule(
529     const std::shared_ptr<torch::jit::Graph>& g,
530     const StaticModuleOptions& opts,
531     std::vector<IValue> sample_inputs)
532     : StaticModule(
533           PrepareForStaticModule(g->copy(), opts, std::move(sample_inputs)),
534           opts) {}
535 
StaticModule(const torch::jit::Module & m,bool is_frozen,const StaticModuleOptions & opts,std::vector<IValue> sample_inputs)536 StaticModule::StaticModule(
537     const torch::jit::Module& m,
538     bool is_frozen,
539     const StaticModuleOptions& opts,
540     std::vector<IValue> sample_inputs)
541     : StaticModule(
542           PrepareForStaticModule(m, is_frozen, opts, std::move(sample_inputs)),
543           opts) {}
544 
StaticModule(std::pair<std::shared_ptr<torch::jit::Graph>,std::optional<Module>> graph_and_module,const StaticModuleOptions & opts)545 StaticModule::StaticModule(
546     std::pair<std::shared_ptr<torch::jit::Graph>, std::optional<Module>>
547         graph_and_module,
548     const StaticModuleOptions& opts)
549     : opts_(opts),
550       graph_(std::move(graph_and_module.first)),
551       module_(std::move(graph_and_module.second)),
552       num_inputs_(graph_->inputs().size()) {
553   sr_metadata_ = c10::make_intrusive<jit::StaticRuntimeMetadata>(opts_);
554   // recursively attach metadata to prim::fork nodes
555   attachNodeMetadata(graph_->block());
556 
557   // check opt flags
558   if (opts.manage_output_tensors) {
559     TORCH_CHECK(
560         opts_.enable_out_variant,
561         "When manage_output_tensors is true, enable_out_variant must be set to true");
562   }
563   if (opts_.optimize_memory) {
564     TORCH_CHECK(
565         opts_.enable_out_variant,
566         "When optimize_memory is true, enable_out_variant must be set to true");
567   }
568 
569   // handle schema
570   if (module_.has_value()) {
571     Method method = module_->get_method("forward");
572     schema_ = method.function().getSchema();
573     const auto num_schema_args = schema_->arguments().size();
574     DCHECK(num_schema_args > 0);
575     if (removeSelfFromGraphInput(graph_)) {
576       module_ = std::nullopt;
577       num_inputs_ = num_schema_args - 1;
578     }
579   }
580 
581   {
582     size_t nodes_size = 0, constants_size = 0;
583     for (Node* node : graph_->nodes()) {
584       ++(node->kind() == prim::Constant ? constants_size : nodes_size);
585     }
586 
587     constants_.reserve(constants_size);
588     functions_.reserve(nodes_size);
589   }
590 
591   // Create ProcessedFunction instances first to freeze their addresses to pass
592   // to ProcessedNode.
593   AliasDb alias_db(graph_, /*isFrozen=*/false);
594   GRAPH_DEBUG("AliasDb: ", alias_db.toString());
595 
596   // Maps each Value* in the graph to its index in the values_ array that will
597   // eventually be created by StaticRuntime.
598   c10::FastMap<const Value*, uint32_t> value_to_index;
599   prepareFunctionsAndConstants(graph_->block(), alias_db, value_to_index);
600 
601   const auto constants_index_offset = 0;
602   const auto values_index_offset = constants_index_offset + constants().size();
603   value_buffer_size_ = values_index_offset;
604 
605   value_buffer_size_ +=
606       prepareBlockInfo(graph_->block(), values_index_offset, value_to_index);
607 
608   prepareStaticNodeInfos(graph_->block(), value_to_index, alias_db);
609 
610   for (auto& block_and_info : block_infos_) {
611     auto& block_info = block_and_info.second;
612     block_info.prepare_for_memory_planner(alias_db, opts);
613   }
614 }
615 
prepareBlockInfo(Block * block,const size_t start_idx,c10::FastMap<const Value *,uint32_t> & value_to_index)616 size_t StaticModule::prepareBlockInfo(
617     Block* block,
618     const size_t start_idx,
619     c10::FastMap<const Value*, uint32_t>& value_to_index) {
620   block_infos_.emplace(block, BlockInfo(start_idx, *block));
621 
622   const auto num_inputs = static_cast<uint32_t>(block->inputs().size());
623   for (const auto i : c10::irange(num_inputs)) {
624     value_to_index.emplace(block->inputs()[i], start_idx + i);
625   }
626   auto cur_idx = start_idx + num_inputs;
627 
628   for (auto* node : block->nodes()) {
629     for (auto* sub_block : node->blocks()) {
630       cur_idx += prepareBlockInfo(sub_block, cur_idx, value_to_index);
631     }
632 
633     if (node->kind() == prim::Constant) {
634       continue;
635     }
636 
637     TORCH_CHECK(
638         cur_idx < (1 << 16),
639         "outputs offset in values table",
640         cur_idx,
641         " would overflow 2-byte index storage");
642 
643     const auto num_outputs = static_cast<uint32_t>(node->outputs().size());
644     for (const auto i : c10::irange(num_outputs)) {
645       value_to_index.emplace(node->outputs()[i], cur_idx + i);
646     }
647     cur_idx += num_outputs;
648   }
649 
650   std::vector<uint16_t> output_indices;
651   output_indices.reserve(block->outputs().size());
652   for (auto* output : block->outputs()) {
653     const auto output_idx = value_to_index.at(output);
654     TORCH_CHECK(
655         output_idx < (1 << 16),
656         "outputs offset in values table",
657         output_idx,
658         " would overflow 2-byte index storage");
659     output_indices.push_back(output_idx);
660   }
661 
662   block_infos_.at(block).set_output_indices(std::move(output_indices));
663   return cur_idx - start_idx;
664 }
665 
attachNodeMetadata(Block * block)666 void StaticModule::attachNodeMetadata(Block* block) {
667   for (auto* node : block->nodes()) {
668     if (node->kind() == prim::fork) {
669       node->ival_(getStaticRuntimeMetadataSymbol(), IValue(sr_metadata_));
670     }
671     for (auto* sub_block : node->blocks()) {
672       attachNodeMetadata(sub_block);
673     }
674   }
675 }
676 
prepareFunctionsAndConstants(Block * block,const AliasDb & alias_db,c10::FastMap<const Value *,uint32_t> & value_to_index)677 void StaticModule::prepareFunctionsAndConstants(
678     Block* block,
679     const AliasDb& alias_db,
680     c10::FastMap<const Value*, uint32_t>& value_to_index) {
681   for (auto* node : block->nodes()) {
682     for (auto* sub_block : node->blocks()) {
683       prepareFunctionsAndConstants(sub_block, alias_db, value_to_index);
684     }
685 
686     if (node->kind() == prim::Constant) {
687       auto* v = node->output();
688       TORCH_CHECK(
689           v->type()->kind() != FunctionType::Kind,
690           "got ",
691           typeKindToString(v->type()->kind()),
692           " instead of ",
693           typeKindToString(FunctionType::Kind));
694       value_to_index.emplace(v, constants_.size());
695       constants_.emplace_back(toIValue(v).value());
696       continue;
697     }
698 
699     // see [Check and correct bad schema alias info at runtime]
700     bool check_outputs_for_overlap =
701         !alias_db.mayContainAlias(node->inputs(), node->outputs()) &&
702         containTensorsOnly(node->outputs());
703     // new ProcessedFunction
704     functions_.emplace_back(
705         node, opts_.enable_out_variant, check_outputs_for_overlap);
706   }
707 }
708 
prepareStaticNodeInfos(Block * block,const c10::FastMap<const Value *,uint32_t> & value_to_index,const AliasDb & alias_db,size_t node_idx)709 size_t StaticModule::prepareStaticNodeInfos(
710     Block* block,
711     const c10::FastMap<const Value*, uint32_t>& value_to_index,
712     const AliasDb& alias_db,
713     size_t node_idx) {
714   const auto node_start = node_idx;
715 
716   auto& block_info = block_infos_.at(block);
717   std::vector<StaticNodeInfo> nodes;
718   c10::FastMap<Node*, bool> node_has_out_variant;
719 
720   for (auto* node : block->nodes()) {
721     if (node->kind() == prim::Constant) {
722       continue;
723     }
724 
725     for (auto* sub_block : node->blocks()) {
726       node_idx +=
727           prepareStaticNodeInfos(sub_block, value_to_index, alias_db, node_idx);
728     }
729     const auto num_outputs = static_cast<uint32_t>(node->inputs().size());
730     ProcessedNodeInputs input_indices(num_outputs);
731     for (const auto input_idx : c10::irange<uint32_t>(num_outputs)) {
732       auto* input = node->inputs()[input_idx];
733       auto input_ivalue_idx = value_to_index.at(input);
734       TORCH_CHECK(
735           input_ivalue_idx < (1 << 16),
736           "input index in values table ",
737           input_ivalue_idx,
738           " would overflow 2-byte index storage");
739       input_indices[input_idx] = input_ivalue_idx;
740     }
741 
742     ProcessedFunction* fn = &functions_[node_idx];
743 
744     // create a new ProcessedNode
745     const auto node_output_idx = node->outputs().empty()
746         // The index is unused if there are no outputs, so just create a
747         // placeholder value.
748         ? std::numeric_limits<uint16_t>::max()
749         : value_to_index.at(node->output(0));
750     nodes.emplace_back(node, fn, std::move(input_indices), node_output_idx);
751 
752     node_has_out_variant.emplace(node, nodes.back().has_out_variant());
753     ++node_idx;
754   }
755 
756   block_info.set_nodes(std::move(nodes), node_has_out_variant);
757   block_info.init_value_group(alias_db);
758 
759   return node_idx - node_start;
760 }
761 
762 #ifdef FBCODE_CAFFE2
763 thread_local SROperatorObserver* tlsOpObserver = nullptr;
764 
setCurrentThreadObserver(SROperatorObserver * observer)765 void SROperatorObserver::setCurrentThreadObserver(
766     SROperatorObserver* observer) {
767   tlsOpObserver = observer;
768 }
769 
getCurrentThreadObserver()770 SROperatorObserver* SROperatorObserver::getCurrentThreadObserver() {
771   return tlsOpObserver;
772 }
773 
onStart(const Node * node)774 void SROperatorObserver::onStart(const Node* node) {
775   if (tlsOpObserver != nullptr && tlsOpObserver->startCb != nullptr) {
776     tlsOpObserver->startCb(node);
777   }
778 }
779 
onEnd(const Node * node)780 void SROperatorObserver::onEnd(const Node* node) {
781   if (tlsOpObserver != nullptr && tlsOpObserver->endCb != nullptr) {
782     tlsOpObserver->endCb(node);
783   }
784 }
785 #endif // FBCODE_CAFFE2
786 
BlockInfo(uint32_t input_idx,Block & block)787 BlockInfo::BlockInfo(uint32_t input_idx, Block& block)
788     : input_idx_(input_idx), block_(block) {}
789 
set_nodes(std::vector<StaticNodeInfo> nodes,const c10::FastMap<Node *,bool> & node_has_out_variant)790 void BlockInfo::set_nodes(
791     std::vector<StaticNodeInfo> nodes,
792     const c10::FastMap<Node*, bool>& node_has_out_variant) {
793   nodes_ = std::move(nodes);
794 
795   for (auto& node : nodes_) {
796     if (node.num_outputs() == 1 &&
797         isOptimizableContainerType(node.node(), node_has_out_variant)) {
798       node_is_optimizable_container_type_.emplace(node.node());
799     }
800   }
801 }
prepare_for_memory_planner(const AliasDb & alias_db,const StaticModuleOptions & opts)802 void BlockInfo::prepare_for_memory_planner(
803     const AliasDb& alias_db,
804     const StaticModuleOptions& opts) {
805   if (!opts.enable_out_variant) {
806     return;
807   }
808 
809   // Never manage graph outputs so that we can do std::move(output_ivalue).
810   // This does not affect performance if the graph returns a collection object.
811   c10::FastSet<const Value*> graph_output_values(
812       block_.outputs().begin(), block_.outputs().end());
813 
814   // collect register indices of outputs of ops with out variant
815   for (StaticNodeInfo& pnode : nodes_) {
816     if (!pnode.has_out_variant()) {
817       continue;
818     }
819     auto outputs = pnode.node()->outputs();
820     const auto num_outputs = static_cast<uint32_t>(outputs.size());
821     for (const auto i : c10::irange(num_outputs)) {
822       const Value* out_v = outputs[i];
823       // Types are stored in the underlying TorchScript IR
824       bool is_tensor_type = out_v->type()->castRaw<TensorType>();
825       if (opts.manage_output_tensors && is_tensor_type &&
826           graph_output_values.find(out_v) == graph_output_values.end() &&
827           value_group_.isOutputAlias(out_v)) {
828         managed_output_tensor_values_.insert(out_v);
829         continue;
830       }
831       if (value_group_.isAlwaysAlive(out_v)) {
832         continue;
833       }
834       if (is_tensor_type) {
835         managed_tensor_values_.insert(out_v);
836       } else if (node_is_optimizable_container_type(pnode.node())) {
837         // We "leak" certain container types because their allocations
838         // take a long time
839         leaked_values_.insert(out_v);
840       }
841     }
842   }
843 
844   for (const Value* output : block_.outputs()) {
845     managed_tensor_values_.erase(output);
846   }
847   GRAPH_DEBUG("managed_tensor_values: ", dumpValueSet(managed_tensor_values_));
848   GRAPH_DEBUG(
849       "managed_output_tensor_values_: ",
850       dumpValueSet(managed_output_tensor_values_));
851 
852   managed_tensor_ranges_ =
853       ManagedTensorRanges(block_, alias_db, managed_tensor_values_);
854 }
855 
opts() const856 const StaticModuleOptions& StaticModule::opts() const {
857   return opts_;
858 }
859 
num_outputs() const860 size_t StaticModule::num_outputs() const {
861   return graph_->outputs().size();
862 }
863 
num_inputs() const864 size_t StaticModule::num_inputs() const {
865   return num_inputs_;
866 }
867 
runtime()868 StaticRuntime& StaticModule::runtime() {
869   if (!cached_runtime_) {
870     cached_runtime_ = std::make_unique<StaticRuntime>(*this);
871   }
872   return *cached_runtime_;
873 }
874 
findNodeWithKindForTesting(const std::string & kind) const875 Node* StaticModule::findNodeWithKindForTesting(const std::string& kind) const {
876   for (auto& block_and_info : block_infos_) {
877     auto& block_info = block_and_info.second;
878     for (auto& pnode : block_info.nodes()) {
879       if (pnode.node()->kind().toQualString() == kind) {
880         return pnode.node();
881       }
882     }
883   }
884   return nullptr;
885 }
886 
operator ()(const std::vector<c10::IValue> & args,const KeywordArgs & kwargs)887 c10::IValue StaticModule::operator()(
888     const std::vector<c10::IValue>& args,
889     const KeywordArgs& kwargs) {
890   return runtime()(args, kwargs);
891 }
892 
operator ()(std::vector<c10::IValue> && args,const KeywordArgs & kwargs)893 c10::IValue StaticModule::operator()(
894     std::vector<c10::IValue>&& args,
895     const KeywordArgs& kwargs) {
896   return runtime()(std::move(args), kwargs);
897 }
898 
BlockRunner(const StaticModule & sm,IValue * values,Block * block,torch::jit::TaskLauncher * launcher,bool is_root_block)899 BlockRunner::BlockRunner(
900     const StaticModule& sm,
901     IValue* values,
902     Block* block,
903     torch::jit::TaskLauncher* launcher,
904     bool is_root_block)
905     : static_module_(sm),
906       block_info_(static_module_.block_info(block)),
907       is_root_block_(is_root_block),
908       first_input_is_self_(
909           is_root_block_ && static_module_.first_input_is_self()),
910       inputs_begin_(block_info_.block_inputs_idx()),
911       // TODO(T108633124): Turn on manage output tensors for sub-blocks.
912       manage_output_tensors_enabled_(
913           is_root_block_ && sm.opts().manage_output_tensors),
914       values_(values) {
915   nodes_.reserve(block_info_.nodes().size());
916   for (auto& pre_pnode : block_info_.nodes()) {
917     nodes_.emplace_back(pre_pnode, values_);
918   }
919 
920   for (auto index : block_info_.block_output_indices()) {
921     outputs_.emplace_back(&values_[index]);
922   }
923 
924   for (auto& pnode : nodes_) {
925     auto* node = pnode.node();
926 
927     // attach the async taskLauncher to processedNodes
928     pnode.set_metadata(launcher);
929     auto blocks = node->blocks();
930     const auto num_blocks = blocks.size();
931     if (num_blocks == 0) {
932       continue;
933     }
934     DCHECK(node->kind() == prim::If || node->kind() == prim::Loop);
935     std::vector<BlockRunner> block_runners;
936     block_runners.reserve(num_blocks);
937 
938     for (auto* b : blocks) {
939       block_runners.emplace_back(sm, values_, b, launcher);
940     }
941     pnode.set_metadata(std::move(block_runners));
942   }
943 }
944 
945 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
946 BlockRunner::BlockRunner(BlockRunner&&) noexcept = default;
947 
948 BlockRunner::~BlockRunner() = default;
949 
set_arg(const size_t idx,std::vector<IValue> && args)950 void BlockRunner::set_arg(const size_t idx, std::vector<IValue>&& args) {
951   DCHECK(idx < args.size());
952   Input(idx + first_input_is_self_) = std::move(args[idx]);
953 }
954 
set_arg(const size_t idx,const std::vector<IValue> & args)955 void BlockRunner::set_arg(const size_t idx, const std::vector<IValue>& args) {
956   DCHECK(idx < args.size());
957   Input(idx + first_input_is_self_) = args[idx];
958 }
959 
set_arg(const size_t idx,const IValue & arg)960 void BlockRunner::set_arg(const size_t idx, const IValue& arg) {
961   Input(idx + first_input_is_self_) = arg;
962 }
963 
964 namespace {
check_type(const Argument & schema_arg,const IValue & arg)965 void check_type(const Argument& schema_arg, const IValue& arg) {
966   // Fast path for most common case
967   if (arg.isTensor() &&
968       schema_arg.type()->kind() == c10::TypeKind::TensorType) {
969     return;
970   }
971   TORCH_CHECK(
972       arg.type()->isSubtypeOf(schema_arg.type()),
973       arg.type()->annotation_str(),
974       " is not a subtype of ",
975       schema_arg.type()->annotation_str(),
976       "; schema arg name: '",
977       schema_arg.name(),
978       "', ivalue: ",
979       iValueToString(arg));
980 }
981 } // namespace
982 
983 template <typename IValueList>
set_inputs(IValueList && args,const KeywordArgs & kwargs)984 void BlockRunner::set_inputs(IValueList&& args, const KeywordArgs& kwargs) {
985   const auto& schema = static_module_.schema();
986   if (first_input_is_self_) {
987     Input(0) = static_module_.module()._ivalue();
988   }
989 
990   if (!is_root_block_ || C10_UNLIKELY(!schema)) {
991     TORCH_CHECK(
992         kwargs.empty(),
993         "BlockRunner got kwargs; is_root_block: ",
994         std::to_string(is_root_block_),
995         "schema: ",
996         schema ? schema->name() : "(not available)");
997 
998     const auto total_num_inputs = args.size() + first_input_is_self_;
999     TORCH_CHECK(
1000         total_num_inputs == block_info_.num_inputs(),
1001         "Block runner got ",
1002         std::to_string(total_num_inputs),
1003         " inputs; ",
1004         " first_input_is_self: ",
1005         std::to_string(first_input_is_self_),
1006         "; SR block expects ",
1007         std::to_string(block_info_.num_inputs()),
1008         " inputs for schema ",
1009         schema ? schema->name() : "(not available)");
1010 
1011     for (const auto i_arg : c10::irange(args.size())) {
1012       set_arg(i_arg, std::forward<IValueList>(args));
1013     }
1014     return;
1015   }
1016 
1017   const auto& schema_args = schema->arguments();
1018   size_t consumed_kwargs = 0;
1019   DCHECK(!schema_args.empty());
1020   TORCH_CHECK(
1021       args.size() < schema_args.size(),
1022       "Static runtime got ",
1023       std::to_string(args.size()),
1024       " arguments, expects ",
1025       std::to_string(schema_args.size() - 1),
1026       " for schema ",
1027       schema->name());
1028 
1029   for (const auto i_arg : c10::irange(1, schema_args.size())) {
1030     // Start at 1 since the schema always contains `self`.
1031     const auto& schema_arg = schema_args[i_arg];
1032 
1033     if (i_arg - 1 < args.size()) {
1034       check_type(schema_arg, std::forward<IValueList>(args)[i_arg - 1]);
1035       set_arg(i_arg - 1, std::forward<IValueList>(args));
1036       continue;
1037     }
1038 
1039     auto it = kwargs.find(schema_arg.name());
1040     if (it != kwargs.end()) {
1041       check_type(schema_arg, it->second);
1042       set_arg(i_arg - 1, it->second);
1043       ++consumed_kwargs;
1044       continue;
1045     }
1046 
1047     auto maybe_default_val = schema_arg.default_value();
1048     if (maybe_default_val) {
1049       set_arg(i_arg - 1, *maybe_default_val);
1050       continue;
1051     }
1052 
1053     TORCH_CHECK(
1054         false,
1055         "Static runtime is missing required kwarg ",
1056         schema_arg.name(),
1057         " i_arg: ",
1058         std::to_string(i_arg),
1059         " for schema ",
1060         schema->name());
1061   }
1062   TORCH_CHECK(
1063       consumed_kwargs == kwargs.size(),
1064       "kwargs size mismatch (consumed ",
1065       std::to_string(consumed_kwargs),
1066       ", expected ",
1067       std::to_string(kwargs.size()),
1068       " for schema ",
1069       schema->name());
1070 }
1071 
create_memory_planner()1072 void BlockRunner::create_memory_planner() {
1073   if (!planner_) {
1074     planner_ = std::make_unique<StandardMemoryPlanner>(
1075         this,
1076         block_info_,
1077         static_module_.opts().enable_out_variant,
1078         manage_output_tensors_enabled_,
1079         static_module_.opts().optimize_memory);
1080   }
1081 }
1082 
1083 namespace {
1084 
destroyNodeOutputs(ProcessedNode & p_node)1085 void destroyNodeOutputs(ProcessedNode& p_node) {
1086   const auto borrows_outputs = borrowsOutputs(p_node.node()->kind());
1087   const auto num_outputs = static_cast<uint32_t>(p_node.num_outputs());
1088   for (const auto i : c10::irange<uint32_t>(num_outputs)) {
1089     auto& output = p_node.Output(i);
1090     if (doesNotHeapAllocateWhenStoredInIValue(*output.type())) {
1091       continue;
1092     }
1093 
1094     if (borrows_outputs) {
1095       // NB: No need to incref here. This codepath is only hit if the run didn't
1096       // finish, so we shouldn't be returning anything to the client.
1097       c10::MaybeOwnedTraits<IValue>::destroyBorrow(output);
1098     } else {
1099       output = IValue();
1100     }
1101   }
1102 }
1103 
1104 } // namespace
1105 
clean_up_intermediate_ivalues()1106 void BlockRunner::clean_up_intermediate_ivalues() noexcept {
1107   // We have to iterate in reverse order here due to borrowed
1108   // IValues - we don't want to destroy a value until all of its
1109   // borrows are cleaned up!
1110   for (auto it = nodes_.rbegin(); it != nodes_.rend(); ++it) {
1111     destroyNodeOutputs(*it);
1112   }
1113 }
1114 
resetMemory()1115 void BlockRunner::resetMemory() noexcept {
1116   planner_.reset();
1117   // We must clean up intermediate values before inputs in case
1118   // there are borrowed inputs and static runtime owns the only
1119   // reference (e.g. the inputs were std::move'd into the runtime)
1120   clean_up_intermediate_ivalues();
1121   clean_up_input_ivalues();
1122 }
1123 
move_outputs_to_tuple(uint32_t num_outputs)1124 c10::IValue BlockRunner::move_outputs_to_tuple(uint32_t num_outputs) {
1125   switch (num_outputs) {
1126     case 1:
1127       return c10::ivalue::Tuple::create(IValue(std::move(*outputs_[0])));
1128     case 2:
1129       return c10::ivalue::Tuple::create(
1130           IValue(std::move(*outputs_[0])), IValue(std::move(*outputs_[1])));
1131     case 3:
1132       return c10::ivalue::Tuple::create(
1133           IValue(std::move(*outputs_[0])),
1134           IValue(std::move(*outputs_[1])),
1135           IValue(std::move(*outputs_[2])));
1136     default: {
1137       std::vector<c10::IValue> outputs;
1138       outputs.reserve(num_outputs);
1139       for (const auto i : c10::irange(num_outputs)) {
1140         // use move here. Otherwise, clean up outputs_[i] explicitly
1141         outputs.emplace_back(std::move(*outputs_[i]));
1142       }
1143       return c10::ivalue::Tuple::create(std::move(outputs));
1144     }
1145   }
1146 }
1147 
1148 /// [Check and correct bad schema alias info at runtime]
1149 /// Static runtime relies on the operator schema's alias info to be correct for
1150 /// memory planning. Because it's hard to enforce the alias info to be correct,
1151 /// we need to do runtime detection for accidental aliases that do not comply
1152 /// with the schema. Only aliases of managed tensors are problematic. To avoid
1153 /// runtime crashes, we can add runtime detection and force the op to comply
1154 /// with its schema by cloning the alias. Because all managed tensors' data_ptrs
1155 /// are part of the internal buffer that the MemoryPlanner allocates, we can
1156 /// check aliases by checking the memory overlap with this internal buffer. But
1157 /// a tensor's storage can be resized during inferenceso we need another way to
1158 /// handle the resized case.
1159 ///
1160 /// There are two ways for incorrect schema to break memory planning. Let's look
1161 /// at two examples:
1162 ///
1163 /// Example 1:
1164 /// @code
1165 ///   def forward(x):
1166 ///     a = x + x
1167 ///     b = bad_op(a)  # b ends up aliasing a incorrectly
1168 ///     return (b)
1169 /// @endcode
1170 /// bad_op: its schema says it returns a new Tensor, but it actually returns an
1171 /// alias. In this case, the memory planner would recognize `a` as a managed
1172 /// tensor and clean up its memory before returning `b`. But `b` is actually an
1173 /// alias of `a`, when `a`'s data_ptr get reset, `b`'s data_ptr gets reset too.
1174 ///
1175 /// Example 2:
1176 /// @code
1177 ///   def forward(x):
1178 ///     a = x + x
1179 ///     a2 = bad_op(a) # a2 ends up alias a incorrectly
1180 ///     b = a + a
1181 ///     c = b * b # c shares storage with a
1182 ///     d = c + 2 # d shares storage with b
1183 ///     e = a2 * a2
1184 ///     return (d, e)
1185 /// @endcode
1186 /// With the memory reuse algorithm, `c` could end up sharing storage with `a`,
1187 /// but because of bad_op, `a2` now aliases `a`. `c` overwrites `a` and
1188 /// therefore `a2`, leading to the wrong results. We solve this problem with two
1189 /// steps. Note this doesn't happen with the current memory reuse algorithm
1190 /// because of the way it's implemented. Things could change with a different
1191 /// implementation.
1192 ///
1193 /// Step 1, annotate the ProcessedNodes with a flag `check_memory_overlap_` set
1194 /// to true if its outputs do not alias its inputs as indicated by the AliasDb
1195 /// and all of its outputs are Tensors. Then at runtime, we check that the
1196 /// nodes' output tensors do not overlap with the internal buffer that the
1197 /// MemoryPlanner allocates. For latency concerns, we only run this check for
1198 /// fallback ops. The schemas of native ops and out variants are vetted and
1199 /// enforced with static runtime unit tests. For the first iteration, we do a
1200 /// full memory overlap check with
1201 /// ProcessedNode::verify_and_correct_memory_overlap() because the internal
1202 /// buffer doesn't exist yet.
1203 ///
1204 /// Step 2, if a managed tensor gets resized during inference, it gets a new
1205 /// data_ptr which is not from the buffer. We can tackle this corner case by
1206 /// delaying the deallocation of the managed tensors to after the outputs are no
1207 /// longer used (essentially merging the internal/output buffers into one).
1208 /// Before the merging is implemented, we add another flag `overlap_detected_`
1209 /// to flag any node with overlap detected in Step 1 and do a full memory
1210 /// overlap check if the fast check (by checking memory overlap with internal
1211 /// buffer) fails. There is still a corner case that fails with the added flag.
1212 /// If a resize is triggered at the same time as the op creating an alias at the
1213 /// same time, the current checks would fail to detect the alias.
verify_and_correct_memory_overlap(ProcessedNode & n)1214 void BlockRunner::verify_and_correct_memory_overlap(ProcessedNode& n) {
1215   // The slow check can be removed once the internal/output buffers are merged
1216   if (C10_UNLIKELY(n.check_outputs_for_memory_overlap())) {
1217     if (C10_UNLIKELY(!planner_)) {
1218       // slow check, for first iter only
1219       n.verify_and_correct_memory_overlap();
1220     } else {
1221       bool overlap_detected_with_fast_check = false;
1222       const auto n_outputs = static_cast<uint32_t>(n.outputs().size());
1223       for (auto i : c10::irange(n_outputs)) {
1224         auto& output = n.Output(i);
1225         if (output.isTensor()) {
1226           overlap_detected_with_fast_check |=
1227               fast_check_and_correct_overlap_with(n, output);
1228         } else if (output.isTensorList()) {
1229           auto tensor_list = output.toListRef();
1230           for (auto& ival : tensor_list) {
1231             overlap_detected_with_fast_check |=
1232                 fast_check_and_correct_overlap_with(
1233                     n,
1234                     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1235                     const_cast<c10::IValue&>(ival));
1236           }
1237         }
1238       }
1239       if (n.outputs_memory_overlap_detected() &&
1240           !overlap_detected_with_fast_check) {
1241         // slow check. Only run when the fast check fails.
1242         n.verify_and_correct_memory_overlap();
1243       }
1244     }
1245   }
1246 }
1247 
fast_check_and_correct_overlap_with(ProcessedNode & n,c10::IValue & tensor_ival)1248 bool BlockRunner::fast_check_and_correct_overlap_with(
1249     ProcessedNode& n,
1250     c10::IValue& tensor_ival) {
1251   auto& tensor = tensor_ival.toTensor();
1252   if (planner_->overlapWithInternalBuffer(tensor.data_ptr())) {
1253     DLOG(INFO) << "Detected alias for node: " << PrintNode(n.node());
1254     tensor_ival = at::native::clone(tensor, std::nullopt);
1255     n.set_outputs_memory_overlap_detected();
1256     return true;
1257   }
1258   return false;
1259 }
1260 
~Deallocator()1261 BlockRunner::Deallocator::~Deallocator() {
1262   // Assume cleanup cannot throw.
1263   cleanupImpl();
1264 #ifndef NDEBUG
1265   block_runner_.check_for_memory_leak(/*output_returned*/ false);
1266 #endif
1267 }
1268 
cleanupImpl()1269 void BlockRunner::Deallocator::cleanupImpl() {
1270   // MemoryPlanner is created after the first invocation of `run()`. This
1271   // is done intentionally because MemoryPlanner uses `Tensor` sizes of
1272   // the previous `run()` for memory planning of subsequent runs
1273   if (C10_LIKELY(finished_)) {
1274     block_runner_.create_memory_planner();
1275   }
1276 
1277   if (C10_LIKELY(block_runner_.planner_)) {
1278     block_runner_.planner_->deallocate();
1279   } else {
1280     // This is the first run, and it didn't finish, so we can't use a
1281     // `MemoryPlanner` to deallocate stuff. Just reset everything manually.
1282     block_runner_.resetMemory();
1283   }
1284   // clean up owning refs of input tensors
1285   block_runner_.clean_up_input_ivalues();
1286   if (C10_UNLIKELY(!finished_)) {
1287     block_runner_.deallocateOutputTensors();
1288   }
1289 }
1290 
1291 template <typename IValueList>
run_impl(IValueList && args,const KeywordArgs & kwargs)1292 c10::IValue BlockRunner::run_impl(
1293     IValueList&& args,
1294     const KeywordArgs& kwargs) {
1295   // We assume inference workloads, so we do not need
1296   // autograd. Enabling this is a significant win on dispatcher
1297   // overhead because it saves a round of dispatch for at least some
1298   // functions, such as resize_ and resize_as_.
1299   c10::InferenceMode mode;
1300 
1301   {
1302     auto on_exit = Deallocator(*this);
1303 
1304     if (planner_) {
1305       DCHECK(!manage_output_tensors_enabled_ || checkOutputTensorMemoryLeaks());
1306       planner_->allocate();
1307     }
1308 
1309     set_inputs(std::forward<IValueList>(args), kwargs);
1310 
1311     for (auto& n : nodes_) {
1312       // LOG(INFO) << "Running node: " << PrintNode(n.node());
1313       n.run();
1314       // Check for incorrect schema alias info.
1315       verify_and_correct_memory_overlap(n);
1316     }
1317     on_exit.setFinished();
1318   }
1319 
1320   // no need to keep references of outputs in static runtime anymore
1321   if (block_info_.num_outputs() > 1) {
1322     return move_outputs_to_tuple(block_info_.num_outputs());
1323   }
1324 
1325   DCHECK(check_for_memory_leak(/*output_returned*/ false));
1326 
1327   // use move here. Otherwise, clean up outputs_[0] explicitly
1328   return std::move(*outputs_[0]);
1329 }
1330 
1331 template <typename IValueList>
run_impl_record_functions(IValueList && args,const KeywordArgs & kwargs)1332 c10::IValue BlockRunner::run_impl_record_functions(
1333     IValueList&& args,
1334     const KeywordArgs& kwargs) {
1335   auto step_callbacks =
1336       at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_MODEL);
1337   if (C10_UNLIKELY(step_callbacks.has_value())) {
1338     at::RecordFunction guard(std::move(*step_callbacks));
1339     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
1340     guard.needsInputs()
1341         ? guard.before(
1342               "forward", c10::ArrayRef<const IValue>(args.data(), args.size()))
1343         : guard.before("forward");
1344 
1345     return run_impl(std::forward<IValueList>(args), kwargs);
1346   }
1347   return run_impl(std::forward<IValueList>(args), kwargs);
1348 }
1349 
1350 template <typename IValueList>
run_impl_async(IValueList && args,const KeywordArgs & kwargs)1351 c10::intrusive_ptr<c10::ivalue::Future> BlockRunner::run_impl_async(
1352     IValueList&& args,
1353     const KeywordArgs& kwargs) {
1354   // run the graph inline in the caller thread. Async ops will be
1355   // executed on taskLauncher attached to the metadata of ProcessedNodes
1356   c10::IValue output = run_impl(std::forward<IValueList>(args), kwargs);
1357 
1358   // If the output is of type future, return it
1359   if (output.isFuture()) {
1360     return output.toFuture();
1361   }
1362 
1363   // wrap the output into future, mark completed and return it
1364   TypePtr return_type;
1365   if (block_info_.num_outputs() > 1) {
1366     return_type = TupleType::create(
1367         fmap(outputs(), [](const IValue* v) { return v->type(); }));
1368   } else {
1369     return_type = outputs().at(0)->type();
1370   }
1371   c10::intrusive_ptr<Future> future = c10::make_intrusive<Future>(return_type);
1372   future->markCompleted(output);
1373   return future;
1374 }
1375 
1376 template <typename IValueList>
1377 c10::intrusive_ptr<c10::ivalue::Future> BlockRunner::
run_impl_record_functions_async(IValueList && args,const KeywordArgs & kwargs)1378     run_impl_record_functions_async(
1379         IValueList&& args,
1380         const KeywordArgs& kwargs) {
1381   auto step_callbacks =
1382       at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_MODEL);
1383   if (C10_UNLIKELY(step_callbacks.has_value())) {
1384     at::RecordFunction guard(std::move(*step_callbacks));
1385     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
1386     guard.needsInputs()
1387         ? guard.before(
1388               "forward", c10::ArrayRef<const IValue>(args.data(), args.size()))
1389         : guard.before("forward");
1390 
1391     return run_impl_async(std::forward<IValueList>(args), kwargs);
1392   }
1393   return run_impl_async(std::forward<IValueList>(args), kwargs);
1394 }
1395 
operator ()(const std::vector<c10::IValue> & args,const KeywordArgs & kwargs)1396 c10::IValue BlockRunner::operator()(
1397     const std::vector<c10::IValue>& args,
1398     const KeywordArgs& kwargs) {
1399 #ifdef PYTORCH_DISABLE_NET_PROFILING
1400   return run_impl(args, kwargs);
1401 #else
1402   return run_impl_record_functions(args, kwargs);
1403 #endif
1404 }
1405 
operator ()(std::vector<c10::IValue> && args,const KeywordArgs & kwargs)1406 c10::IValue BlockRunner::operator()(
1407     std::vector<c10::IValue>&& args,
1408     const KeywordArgs& kwargs) {
1409 #ifdef PYTORCH_DISABLE_NET_PROFILING
1410   return run_impl(std::move(args), kwargs);
1411 #else
1412   return run_impl_record_functions(std::move(args), kwargs);
1413 #endif
1414 }
1415 
runAsync(const std::vector<c10::IValue> & args,const KeywordArgs & kwargs)1416 c10::intrusive_ptr<c10::ivalue::Future> BlockRunner::runAsync(
1417     const std::vector<c10::IValue>& args,
1418     const KeywordArgs& kwargs) {
1419 #ifdef PYTORCH_DISABLE_NET_PROFILING
1420   return run_impl_async(args, kwargs);
1421 #else
1422   return run_impl_record_functions_async(args, kwargs);
1423 #endif
1424 }
1425 
runAsync(std::vector<c10::IValue> && args,const KeywordArgs & kwargs)1426 c10::intrusive_ptr<c10::ivalue::Future> BlockRunner::runAsync(
1427     std::vector<c10::IValue>&& args,
1428     const KeywordArgs& kwargs) {
1429 #ifdef PYTORCH_DISABLE_NET_PROFILING
1430   return run_impl_async(std::move(args), kwargs);
1431 #else
1432   return run_impl_record_functions_async(std::move(args), kwargs);
1433 #endif
1434 }
1435 
1436 namespace {
1437 
generate_latency_json(const std::string & label,double millis)1438 std::string generate_latency_json(const std::string& label, double millis) {
1439 #ifdef FBCODE_CAFFE2
1440   folly::dynamic json = folly::dynamic::object();
1441   json["type"] = label;
1442   json["metric"] = "latency";
1443   json["unit"] = "ms";
1444   json["value"] = millis;
1445   return "PyTorchObserver " + folly::toJson(json);
1446 #else
1447   (void)label;
1448   (void)millis;
1449   return "";
1450 #endif
1451 }
1452 
1453 } // namespace
1454 
benchmark(const std::vector<std::vector<c10::IValue>> & args_list,const std::vector<KeywordArgs> & kwargs_list,const uint32_t warmup_runs,const uint32_t main_runs,bool print_per_node_time,bool generate_ai_pep_output)1455 void BlockRunner::benchmark(
1456     const std::vector<std::vector<c10::IValue>>& args_list,
1457     const std::vector<KeywordArgs>& kwargs_list,
1458     const uint32_t warmup_runs,
1459     const uint32_t main_runs,
1460     bool print_per_node_time,
1461     bool generate_ai_pep_output) {
1462   TORCH_CHECK(kwargs_list.empty() || args_list.size() == kwargs_list.size());
1463   std::cout << "Input size: " << args_list.size() << '\n';
1464   float time_per_iter =
1465       benchmark_model(args_list, kwargs_list, warmup_runs, main_runs);
1466   std::cout << "Static runtime ms per iter: " << time_per_iter
1467             << ". Iters per second: " << 1000.0 / time_per_iter << '\n';
1468 
1469   IndividualMetrics results =
1470       benchmark_individual_ops(args_list, kwargs_list, warmup_runs, main_runs);
1471 
1472   if (print_per_node_time) {
1473     const auto num_nodes = static_cast<uint32_t>(nodes_.size());
1474     for (const auto i : c10::irange(num_nodes)) {
1475       const Node* node = nodes_[i].node();
1476       std::cout << "Node #" << i << ": " << results.time_per_node[i]
1477                 << " ms/iter, ";
1478       node->print(std::cout, 0, nullptr, false);
1479     }
1480   }
1481 
1482   std::vector<std::pair<std::string, double>> time_per_node_type_vec{
1483       results.time_per_node_type.begin(), results.time_per_node_type.end()};
1484   if (args_list.empty()) {
1485     std::sort(
1486         time_per_node_type_vec.begin(),
1487         time_per_node_type_vec.end(),
1488         [&results](auto& left, auto& right) {
1489           return results.instances_per_node_type[left.first] >
1490               results.instances_per_node_type[right.first];
1491         });
1492   } else {
1493     std::sort(
1494         time_per_node_type_vec.begin(),
1495         time_per_node_type_vec.end(),
1496         [](auto& left, auto& right) { return left.second > right.second; });
1497   }
1498   std::cout << "Time per node type:" << '\n';
1499   for (const auto& p : time_per_node_type_vec) {
1500     const std::string& kind = p.first;
1501     const double ms = p.second;
1502     std::cout << std::setw(15) << ms << " ms. " << std::setw(10)
1503               << results.percent_per_node_type[kind] << "%. " << kind << " ("
1504               << results.instances_per_node_type[kind] << " nodes";
1505     if (results.out_nodes.count(kind)) {
1506       std::cout << ", out variant)" << '\n';
1507     } else if (results.native_nodes.count(kind)) {
1508       std::cout << ", native)" << '\n';
1509     } else {
1510       std::cout << ")" << '\n';
1511     }
1512 
1513     if (generate_ai_pep_output) {
1514       LOG(INFO) << generate_latency_json(kind, ms);
1515     }
1516   }
1517   if (generate_ai_pep_output) {
1518     LOG(INFO) << generate_latency_json(
1519         "static_runtime_first_iter", results.first_iter_time);
1520   }
1521   std::cout << std::setw(15) << results.total_time << " ms. in Total" << '\n';
1522   std::cout << "BlockRunner setup time: " << results.setup_time << " ms"
1523             << '\n';
1524   std::cout << "Memory allocation time: " << results.memory_alloc_time
1525             << " ms\n";
1526   std::cout << "Memory deallocation time: " << results.memory_dealloc_time
1527             << " ms" << '\n';
1528   std::cout << "Outputs deallocation time: " << results.output_dealloc_time
1529             << " ms" << '\n';
1530   std::cout << "First iter time: " << results.first_iter_time << " ms" << '\n';
1531   std::cout << "Number of operators: " << nodes_.size() << '\n';
1532 
1533   if (planner_) {
1534     std::cout << "Total number of managed tensors: "
1535               << planner_->total_num_managed_tensors() << '\n';
1536     std::cout << "Total number of managed output tensors: "
1537               << planner_->total_num_managed_output_tensors() << '\n';
1538     std::cout << "Total number of unmanaged values: "
1539               << planner_->total_num_unmanaged() << '\n';
1540     std::cout << "Number of unmanaged values requiring cleanup: "
1541               << planner_->num_unmanaged_non_scalars() << '\n';
1542     std::cout << "Number of unmanaged values not requiring cleanup: "
1543               << planner_->num_unmanaged_scalars() << '\n';
1544     std::cout << "Total memory managed: " << planner_->total_managed()
1545               << " bytes" << '\n';
1546     if (static_module_.opts().optimize_memory) {
1547       std::cout << "Total number of reused tensors: "
1548                 << planner_->total_reused_tensors() << '\n';
1549     }
1550   }
1551 
1552   auto unsupported_nodes_count = results.total_nodes_count -
1553       results.out_nodes_count - results.native_nodes.size();
1554   std::cout << "Total number of 'out' variant nodes/total number of nodes: "
1555             << results.out_nodes_count << "/" << results.total_nodes_count
1556             << " ("
1557             << 100.0 * static_cast<float>(results.out_nodes_count) /
1558           static_cast<float>(results.total_nodes_count)
1559             << "%)" << '\n';
1560   std::cout << "Total number of nodes not covered by SR/total number of nodes: "
1561             << unsupported_nodes_count << "/" << results.total_nodes_count
1562             << " ("
1563             << 100.0 * static_cast<float>(unsupported_nodes_count) /
1564           static_cast<float>(results.total_nodes_count)
1565             << "%)" << '\n';
1566 
1567   check_for_memory_leak();
1568 
1569 #ifndef NDEBUG
1570   KeywordArgs empty_kwargs;
1571   display_nodes(
1572       args_list[0], kwargs_list.size() > 0 ? kwargs_list[0] : empty_kwargs);
1573 #endif
1574 }
1575 
benchmark_model(const std::vector<std::vector<c10::IValue>> & args_list,const std::vector<KeywordArgs> & kwargs_list,const unsigned int warmup_runs,const unsigned int main_runs)1576 float BlockRunner::benchmark_model(
1577     const std::vector<std::vector<c10::IValue>>& args_list,
1578     const std::vector<KeywordArgs>& kwargs_list,
1579     const unsigned int warmup_runs,
1580     const unsigned int main_runs) {
1581   TORCH_CHECK(main_runs >= 1);
1582   TORCH_CHECK(kwargs_list.empty() || args_list.size() == kwargs_list.size());
1583 
1584   const bool is_kwargs_empty = kwargs_list.empty();
1585   const KeywordArgs empty_kwargs;
1586   for (const auto _n_run : c10::irange(warmup_runs)) {
1587     (void)_n_run; // Suppress unused variable warning
1588     const auto num_args = static_cast<uint32_t>(args_list.size());
1589     for (const auto j : c10::irange(num_args)) {
1590       operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1591       if (manage_output_tensors_enabled_) {
1592         deallocateOutputTensors();
1593       }
1594     }
1595   }
1596   caffe2::Timer timer;
1597   for (const auto _n_run : c10::irange(main_runs)) {
1598     (void)_n_run; // Suppress unused variable warning
1599     const auto num_args = static_cast<uint32_t>(args_list.size());
1600     for (const auto j : c10::irange(num_args)) {
1601       operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1602       if (manage_output_tensors_enabled_) {
1603         deallocateOutputTensors();
1604       }
1605     }
1606   }
1607   float millis = timer.MilliSeconds();
1608   return millis /
1609       (static_cast<float>(main_runs) * static_cast<float>(args_list.size()));
1610 }
1611 
display_ivalue(const IValue & iv)1612 static bool display_ivalue(const IValue& iv) {
1613   if (iv.isTensor()) {
1614     std::cout << "Tensor " << iv.toTensor().toString() << " {";
1615     const auto dims = iv.toTensor().sizes();
1616     const auto n_dims = static_cast<uint32_t>(dims.size());
1617     for (const auto i : c10::irange(n_dims)) {
1618       std::cout << iv.toTensor().sizes()[i];
1619       if (n_dims > i + 1) {
1620         std::cout << ", ";
1621       }
1622     }
1623     std::cout << "}\n";
1624     return true;
1625   } else if (iv.isTensorList()) {
1626     std::cout << "TensorList {" << iv.toTensorList().size() << "}\n";
1627     return true;
1628   } else if (iv.isGenericDict()) {
1629     std::cout << "Dict {" << iv.toGenericDict().size() << "}\n";
1630     return true;
1631   } else if (iv.isTuple()) {
1632     std::cout << "Tuple {" << iv.toTupleRef().elements().size() << "}\n";
1633     return true;
1634   } else if (iv.isInt()) {
1635     std::cout << "int {" << iv.toInt() << "}\n";
1636     return true;
1637   } else if (iv.isBool()) {
1638     std::cout << "bool {" << iv.toBool() << "}\n";
1639     return true;
1640   } else if (iv.isDouble()) {
1641     std::cout << "double {" << iv.toDouble() << "}\n";
1642     return true;
1643   }
1644   return false;
1645 }
1646 
display_pnode_info(const ProcessedNode & pnode)1647 static void display_pnode_info(const ProcessedNode& pnode) {
1648   pnode.node()->print(std::cout, 0, nullptr, false);
1649   const auto num_inputs = static_cast<uint32_t>(pnode.num_inputs());
1650   for (const auto i : c10::irange(num_inputs)) {
1651     std::cout << "\ti" << i << ": ";
1652     if (!display_ivalue(pnode.Input(i))) {
1653       std::cout << *(pnode.node()->inputs()[i]->type()) << '\n';
1654     }
1655   }
1656   const auto outputs = pnode.outputs();
1657   const auto num_outputs = static_cast<uint32_t>(outputs.size());
1658   for (const auto i : c10::irange(num_outputs)) {
1659     std::cout << "\to" << i << ": ";
1660     if (!display_ivalue(outputs[i])) {
1661       std::cout << *(pnode.node()->outputs()[i]->type()) << '\n';
1662     }
1663   }
1664 }
1665 
display_nodes(const std::vector<c10::IValue> & args,const KeywordArgs & kwargs)1666 void BlockRunner::display_nodes(
1667     const std::vector<c10::IValue>& args,
1668     const KeywordArgs& kwargs) {
1669   c10::InferenceMode mode;
1670 
1671   auto on_exit = Deallocator(*this);
1672 
1673   if (planner_) {
1674     planner_->allocate();
1675   }
1676   set_inputs(args, kwargs);
1677 
1678   for (auto& node : nodes_) {
1679     node.run();
1680     display_pnode_info(node);
1681   }
1682   on_exit.setFinished();
1683 }
1684 
benchmark_individual_ops(const std::vector<std::vector<c10::IValue>> & args_list,const std::vector<KeywordArgs> & kwargs_list,const uint32_t warmup_runs,const uint32_t main_runs)1685 BlockRunner::IndividualMetrics BlockRunner::benchmark_individual_ops(
1686     const std::vector<std::vector<c10::IValue>>& args_list,
1687     const std::vector<KeywordArgs>& kwargs_list,
1688     const uint32_t warmup_runs,
1689     const uint32_t main_runs) {
1690   TORCH_CHECK(kwargs_list.empty() || args_list.size() == kwargs_list.size());
1691   TORCH_CHECK(warmup_runs >= 1 && main_runs >= 1);
1692 
1693   IndividualMetrics results;
1694   results.time_per_node.resize(nodes_.size(), 0);
1695   if (args_list.empty()) {
1696     // When the given input is empty, compute the op statistics from the given
1697     // graph without executing it.
1698     const auto num_nodes = static_cast<uint32_t>(nodes_.size());
1699     for (const auto i : c10::irange(num_nodes)) {
1700       const Node* node = nodes_[i].node();
1701       std::string kind(node->kind().toQualString());
1702       // TODO: Collect op statistics from sub-blocks here.
1703       results.time_per_node[i] = 0;
1704       results.time_per_node_type[kind] = 0;
1705       results.instances_per_node_type[kind]++;
1706       if (nodes_[i].has_out_variant()) {
1707         results.out_nodes.insert(kind);
1708         results.out_nodes_count++;
1709       } else if (nodes_[i].has_native()) {
1710         results.native_nodes.insert(kind);
1711       }
1712       results.total_time += results.time_per_node[i];
1713     }
1714     results.total_nodes_count = nodes_.size();
1715     results.memory_alloc_time = 0;
1716     results.memory_dealloc_time = 0;
1717     results.output_dealloc_time = 0;
1718     for (const auto& p : results.time_per_node_type) {
1719       const std::string& kind = p.first;
1720       results.percent_per_node_type[kind] = 0;
1721     }
1722     return results;
1723   }
1724 
1725   const bool is_kwargs_empty = kwargs_list.empty();
1726   const KeywordArgs empty_kwargs;
1727   bool manage_output_tensors = static_module_.opts().manage_output_tensors;
1728   // See comment on above use of InferenceMode for
1729   // explanation.
1730   c10::InferenceMode mode;
1731 
1732   // setup time
1733   caffe2::Timer timer;
1734 
1735   set_inputs(args_list[0], is_kwargs_empty ? empty_kwargs : kwargs_list[0]);
1736 
1737   results.setup_time = timer.MilliSeconds();
1738 
1739   // The first iteration profiles each node's output Tensors' sizes and
1740   // initializes the memory planner with the profile information. Following
1741   // iterations just use the already established memory planning.
1742   timer.Start();
1743   operator()(args_list[0], is_kwargs_empty ? empty_kwargs : kwargs_list[0]);
1744   if (manage_output_tensors) {
1745     deallocateOutputTensors();
1746   }
1747   results.first_iter_time = timer.MilliSeconds();
1748 
1749   // warmup runs
1750   for (const auto _n_run : c10::irange(warmup_runs)) {
1751     (void)_n_run; // Suppress unused variable warning
1752     const auto num_args = static_cast<uint32_t>(args_list.size());
1753     for (const auto j : c10::irange(num_args)) {
1754       operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1755       if (manage_output_tensors) {
1756         deallocateOutputTensors();
1757       }
1758     }
1759   }
1760 
1761   // main runs
1762   for (const auto i : c10::irange(main_runs)) {
1763     (void)i; // Suppress unused variable warning
1764     const auto num_args = static_cast<uint32_t>(args_list.size());
1765     for (const auto j : c10::irange(num_args)) {
1766       set_inputs(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1767 
1768       timer.Start();
1769       if (planner_) {
1770         planner_->allocate();
1771       }
1772       float millis = timer.MilliSeconds();
1773       results.memory_alloc_time += millis;
1774       const auto num_nodes = static_cast<uint32_t>(nodes_.size());
1775       for (const auto k : c10::irange<uint32_t>(num_nodes)) {
1776         timer.Start();
1777         nodes_[k].run();
1778         millis = timer.MilliSeconds();
1779         results.time_per_node[k] += millis;
1780         verify_and_correct_memory_overlap(nodes_[k]);
1781       }
1782       timer.Start();
1783       create_memory_planner();
1784       planner_->deallocate();
1785       // clean up owning refs of input tensors
1786       clean_up_input_ivalues();
1787       if (manage_output_tensors) {
1788         deallocateOutputTensors();
1789       }
1790       millis = timer.MilliSeconds();
1791       results.memory_dealloc_time += millis;
1792 
1793       timer.Start();
1794       // no need to keep references of outputs in static runtime anymore
1795       c10::IValue output;
1796       if (static_module_.num_outputs() > 1) {
1797         output = move_outputs_to_tuple(static_module_.num_outputs());
1798       }
1799 
1800       DCHECK(check_for_memory_leak(/*output_returned*/ false));
1801 
1802       // use move here. Otherwise, clean up outputs_[0] explicitly
1803       output = std::move(*outputs_[0]);
1804       // release outputs explicitly to measure the time it takes
1805       output = IValue();
1806       millis = timer.MilliSeconds();
1807       results.output_dealloc_time += millis;
1808     }
1809   }
1810 
1811   // post processing
1812   const float num_total_iters =
1813       (static_cast<float>(main_runs) * static_cast<float>(args_list.size()));
1814   const auto num_nodes = static_cast<uint32_t>(nodes_.size());
1815   for (const auto i : c10::irange(num_nodes)) {
1816     const Node* node = nodes_[i].node();
1817     std::string kind = std::string(node->kind().toQualString());
1818     results.time_per_node[i] /= num_total_iters;
1819     results.time_per_node_type[kind] += results.time_per_node[i];
1820     results.instances_per_node_type[kind]++;
1821     if (nodes_[i].has_out_variant()) {
1822       results.out_nodes.insert(kind);
1823       results.out_nodes_count++;
1824     } else if (nodes_[i].has_native()) {
1825       results.native_nodes.insert(kind);
1826     }
1827     results.total_time += results.time_per_node[i];
1828   }
1829   results.total_nodes_count = nodes_.size();
1830   results.memory_alloc_time /= num_total_iters;
1831   results.memory_dealloc_time /= num_total_iters;
1832   results.output_dealloc_time /= num_total_iters;
1833   for (const auto& p : results.time_per_node_type) {
1834     const std::string& kind = p.first;
1835     results.percent_per_node_type[kind] = p.second / results.total_time * 100;
1836   }
1837   return results;
1838 }
1839 
check_for_memory_leak(bool output_returned,bool recurse_on_sub_blocks)1840 bool BlockRunner::check_for_memory_leak(
1841     bool output_returned,
1842     bool recurse_on_sub_blocks) {
1843   // check for inputs
1844   const auto num_inputs = static_cast<uint32_t>(block_info_.num_inputs());
1845   for (const auto i : c10::irange(num_inputs)) {
1846     TORCH_CHECK(
1847         values_[i + block_info_.block_inputs_idx()].isNone(),
1848         "Input ",
1849         i,
1850         " was not cleaned up");
1851   }
1852   c10::FastSet<const IValue*> output_ivalues(outputs_.begin(), outputs_.end());
1853   const auto num_nodes = static_cast<uint32_t>(nodes_.size());
1854   for (const auto n : c10::irange(num_nodes)) {
1855     auto& pnode = nodes_[n];
1856     const auto num_outputs = static_cast<uint32_t>(pnode.num_outputs());
1857     for (const auto i : c10::irange(num_outputs)) {
1858       const IValue* ival = &pnode.Output(i);
1859       const Value* val = pnode.node()->output(i);
1860       // subtlety: isManagedOutputTensorValue may give a false
1861       // negative here if an output is an alias of this value, so
1862       // check the actual tensor!
1863       if (planner_ &&
1864           (isManagedOutputTensor(*ival) || isManagedOutputTensorValue(val))) {
1865         // `ival` contains a managed output tensor that the runtime doesn't
1866         // reclaim at the end of an iteration, but the client does so
1867         // by explicitly calling
1868         // `BlockRunner::deallocateOutputTensors`.
1869         continue;
1870       }
1871       const std::string error_msg = "Output " + std::to_string(i) + ", %" +
1872           val->debugName() + " of node " + std::to_string(n) +
1873           " which has kind " + pnode.node()->kind().toQualString() +
1874           " was not cleaned up";
1875       if (output_ivalues.count(ival) == 0) {
1876         // check for intermediates
1877         if (!ival->isNone()) {
1878           TORCH_CHECK(
1879               ival->isTensor() ||
1880                   block_info_.node_is_optimizable_container_type(
1881                       pnode.node()) ||
1882                   doesNotHeapAllocateWhenStoredInIValue(*val->type()),
1883               error_msg);
1884           if (ival->isTensor()) {
1885             const auto& t = ival->toTensor();
1886             if (t.defined()) {
1887               auto* storage_impl = t.storage().unsafeGetStorageImpl();
1888               TORCH_CHECK(
1889                   storage_impl->data() == nullptr ||
1890                       (planner_ &&
1891                        planner_->isManagedStorageImpl(storage_impl)),
1892                   error_msg);
1893             }
1894           }
1895         }
1896       } else {
1897         // check for outputs
1898         if (output_returned) {
1899           TORCH_CHECK(ival->isNone(), error_msg);
1900         }
1901       }
1902     }
1903     auto* metadata = pnode.metadata();
1904     if (recurse_on_sub_blocks && metadata) {
1905       auto& block_runners = metadata->block_runners();
1906       for (auto& block_runner : block_runners) {
1907         block_runner.check_for_memory_leak(
1908             output_returned, recurse_on_sub_blocks);
1909       }
1910     }
1911   }
1912   VLOG(1) << "Finished checking for memory leak";
1913   return true;
1914 }
1915 
deallocateOutputTensors()1916 void BlockRunner::deallocateOutputTensors() {
1917   if (!static_module_.opts().manage_output_tensors) {
1918     TORCH_CHECK(
1919         !planner_ || planner_->numOutputBufferBytes() == 0,
1920         "manage_output_tensors is disabled, but output tensor buffer is not empty.");
1921     return;
1922   }
1923   if (planner_) {
1924     planner_->deallocateOutputTensors();
1925     DCHECK(checkOutputTensorMemoryLeaks());
1926   }
1927 }
1928 
checkOutputTensorMemoryLeaks()1929 bool BlockRunner::checkOutputTensorMemoryLeaks() {
1930   if (!static_module_.opts().manage_output_tensors || !planner_) {
1931     return true;
1932   }
1933   const auto num_nodes = static_cast<uint32_t>(nodes_.size());
1934   for (const auto n : c10::irange(num_nodes)) {
1935     auto& pnode = nodes_[n];
1936     const auto num_outputs = static_cast<uint32_t>(pnode.num_outputs());
1937     for (const auto i : c10::irange(num_outputs)) {
1938       const IValue* ival = &pnode.Output(i);
1939       const Value* val = pnode.node()->output(i);
1940       if (!isManagedOutputTensorValue(val) || !ival->isTensor()) {
1941         // ival can not be a tensor if it's being managed by ops like
1942         // to_maybe_copy_out; see ReplaceWithMaybeCopy for details.
1943         continue;
1944       }
1945       const auto& t = ival->toTensor();
1946       if (t.defined()) {
1947         auto* storage_impl = t.storage().unsafeGetStorageImpl();
1948         const std::string error_msg = "Output " + std::to_string(i) + ", %" +
1949             val->debugName() + " of node " + std::to_string(n) +
1950             " was not cleaned up";
1951         TORCH_CHECK(storage_impl->data() == nullptr, error_msg);
1952       }
1953     }
1954   }
1955   VLOG(1) << "Finished checking for memory leak from output tensors";
1956   return true;
1957 }
1958 
isManagedOutputTensor(const IValue & ivalue) const1959 bool BlockRunner::isManagedOutputTensor(const IValue& ivalue) const {
1960   return planner_ && planner_->isManagedOutputTensor(ivalue);
1961 }
1962 
isManagedOutputTensorValue(const Value * value) const1963 bool BlockRunner::isManagedOutputTensorValue(const Value* value) const {
1964   // It's possible that manage_output_tensors_ was disabled after initializing
1965   // managed_output_tensor_values, so we have to check that flag here.
1966   if (!planner_ || !manage_output_tensors_enabled_) {
1967     return false;
1968   }
1969   const auto& managed_outputs = block_info_.managed_output_tensor_values();
1970   return managed_outputs.find(value) != managed_outputs.end();
1971 }
1972 
disableManageOutputTensors()1973 void BlockRunner::disableManageOutputTensors() {
1974   if (!manage_output_tensors_enabled_) {
1975     return;
1976   }
1977   manage_output_tensors_enabled_ = false;
1978   if (!planner_) {
1979     return;
1980   }
1981   // Reset all IValues and destruct planner_ so that it can be reconstructed in
1982   // the next run.
1983   for (auto& n : nodes_) {
1984     const auto num_outputs = static_cast<uint32_t>(n.outputs().size());
1985     for (const auto i : c10::irange(num_outputs)) {
1986       n.Output(i) = IValue();
1987     }
1988   }
1989   planner_.reset();
1990 }
1991 
ProcessedFunction(Node * node,bool enable_out_variant,bool check_memory_overlap)1992 ProcessedFunction::ProcessedFunction(
1993     Node* node,
1994     bool enable_out_variant,
1995     bool check_memory_overlap)
1996     : check_memory_overlap_(check_memory_overlap),
1997       num_outputs_(node->outputs().size()) {
1998   if (enable_out_variant) {
1999     f_ = getOutOfPlaceOperation(node);
2000     if (f_) {
2001       kind_ = ProcessedFunction::Kind::kOutVariant;
2002       // do not check memory overlap for out variants
2003       check_memory_overlap_ = false;
2004       VLOG(1) << "Switch to out variant for node: " << PrintNode(node);
2005       return;
2006     }
2007   }
2008   {
2009     f_ = getNativeOperation(node);
2010     if (f_) {
2011       kind_ = ProcessedFunction::Kind::kNativeFunction;
2012 #ifdef NDEBUG
2013       // skip this check in opt mode because these ops are better vetted
2014       check_memory_overlap_ = false;
2015 #endif
2016       VLOG(1) << "Switch to native impl for node: " << PrintNode(node);
2017       return;
2018     }
2019   }
2020   {
2021     const Operator& op = node->getOperator();
2022     f_ = [node_op = op.getOperation(node),
2023           has_var_args = hasVarArgs(node)](ProcessedNode* pnode) mutable {
2024       std::vector<IValue> stack;
2025       const auto size = static_cast<uint32_t>(pnode->num_inputs());
2026       stack.reserve(size + has_var_args);
2027       for (const auto i : c10::irange(size)) {
2028         stack.emplace_back(pnode->Input(i));
2029       }
2030       // Need to store the number of inputs in stack for variadic ops.
2031       if (has_var_args) {
2032         stack.emplace_back(static_cast<int>(size));
2033       }
2034       node_op(stack);
2035       const auto num_outputs = static_cast<uint32_t>(pnode->num_outputs());
2036       TORCH_DCHECK_EQ(stack.size(), num_outputs);
2037       for (const auto i : c10::irange(num_outputs)) {
2038         pnode->Output(i) = std::move(stack[i]);
2039       }
2040     };
2041     kind_ = ProcessedFunction::Kind::kInterpreterFallback;
2042     VLOG(1) << "Fallback interpreter for node: " << PrintNode(node);
2043   }
2044 }
2045 
StaticNodeInfo(Node * node,ProcessedFunction * fn,ProcessedNodeInputs inputs,uint16_t outputs_offset)2046 StaticNodeInfo::StaticNodeInfo(
2047     Node* node,
2048     ProcessedFunction* fn,
2049     ProcessedNodeInputs inputs,
2050     uint16_t outputs_offset)
2051     : node_(node),
2052       fn_(fn),
2053       inputs_(std::move(inputs)),
2054       outputs_offset_(outputs_offset) {
2055   TORCH_CHECK(
2056       num_outputs() == node->outputs().size(),
2057       "Node ",
2058       node->kind().toQualString(),
2059       " has ",
2060       std::to_string(num_outputs()),
2061       " outputs, expected ",
2062       std::to_string(node->outputs().size()));
2063 }
2064 
inputs_ivalue_vec() const2065 std::vector<IValue> ProcessedNode::inputs_ivalue_vec() const {
2066   std::vector<IValue> result;
2067   const auto num_inputs = static_cast<uint32_t>(inputs_.size());
2068   result.reserve(num_inputs);
2069 
2070   for (const auto idx : c10::irange(num_inputs)) {
2071     result.emplace_back(Input(idx));
2072   }
2073   return result;
2074 }
2075 
run()2076 void ProcessedNode::run() {
2077 #ifdef FBCODE_CAFFE2
2078   SROperatorObserver::onStart(node());
2079 #endif
2080 #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
2081   auto step_callbacks =
2082       at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_OP);
2083   if (C10_UNLIKELY(step_callbacks.has_value())) {
2084     at::RecordFunction guard(std::move(*step_callbacks));
2085     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
2086     if (guard.needsInputs()) {
2087       const auto inputs = inputs_ivalue_vec();
2088       guard.before(
2089           get_op_name(),
2090           c10::ArrayRef<const IValue>(inputs.data(), inputs.size()));
2091     } else {
2092       guard.before(get_op_name());
2093     }
2094     if (has_out_variant()) {
2095       guard._setStaticRuntimeOutVariant();
2096     }
2097 
2098     fn_->run(this);
2099   } else {
2100     fn_->run(this);
2101   }
2102 #else
2103   fn_->run(this);
2104 #endif
2105 #ifndef NDEBUG
2106   if (FLAGS_static_runtime_disable_debug_memory_overlap_check) {
2107     // run check but do not enforce
2108     verify_no_memory_overlap();
2109   } else {
2110     DCHECK(verify_no_memory_overlap());
2111   }
2112 #endif
2113 #ifdef FBCODE_CAFFE2
2114   SROperatorObserver::onEnd(node());
2115 #endif
2116 }
2117 
checkNoMemoryOverlap(const at::Tensor & a,const at::Tensor & b)2118 static bool checkNoMemoryOverlap(const at::Tensor& a, const at::Tensor& b) {
2119   at::MemOverlapStatus status = at::get_overlap_status(a, b);
2120   if (status == at::MemOverlapStatus::Full ||
2121       status == at::MemOverlapStatus::Partial) {
2122     return false;
2123   }
2124   if (status == at::MemOverlapStatus::TooHard) {
2125     VLOG(1) << "Detected TOO_HARD memory overlap status";
2126   }
2127   return true;
2128 }
2129 
verify_no_memory_overlap(bool force_check) const2130 bool ProcessedNode::verify_no_memory_overlap(bool force_check) const {
2131   const static std::array<c10::Symbol, 7> special_case_ops = {
2132       fromQualString("prim::TypeCheck"),
2133       fromQualString("prim::IfThenElse"),
2134       fromQualString("static_runtime::select_tensor"),
2135       fromQualString("static_runtime::VarTupleUnpack"),
2136       fromQualString("static_runtime::dict_unpack"),
2137       fromQualString("static_runtime::fused_split_and_squeeze"),
2138       fromQualString("static_runtime::create_owned_ref")};
2139   if (!force_check &&
2140       std::find(
2141           begin(special_case_ops), end(special_case_ops), node()->kind()) !=
2142           end(special_case_ops)) {
2143     return true;
2144   }
2145 
2146   return verify_outputs_dont_overlap_each_other() &&
2147       verify_inputs_dont_overlap_outputs(force_check);
2148 }
2149 
verify_outputs_dont_overlap_each_other() const2150 bool ProcessedNode::verify_outputs_dont_overlap_each_other() const {
2151   const auto n_outputs = static_cast<uint32_t>(num_outputs());
2152   for (const auto i : c10::irange(n_outputs)) {
2153     if (!Output(i).isTensor()) {
2154       continue;
2155     }
2156     const auto& out0_t = Output(i).toTensor();
2157     for (const auto j : c10::irange(i + 1, n_outputs)) {
2158       if (!Output(j).isTensor()) {
2159         continue;
2160       }
2161       const auto& out1_t = Output(j).toTensor();
2162       if (!checkNoMemoryOverlap(out0_t, out1_t)) {
2163         LOG(INFO) << "Node output " << i << " overlaps with output " << j
2164                   << ", " << PrintNode(node_);
2165         return false;
2166       }
2167     }
2168   }
2169   return true;
2170 }
2171 
verify_inputs_dont_overlap_outputs(bool force_check) const2172 bool ProcessedNode::verify_inputs_dont_overlap_outputs(bool force_check) const {
2173   auto schema = node()->maybeSchema();
2174   // skip memory overlap check for mutable or view ops with only one output
2175   bool skip_check = !schema ||
2176       ((schema->is_mutable() || !fn_->checkMemoryOverlap()) &&
2177        num_outputs() == 1);
2178   if (!schema || (!force_check && skip_check)) {
2179     if (!schema) {
2180       VLOG(2) << "Detected that op schema is null";
2181       return true;
2182     }
2183     VLOG(2) << "schema->is_mutable: " << schema->is_mutable()
2184             << ", fn_->checkMemoryOverlap: " << fn_->checkMemoryOverlap()
2185             << ", num_outputs_: " << num_outputs();
2186     return true;
2187   }
2188   const auto n_inputs = static_cast<uint32_t>(inputs_.size());
2189   const auto n_outputs = static_cast<uint32_t>(num_outputs());
2190   for (const auto i : c10::irange<uint32_t>(n_inputs)) {
2191     const IValue* in = &Input(i);
2192     if (!in->isTensor()) {
2193       continue;
2194     }
2195     const auto& in_t = in->toTensor();
2196     for (const auto j : c10::irange(n_outputs)) {
2197       const IValue& out = Output(j);
2198       if (!out.isTensor()) {
2199         continue;
2200       }
2201       const auto& out_t = out.toTensor();
2202       if (!checkNoMemoryOverlap(in_t, out_t)) {
2203         LOG(INFO) << "Node input " << i << " overlaps with output " << j << ", "
2204                   << PrintNode(node_);
2205         LOG(INFO) << *schema;
2206         return false;
2207       }
2208     }
2209   }
2210   return true;
2211 }
2212 
check_and_correct_overlap_with(const at::Tensor & input,c10::IValue & output_ival)2213 bool ProcessedNode::check_and_correct_overlap_with(
2214     const at::Tensor& input,
2215     c10::IValue& output_ival) {
2216   auto& tensor = output_ival.toTensor();
2217   if (!checkNoMemoryOverlap(input, tensor)) {
2218     DLOG(INFO) << "Detected alias for node: " << PrintNode(node());
2219     output_ival = at::native::clone(tensor, std::nullopt);
2220     set_outputs_memory_overlap_detected();
2221     return true;
2222   }
2223   return false;
2224 }
2225 
verify_and_correct_memory_overlap()2226 void ProcessedNode::verify_and_correct_memory_overlap() {
2227   const auto n_inputs = static_cast<uint32_t>(inputs_.size());
2228   const auto n_outputs = static_cast<uint32_t>(num_outputs());
2229   for (const auto i : c10::irange(n_inputs)) {
2230     const IValue& in = Input(i);
2231     if (!in.isTensor()) {
2232       continue;
2233     }
2234     const auto& in_t = in.toTensor();
2235     for (const auto j : c10::irange(n_outputs)) {
2236       auto& output = Output(j);
2237       if (output.isTensor()) {
2238         check_and_correct_overlap_with(in_t, output);
2239       } else if (output.isTensorList()) {
2240         auto tensors = output.toListRef();
2241         for (const auto& ival : tensors) {
2242           // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
2243           check_and_correct_overlap_with(in_t, const_cast<c10::IValue&>(ival));
2244         }
2245 #ifdef FBCODE_CAFFE2
2246         if (outputs_memory_overlap_detected()) {
2247           LOG_EVERY_MS(WARNING, 60000)
2248               << "Detected alias for node: " << PrintNode(node());
2249         }
2250 #endif
2251       }
2252     }
2253   }
2254 }
2255 
StaticRuntime(const StaticModule & sm)2256 StaticRuntime::StaticRuntime(const StaticModule& sm)
2257     : async_task_launcher_(at::launch), values_(sm.value_buffer_size()) {
2258   std::copy(sm.constants().begin(), sm.constants().end(), values_.data());
2259   // default task launcher set to inter-op thread pool
2260 
2261   block_ = std::make_unique<BlockRunner>(
2262       sm,
2263       values_.data(),
2264       sm.root_block(),
2265       &async_task_launcher_,
2266       true /*is_root_block*/);
2267 }
2268 
operator ()(const std::vector<c10::IValue> & args,const KeywordArgs & kwargs)2269 c10::IValue StaticRuntime::operator()(
2270     const std::vector<c10::IValue>& args,
2271     const KeywordArgs& kwargs) {
2272   return (*block_)(args, kwargs);
2273 }
2274 
operator ()(std::vector<c10::IValue> && args,const KeywordArgs & kwargs)2275 c10::IValue StaticRuntime::operator()(
2276     std::vector<c10::IValue>&& args,
2277     const KeywordArgs& kwargs) {
2278   return (*block_)(std::move(args), kwargs);
2279 }
2280 
runAsync(const std::vector<c10::IValue> & args,const KeywordArgs & kwargs,torch::jit::TaskLauncher taskLauncher)2281 c10::intrusive_ptr<c10::ivalue::Future> StaticRuntime::runAsync(
2282     const std::vector<c10::IValue>& args,
2283     const KeywordArgs& kwargs,
2284     torch::jit::TaskLauncher taskLauncher) {
2285   async_task_launcher_ = std::move(taskLauncher);
2286   return block_->runAsync(args, kwargs);
2287 }
2288 
runAsync(std::vector<c10::IValue> && args,const KeywordArgs & kwargs,torch::jit::TaskLauncher taskLauncher)2289 c10::intrusive_ptr<c10::ivalue::Future> StaticRuntime::runAsync(
2290     std::vector<c10::IValue>&& args,
2291     const KeywordArgs& kwargs,
2292     torch::jit::TaskLauncher taskLauncher) {
2293   async_task_launcher_ = std::move(taskLauncher);
2294   return block_->runAsync(std::move(args), kwargs);
2295 }
2296 
check_for_memory_leak(bool output_returned)2297 bool StaticRuntime::check_for_memory_leak(bool output_returned) {
2298   return block_->check_for_memory_leak(
2299       output_returned, /* recurse_on_sub_blocks */ true);
2300 }
2301 
checkOutputTensorMemoryLeaks()2302 bool StaticRuntime::checkOutputTensorMemoryLeaks() {
2303   return block_->checkOutputTensorMemoryLeaks();
2304 }
2305 
deallocateOutputTensors()2306 void StaticRuntime::deallocateOutputTensors() {
2307   block_->deallocateOutputTensors();
2308 }
2309 
isManagedOutputTensor(const IValue & ivalue) const2310 bool StaticRuntime::isManagedOutputTensor(const IValue& ivalue) const {
2311   return block_->isManagedOutputTensor(ivalue);
2312 }
2313 
disableManageOutputTensors()2314 void StaticRuntime::disableManageOutputTensors() {
2315   block_->disableManageOutputTensors();
2316 }
2317 
get_memory_planner() const2318 const MemoryPlanner* StaticRuntime::get_memory_planner() const {
2319   return block_->get_memory_planner();
2320 }
2321 
2322 } // namespace torch::jit
2323