1 #include <torch/csrc/jit/runtime/static/impl.h>
2
3 #include <ATen/MemoryOverlap.h>
4 #include <ATen/core/symbol.h>
5 #include <ATen/record_function.h>
6 #include <c10/core/CPUAllocator.h>
7 #include <c10/core/InferenceMode.h>
8 #include <c10/macros/Macros.h>
9 #include <c10/util/MaybeOwned.h>
10 #include <c10/util/irange.h>
11 #include <caffe2/core/timer.h>
12 #include <torch/csrc/jit/ir/alias_analysis.h>
13 #include <torch/csrc/jit/jit_log.h>
14 #include <torch/csrc/jit/passes/add_if_then_else.h>
15 #include <torch/csrc/jit/passes/canonicalize.h>
16 #include <torch/csrc/jit/passes/dead_code_elimination.h>
17 #include <torch/csrc/jit/passes/eliminate_no_ops.h>
18 #include <torch/csrc/jit/passes/freeze_module.h>
19 #include <torch/csrc/jit/passes/remove_mutation.h>
20 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
21 #include <torch/csrc/jit/passes/variadic_ops.h>
22 #include <torch/csrc/jit/runtime/graph_iterator.h>
23 #include <torch/csrc/jit/runtime/static/fusion.h>
24 #include <torch/csrc/jit/runtime/static/memory_planner.h>
25 #include <torch/csrc/jit/runtime/static/ops.h>
26 #include <torch/csrc/jit/runtime/static/passes.h>
27 #include <torch/csrc/jit/runtime/vararg_functions.h>
28 #include <algorithm>
29 #include <cstdint>
30 #include <iostream>
31
32 #ifndef AT_PER_OPERATOR_HEADERS
33 #include <ATen/NativeFunctions.h>
34 #else
35 #include <ATen/ops/clone_native.h>
36 #endif
37
38 #include <iterator>
39 #include <limits>
40 #include <sstream>
41 #include <stdexcept>
42
43 #ifdef FBCODE_CAFFE2
44 #include <common/logging/logging.h>
45 #include <folly/dynamic.h>
46 #include <folly/json.h>
47 #endif
48
49 // used in test only
50 C10_DEFINE_bool(
51 static_runtime_disable_debug_memory_overlap_check,
52 false,
53 "If true, disable the memory overlap check in debug mode in ProcessedNode::run()");
54
55 namespace torch::jit {
56
57 namespace {
58 #ifndef STRIP_ERROR_MESSAGES
iValueToString(const c10::IValue & val)59 std::string iValueToString(const c10::IValue& val) {
60 std::ostringstream oss;
61 oss << val;
62 return oss.str();
63 }
64 #endif
65
allArgsAreTensors(const Node * node)66 bool allArgsAreTensors(const Node* node) {
67 const auto& inputs = node->inputs();
68 return std::all_of(inputs.begin(), inputs.end(), [](const Value* value) {
69 return value->type()->kind() == TypeKind::TensorType;
70 });
71 }
72
73 } // namespace
74
75 // A manually curated set of ops that are disallowed in static runtime.
76 // These are rarely-used ops. Disallowing them typically eliminates
77 // corner cases in graph optimizations, allowing for more aggressive
78 // optimizations and better performance.
isUnsupportedOp(const Node * node)79 static bool isUnsupportedOp(const Node* node) {
80 auto kind = node->kind();
81 if (kind != aten::__is__ && kind != aten::__isnot__) {
82 return false;
83 }
84
85 // We can't support aten::__is__ (and __isnot__) with tensor arguments.
86 // Consider the following graph:
87 // def forward(x):
88 // y = x.detach()
89 // return x is y
90 // We have a graph optimization that removes the `detach` node since it is
91 // a no-op during inference. But this affects the result - we get true
92 // instead of false! There are many other graph passes affected by this
93 // issue.
94 return allArgsAreTensors(node);
95 }
96
97 namespace {
98
canEnableStaticRuntimeImpl(const Block * block)99 bool canEnableStaticRuntimeImpl(const Block* block) {
100 if (block == nullptr) {
101 return false;
102 }
103
104 bool can_support = true;
105 for (auto* node : block->nodes()) {
106 for (auto* subblock : node->blocks()) {
107 // The ordering prevents && from short circuiting, which we want -
108 // it's useful to see *all* the unsupported ops.
109 can_support = canEnableStaticRuntimeImpl(subblock) && can_support;
110 }
111
112 const auto kind = node->kind();
113 if (kind == prim::Constant) {
114 continue;
115 }
116 // check if can get op from Node
117 const Operator* op = node->maybeOperator();
118 if (isUnsupportedOp(node) || (!op && !nativeOpIsRegistered(kind))) {
119 can_support = false;
120 LOG(WARNING) << "Found unsupported op: " << kind.toQualString();
121 }
122 }
123 return can_support;
124 }
125
126 } // namespace
127
128 // Graph must be frozen. canEnableStaticRuntime will return false
129 // if there's any prim::CallMethod ops left in the graph.
canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph> & graph)130 bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
131 return canEnableStaticRuntimeImpl(graph->block());
132 }
133
134 namespace {
135
136 auto sr_metadata_registerer = torch::class_<StaticRuntimeMetadata>(
137 "StaticRuntime",
138 "StaticRuntimeMetadata");
139
140 } // namespace
141
dumpValueSet(const c10::FastSet<const Value * > & value_set,const char * set_name)142 std::string dumpValueSet(
143 const c10::FastSet<const Value*>& value_set,
144 const char* set_name) {
145 std::ostringstream oss;
146 oss << set_name << ": {";
147 for (const auto* val : value_set) {
148 oss << "%" << val->debugName() << ", ";
149 }
150 oss << "}";
151 return oss.str();
152 }
153
154 namespace {
155
OptimizeGraph(std::shared_ptr<torch::jit::Graph> & graph,const StaticModuleOptions & opts,std::vector<IValue> sample_inputs)156 void OptimizeGraph(
157 std::shared_ptr<torch::jit::Graph>& graph,
158 const StaticModuleOptions& opts,
159 std::vector<IValue> sample_inputs) {
160 GRAPH_DUMP("Before optimizations: ", graph);
161 if (opts.enable_tensorexpr_fusion) {
162 if (sample_inputs.empty()) {
163 VLOG(1) << "Cannot perform TensorExpr fusion - sample_inputs is empty";
164 } else {
165 VLOG(1) << "Performing TensorExpr fusion";
166 performTensorExprFusion(graph, std::move(sample_inputs));
167 }
168 }
169 Inline(*graph);
170 ConstantPropagation(graph);
171 Canonicalize(graph);
172 ConstantPropagation(graph);
173 RemoveTensorMutation(graph);
174 ConstantPropagation(graph);
175 EliminateNoOpSlice(graph);
176 EliminateDeadCode(graph);
177 FuseInferenceOpsForSparseNN(graph);
178 UseVariadicCat(graph);
179 UseVariadicStack(graph);
180 EliminateTrivialEquallySplit(graph);
181 EliminateExtraPermuteOps(graph);
182
183 if (opts.enable_out_variant) {
184 UseVariadicOp(
185 graph,
186 fromQualString("fb::sigrid_transforms_torch_bind"),
187 fromQualString("fb::variadic_sigrid_transforms_torch_bind"));
188 UseVariadicOp(
189 graph,
190 fromQualString("torcharrow::inference_wrapper_run_flat"),
191 fromQualString("torcharrow::variadic_inference_wrapper_run_flat"));
192 // These fused ops only have out variants - we can't do the fusion when
193 // out variants are disabled.
194 FuseSignLog1P(graph);
195 FuseClampNaNToNum(graph);
196
197 #ifdef FBCODE_CAFFE2
198 if (opts.use_copy_variants && !opts.enable_tensorexpr_fusion) {
199 ReplaceWithCopy(graph);
200 } else {
201 ReplacePermuteWithCopy(graph);
202 }
203 if (opts.use_maybe_copy_variants && !opts.enable_tensorexpr_fusion) {
204 ReplaceWithMaybeCopy(graph);
205 }
206 FuseListUnpack(graph);
207 RemoveUnnecessaryOutputs(graph);
208 PrepackWeights(graph);
209 #endif
210 }
211
212 ConstantPropagation(graph);
213 RemoveImmutableInputDictLookups(graph);
214 UseVariadicTupleUnpack(graph);
215 UseVariadicGroupedAccessor(graph);
216 EliminateNoOps(
217 graph, /* custom_ops */ {fromQualString("fb::scale_gradient")});
218 AddIfThenElseOp(graph);
219 UseSplitAndSqueeze(graph);
220 UseInPlaceGetRealInputsFromOptionalInputsV2(graph);
221 GRAPH_DUMP("Final graph after optimizations: ", graph);
222 }
223
IsSelfInGraphInput(std::shared_ptr<torch::jit::Graph> & graph)224 bool IsSelfInGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
225 return !graph->inputs().empty() && graph->inputs().at(0)->type()->is_module();
226 }
227
228 // remove unused input 0 from graph
removeSelfFromGraphInput(std::shared_ptr<torch::jit::Graph> & graph)229 bool removeSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
230 if (graph->inputs().at(0)->type()->is_module()) {
231 if (graph->inputs().at(0)->hasUses()) {
232 return false;
233 }
234 graph->eraseInput(0);
235 }
236 return true;
237 }
238
valueVecFromFastSet(const c10::FastSet<const Value * > & s)239 std::vector<Value*> valueVecFromFastSet(const c10::FastSet<const Value*>& s) {
240 std::vector<Value*> result;
241 result.reserve(s.size());
242 for (auto* v : s) {
243 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
244 result.emplace_back(const_cast<Value*>(v));
245 }
246 return result;
247 }
248
mayContainAlias(const AliasDb & db,const Value * v1,const Value * v2)249 bool mayContainAlias(const AliasDb& db, const Value* v1, const Value* v2) {
250 // AliasDb is not const-correct here, so we have to const_cast
251 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
252 return db.mayContainAlias(const_cast<Value*>(v1), const_cast<Value*>(v2));
253 }
254
mayContainAlias(const AliasDb & db,const Value * a,const c10::FastSet<const Value * > & b)255 bool mayContainAlias(
256 const AliasDb& db,
257 const Value* a,
258 const c10::FastSet<const Value*>& b) {
259 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
260 return db.mayContainAlias(const_cast<Value*>(a), valueVecFromFastSet(b));
261 }
262
escapesScope(const AliasDb & db,const Value * a)263 bool escapesScope(const AliasDb& db, const Value* a) {
264 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
265 return db.escapesScope({const_cast<Value*>(a)});
266 }
267
PrepareGraphForStaticModule(std::shared_ptr<torch::jit::Graph> graph,const StaticModuleOptions & opts,std::vector<IValue> sample_inputs)268 void PrepareGraphForStaticModule(
269 std::shared_ptr<torch::jit::Graph> graph,
270 const StaticModuleOptions& opts,
271 std::vector<IValue> sample_inputs) {
272 TORCH_CHECK(canEnableStaticRuntime(graph));
273 OptimizeGraph(graph, opts, std::move(sample_inputs));
274
275 // Static runtime moves its outputs out of the runtime
276 // by default. In some rare cases, this is not actually safe to
277 // do - for example, if the value is a constant, static runtime
278 // needs to hold onto a copy. Rather than adding special logic
279 // to handle this rare case, we use this pass to detect it and
280 // create an owned reference that can be safely moved out of the
281 // runtime.
282 CreateOwnedRefsForSpecialValues(*graph);
283
284 // We assume that each sub-block has at least one output. If we
285 // detect any that have 0, force the sub-block to return None.
286 ForceNonEmptyOutputs(*graph);
287 }
288
PrepareForStaticModule(const torch::jit::Module & m,bool is_frozen,const StaticModuleOptions & opts,std::vector<IValue> sample_inputs)289 std::pair<std::shared_ptr<Graph>, std::optional<Module>> PrepareForStaticModule(
290 const torch::jit::Module& m,
291 bool is_frozen,
292 const StaticModuleOptions& opts,
293 std::vector<IValue> sample_inputs) {
294 LOG(INFO) << "StaticModuleOptions: enable_out_variant "
295 << opts.enable_out_variant << ", optimize_memory "
296 << opts.optimize_memory << ", manage_output_tensors "
297 << opts.manage_output_tensors << ", use_copy_variants "
298 << opts.use_copy_variants << ", use_maybe_copy_variants "
299 << opts.use_maybe_copy_variants << ", enable_tensorexpr_fusion "
300 << opts.enable_tensorexpr_fusion;
301
302 Module module = m.copy();
303 if (!is_frozen) {
304 module.eval();
305 module = freeze_module(module);
306 }
307
308 Method method = module.get_method("forward");
309 auto graph = module.get_method("forward").graph();
310
311 if (!sample_inputs.empty() && IsSelfInGraphInput(graph)) {
312 sample_inputs.insert(sample_inputs.begin(), m._ivalue());
313 }
314 PrepareGraphForStaticModule(graph, opts, std::move(sample_inputs));
315
316 return std::make_pair(graph, module);
317 }
318
PrepareForStaticModule(const std::shared_ptr<torch::jit::Graph> & graph,const StaticModuleOptions & opts,std::vector<IValue> sample_inputs)319 std::pair<std::shared_ptr<Graph>, std::optional<Module>> PrepareForStaticModule(
320 const std::shared_ptr<torch::jit::Graph>& graph,
321 const StaticModuleOptions& opts,
322 std::vector<IValue> sample_inputs) {
323 PrepareGraphForStaticModule(graph, opts, std::move(sample_inputs));
324 return std::make_pair(graph, std::nullopt);
325 }
326
327 } // namespace
328
init(const Block & block,const AliasDb & db)329 void ValueGroup::init(const Block& block, const AliasDb& db) {
330 external_aliases_.clear();
331 output_aliases_.clear();
332 // Build `external_aliases` as we look through nodes forwardly from
333 // the graph's inputs and add aliases of the inputs being created by the
334 // nodes.
335 external_aliases_.insert(block.inputs().begin(), block.inputs().end());
336 for (const auto* node : block.nodes()) {
337 if (node->kind() == prim::Constant) {
338 for (const auto* output : node->outputs()) {
339 external_aliases_.insert(output);
340 }
341 }
342 }
343 for (const auto* node : block.nodes()) {
344 if (node->kind() == prim::Constant) {
345 // Constants are already in `external_aliases`.
346 continue;
347 }
348 for (const auto* v : node->outputs()) {
349 if (escapesScope(db, v) || mayContainAlias(db, v, external_aliases_)) {
350 external_aliases_.insert(v);
351 }
352 }
353 }
354
355 // Build `output_aliases` as we look through nodes reversely so that we can
356 // start from the output values, and follow the flows backwardly from there.
357 output_aliases_.insert(block.outputs().begin(), block.outputs().end());
358 for (const auto* node : block.nodes().reverse()) {
359 if (node->kind() == prim::Constant) {
360 // Constants cannot create any aliases.
361 continue;
362 }
363 for (const auto* v : node->outputs()) {
364 if (mayContainAlias(db, v, output_aliases_)) {
365 output_aliases_.insert(v);
366 }
367 }
368 }
369 }
370
371 namespace {
372
isTensorList(const Value * value)373 bool isTensorList(const Value* value) {
374 auto* type = value->type()->castRaw<ListType>();
375 if (!type) {
376 return false;
377 }
378 return type->getElementType()->kind() == c10::TypeKind::TensorType;
379 }
380
containTensorsOnly(at::ArrayRef<Value * > values)381 bool containTensorsOnly(at::ArrayRef<Value*> values) {
382 // return true only if all outputs are tensors
383 return std::all_of(values.begin(), values.end(), [](const Value* value) {
384 return value->type()->kind() == c10::TypeKind::TensorType ||
385 isTensorList(value);
386 });
387 }
388
isPureFunction(const Node * node)389 bool isPureFunction(const Node* node) {
390 auto* schema = node->maybeSchema();
391 return schema &&
392 schema->aliasAnalysis() == c10::AliasAnalysisKind::PURE_FUNCTION;
393 }
394
395 } // namespace
396
ManagedTensorRanges(Block & block,const AliasDb & alias_db,const c10::FastSet<const Value * > & managed_tensor_values)397 ManagedTensorRanges::ManagedTensorRanges(
398 Block& block,
399 const AliasDb& alias_db,
400 const c10::FastSet<const Value*>& managed_tensor_values) {
401 const std::vector<Node*> nodes(block.nodes().begin(), block.nodes().end());
402 const c10::FastSet<const Value*> graph_inputs(
403 block.inputs().begin(), block.inputs().end());
404
405 const auto num_nodes = static_cast<uint32_t>(nodes.size());
406 for (const auto i : c10::irange(num_nodes)) {
407 auto* node = nodes[i];
408 for (auto* input : node->inputs()) {
409 auto* lifetime = getLifetime(input);
410 if (!lifetime) {
411 continue;
412 }
413 DCHECK(lifetime->end <= i);
414 lifetime->end = i;
415 }
416 for (auto* output : node->outputs()) {
417 if (!alias_db.isMutableType(output)) {
418 continue;
419 }
420 value_lifetimes_.emplace(output, Lifetime(i, i));
421 }
422 }
423 for (auto* graph_output : block.outputs()) {
424 auto* lifetime = getLifetime(graph_output);
425 if (!lifetime) {
426 continue;
427 }
428 lifetime->end = num_nodes;
429 }
430
431 // Handle aliases. Aliases may extend a Value*'s lifetime. If a node
432 // has an input and output that may alias each other, set the input's
433 // lifetime end to max(input.lifetime_end, output.lifetime_end). Iterate
434 // backwards to handle chains of aliases.
435 for (const auto* node : block.nodes().reverse()) {
436 if (isPureFunction(node)) {
437 // If the node is a pure function, it doesn't create any aliases,
438 // so we can safely skip it.
439 continue;
440 }
441
442 auto inputs = collectValuesWithTrackedLifetimes(node->inputs());
443 auto outputs = collectValuesWithTrackedLifetimes(node->outputs());
444 for (auto* input : inputs) {
445 auto* input_lifetime = getLifetime(input);
446 DCHECK(input_lifetime != nullptr);
447 for (auto* output : outputs) {
448 if (mayContainAlias(alias_db, input, output)) {
449 auto* output_lifetime = getLifetime(output);
450 DCHECK(output_lifetime != nullptr);
451 input_lifetime->end =
452 std::max(output_lifetime->end, input_lifetime->end);
453 }
454 }
455 }
456 }
457 for (auto* managed_tensor : managed_tensor_values) {
458 auto* lifetime = getLifetime(managed_tensor);
459 DCHECK(lifetime && lifetime->end <= num_nodes);
460 Node* freeing_node = nullptr;
461 if (lifetime->end == num_nodes) {
462 freeing_node = block.return_node();
463 } else {
464 freeing_node = nodes[lifetime->end];
465 }
466 node_to_newly_free_tensors_[freeing_node].emplace_back(managed_tensor);
467 }
468 }
469
nodeFreesManagedTensors(Node * node) const470 bool ManagedTensorRanges::nodeFreesManagedTensors(Node* node) const {
471 auto it = node_to_newly_free_tensors_.find(node);
472 return it != node_to_newly_free_tensors_.end() && !it->second.empty();
473 }
474
475 const std::vector<const Value*>& ManagedTensorRanges::
availableTensorValuesAfterNode(Node * node) const476 availableTensorValuesAfterNode(Node* node) const {
477 return node_to_newly_free_tensors_.at(node);
478 }
479
lifetimesOverlap(const Value * v1,const Value * v2) const480 bool ManagedTensorRanges::lifetimesOverlap(const Value* v1, const Value* v2)
481 const {
482 const auto* v1_lifetime = getLifetime(v1);
483 const auto* v2_lifetime = getLifetime(v2);
484 if (!v1_lifetime || !v2_lifetime) {
485 return false;
486 }
487
488 if (v1_lifetime->start < v2_lifetime->start) {
489 return v1_lifetime->end >= v2_lifetime->start;
490 }
491 return v2_lifetime->end >= v1_lifetime->start;
492 }
493
getLifetime(const Value * value) const494 const ManagedTensorRanges::Lifetime* ManagedTensorRanges::getLifetime(
495 const Value* value) const {
496 auto it = value_lifetimes_.find(value);
497 if (it != value_lifetimes_.end()) {
498 return &it->second;
499 }
500 return nullptr;
501 }
502
getLifetime(const Value * value)503 ManagedTensorRanges::Lifetime* ManagedTensorRanges::getLifetime(
504 const Value* value) {
505 // const_cast is safe here, this is just a way to avoid code duplication
506 // between the const/non-const versions of getLifetime.
507
508 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
509 const auto* const_this = const_cast<const ManagedTensorRanges*>(this);
510
511 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
512 return const_cast<ManagedTensorRanges::Lifetime*>(
513 const_this->getLifetime(value));
514 }
515
516 std::vector<const Value*> ManagedTensorRanges::
collectValuesWithTrackedLifetimes(at::ArrayRef<const Value * > values)517 collectValuesWithTrackedLifetimes(at::ArrayRef<const Value*> values) {
518 std::vector<const Value*> mutable_values;
519 mutable_values.reserve(values.size());
520 std::copy_if(
521 values.begin(),
522 values.end(),
523 std::back_inserter(mutable_values),
524 [this](const Value* value) { return getLifetime(value) != nullptr; });
525 return mutable_values;
526 }
527
StaticModule(const std::shared_ptr<torch::jit::Graph> & g,const StaticModuleOptions & opts,std::vector<IValue> sample_inputs)528 StaticModule::StaticModule(
529 const std::shared_ptr<torch::jit::Graph>& g,
530 const StaticModuleOptions& opts,
531 std::vector<IValue> sample_inputs)
532 : StaticModule(
533 PrepareForStaticModule(g->copy(), opts, std::move(sample_inputs)),
534 opts) {}
535
StaticModule(const torch::jit::Module & m,bool is_frozen,const StaticModuleOptions & opts,std::vector<IValue> sample_inputs)536 StaticModule::StaticModule(
537 const torch::jit::Module& m,
538 bool is_frozen,
539 const StaticModuleOptions& opts,
540 std::vector<IValue> sample_inputs)
541 : StaticModule(
542 PrepareForStaticModule(m, is_frozen, opts, std::move(sample_inputs)),
543 opts) {}
544
StaticModule(std::pair<std::shared_ptr<torch::jit::Graph>,std::optional<Module>> graph_and_module,const StaticModuleOptions & opts)545 StaticModule::StaticModule(
546 std::pair<std::shared_ptr<torch::jit::Graph>, std::optional<Module>>
547 graph_and_module,
548 const StaticModuleOptions& opts)
549 : opts_(opts),
550 graph_(std::move(graph_and_module.first)),
551 module_(std::move(graph_and_module.second)),
552 num_inputs_(graph_->inputs().size()) {
553 sr_metadata_ = c10::make_intrusive<jit::StaticRuntimeMetadata>(opts_);
554 // recursively attach metadata to prim::fork nodes
555 attachNodeMetadata(graph_->block());
556
557 // check opt flags
558 if (opts.manage_output_tensors) {
559 TORCH_CHECK(
560 opts_.enable_out_variant,
561 "When manage_output_tensors is true, enable_out_variant must be set to true");
562 }
563 if (opts_.optimize_memory) {
564 TORCH_CHECK(
565 opts_.enable_out_variant,
566 "When optimize_memory is true, enable_out_variant must be set to true");
567 }
568
569 // handle schema
570 if (module_.has_value()) {
571 Method method = module_->get_method("forward");
572 schema_ = method.function().getSchema();
573 const auto num_schema_args = schema_->arguments().size();
574 DCHECK(num_schema_args > 0);
575 if (removeSelfFromGraphInput(graph_)) {
576 module_ = std::nullopt;
577 num_inputs_ = num_schema_args - 1;
578 }
579 }
580
581 {
582 size_t nodes_size = 0, constants_size = 0;
583 for (Node* node : graph_->nodes()) {
584 ++(node->kind() == prim::Constant ? constants_size : nodes_size);
585 }
586
587 constants_.reserve(constants_size);
588 functions_.reserve(nodes_size);
589 }
590
591 // Create ProcessedFunction instances first to freeze their addresses to pass
592 // to ProcessedNode.
593 AliasDb alias_db(graph_, /*isFrozen=*/false);
594 GRAPH_DEBUG("AliasDb: ", alias_db.toString());
595
596 // Maps each Value* in the graph to its index in the values_ array that will
597 // eventually be created by StaticRuntime.
598 c10::FastMap<const Value*, uint32_t> value_to_index;
599 prepareFunctionsAndConstants(graph_->block(), alias_db, value_to_index);
600
601 const auto constants_index_offset = 0;
602 const auto values_index_offset = constants_index_offset + constants().size();
603 value_buffer_size_ = values_index_offset;
604
605 value_buffer_size_ +=
606 prepareBlockInfo(graph_->block(), values_index_offset, value_to_index);
607
608 prepareStaticNodeInfos(graph_->block(), value_to_index, alias_db);
609
610 for (auto& block_and_info : block_infos_) {
611 auto& block_info = block_and_info.second;
612 block_info.prepare_for_memory_planner(alias_db, opts);
613 }
614 }
615
prepareBlockInfo(Block * block,const size_t start_idx,c10::FastMap<const Value *,uint32_t> & value_to_index)616 size_t StaticModule::prepareBlockInfo(
617 Block* block,
618 const size_t start_idx,
619 c10::FastMap<const Value*, uint32_t>& value_to_index) {
620 block_infos_.emplace(block, BlockInfo(start_idx, *block));
621
622 const auto num_inputs = static_cast<uint32_t>(block->inputs().size());
623 for (const auto i : c10::irange(num_inputs)) {
624 value_to_index.emplace(block->inputs()[i], start_idx + i);
625 }
626 auto cur_idx = start_idx + num_inputs;
627
628 for (auto* node : block->nodes()) {
629 for (auto* sub_block : node->blocks()) {
630 cur_idx += prepareBlockInfo(sub_block, cur_idx, value_to_index);
631 }
632
633 if (node->kind() == prim::Constant) {
634 continue;
635 }
636
637 TORCH_CHECK(
638 cur_idx < (1 << 16),
639 "outputs offset in values table",
640 cur_idx,
641 " would overflow 2-byte index storage");
642
643 const auto num_outputs = static_cast<uint32_t>(node->outputs().size());
644 for (const auto i : c10::irange(num_outputs)) {
645 value_to_index.emplace(node->outputs()[i], cur_idx + i);
646 }
647 cur_idx += num_outputs;
648 }
649
650 std::vector<uint16_t> output_indices;
651 output_indices.reserve(block->outputs().size());
652 for (auto* output : block->outputs()) {
653 const auto output_idx = value_to_index.at(output);
654 TORCH_CHECK(
655 output_idx < (1 << 16),
656 "outputs offset in values table",
657 output_idx,
658 " would overflow 2-byte index storage");
659 output_indices.push_back(output_idx);
660 }
661
662 block_infos_.at(block).set_output_indices(std::move(output_indices));
663 return cur_idx - start_idx;
664 }
665
attachNodeMetadata(Block * block)666 void StaticModule::attachNodeMetadata(Block* block) {
667 for (auto* node : block->nodes()) {
668 if (node->kind() == prim::fork) {
669 node->ival_(getStaticRuntimeMetadataSymbol(), IValue(sr_metadata_));
670 }
671 for (auto* sub_block : node->blocks()) {
672 attachNodeMetadata(sub_block);
673 }
674 }
675 }
676
prepareFunctionsAndConstants(Block * block,const AliasDb & alias_db,c10::FastMap<const Value *,uint32_t> & value_to_index)677 void StaticModule::prepareFunctionsAndConstants(
678 Block* block,
679 const AliasDb& alias_db,
680 c10::FastMap<const Value*, uint32_t>& value_to_index) {
681 for (auto* node : block->nodes()) {
682 for (auto* sub_block : node->blocks()) {
683 prepareFunctionsAndConstants(sub_block, alias_db, value_to_index);
684 }
685
686 if (node->kind() == prim::Constant) {
687 auto* v = node->output();
688 TORCH_CHECK(
689 v->type()->kind() != FunctionType::Kind,
690 "got ",
691 typeKindToString(v->type()->kind()),
692 " instead of ",
693 typeKindToString(FunctionType::Kind));
694 value_to_index.emplace(v, constants_.size());
695 constants_.emplace_back(toIValue(v).value());
696 continue;
697 }
698
699 // see [Check and correct bad schema alias info at runtime]
700 bool check_outputs_for_overlap =
701 !alias_db.mayContainAlias(node->inputs(), node->outputs()) &&
702 containTensorsOnly(node->outputs());
703 // new ProcessedFunction
704 functions_.emplace_back(
705 node, opts_.enable_out_variant, check_outputs_for_overlap);
706 }
707 }
708
prepareStaticNodeInfos(Block * block,const c10::FastMap<const Value *,uint32_t> & value_to_index,const AliasDb & alias_db,size_t node_idx)709 size_t StaticModule::prepareStaticNodeInfos(
710 Block* block,
711 const c10::FastMap<const Value*, uint32_t>& value_to_index,
712 const AliasDb& alias_db,
713 size_t node_idx) {
714 const auto node_start = node_idx;
715
716 auto& block_info = block_infos_.at(block);
717 std::vector<StaticNodeInfo> nodes;
718 c10::FastMap<Node*, bool> node_has_out_variant;
719
720 for (auto* node : block->nodes()) {
721 if (node->kind() == prim::Constant) {
722 continue;
723 }
724
725 for (auto* sub_block : node->blocks()) {
726 node_idx +=
727 prepareStaticNodeInfos(sub_block, value_to_index, alias_db, node_idx);
728 }
729 const auto num_outputs = static_cast<uint32_t>(node->inputs().size());
730 ProcessedNodeInputs input_indices(num_outputs);
731 for (const auto input_idx : c10::irange<uint32_t>(num_outputs)) {
732 auto* input = node->inputs()[input_idx];
733 auto input_ivalue_idx = value_to_index.at(input);
734 TORCH_CHECK(
735 input_ivalue_idx < (1 << 16),
736 "input index in values table ",
737 input_ivalue_idx,
738 " would overflow 2-byte index storage");
739 input_indices[input_idx] = input_ivalue_idx;
740 }
741
742 ProcessedFunction* fn = &functions_[node_idx];
743
744 // create a new ProcessedNode
745 const auto node_output_idx = node->outputs().empty()
746 // The index is unused if there are no outputs, so just create a
747 // placeholder value.
748 ? std::numeric_limits<uint16_t>::max()
749 : value_to_index.at(node->output(0));
750 nodes.emplace_back(node, fn, std::move(input_indices), node_output_idx);
751
752 node_has_out_variant.emplace(node, nodes.back().has_out_variant());
753 ++node_idx;
754 }
755
756 block_info.set_nodes(std::move(nodes), node_has_out_variant);
757 block_info.init_value_group(alias_db);
758
759 return node_idx - node_start;
760 }
761
762 #ifdef FBCODE_CAFFE2
763 thread_local SROperatorObserver* tlsOpObserver = nullptr;
764
setCurrentThreadObserver(SROperatorObserver * observer)765 void SROperatorObserver::setCurrentThreadObserver(
766 SROperatorObserver* observer) {
767 tlsOpObserver = observer;
768 }
769
getCurrentThreadObserver()770 SROperatorObserver* SROperatorObserver::getCurrentThreadObserver() {
771 return tlsOpObserver;
772 }
773
onStart(const Node * node)774 void SROperatorObserver::onStart(const Node* node) {
775 if (tlsOpObserver != nullptr && tlsOpObserver->startCb != nullptr) {
776 tlsOpObserver->startCb(node);
777 }
778 }
779
onEnd(const Node * node)780 void SROperatorObserver::onEnd(const Node* node) {
781 if (tlsOpObserver != nullptr && tlsOpObserver->endCb != nullptr) {
782 tlsOpObserver->endCb(node);
783 }
784 }
785 #endif // FBCODE_CAFFE2
786
BlockInfo(uint32_t input_idx,Block & block)787 BlockInfo::BlockInfo(uint32_t input_idx, Block& block)
788 : input_idx_(input_idx), block_(block) {}
789
set_nodes(std::vector<StaticNodeInfo> nodes,const c10::FastMap<Node *,bool> & node_has_out_variant)790 void BlockInfo::set_nodes(
791 std::vector<StaticNodeInfo> nodes,
792 const c10::FastMap<Node*, bool>& node_has_out_variant) {
793 nodes_ = std::move(nodes);
794
795 for (auto& node : nodes_) {
796 if (node.num_outputs() == 1 &&
797 isOptimizableContainerType(node.node(), node_has_out_variant)) {
798 node_is_optimizable_container_type_.emplace(node.node());
799 }
800 }
801 }
prepare_for_memory_planner(const AliasDb & alias_db,const StaticModuleOptions & opts)802 void BlockInfo::prepare_for_memory_planner(
803 const AliasDb& alias_db,
804 const StaticModuleOptions& opts) {
805 if (!opts.enable_out_variant) {
806 return;
807 }
808
809 // Never manage graph outputs so that we can do std::move(output_ivalue).
810 // This does not affect performance if the graph returns a collection object.
811 c10::FastSet<const Value*> graph_output_values(
812 block_.outputs().begin(), block_.outputs().end());
813
814 // collect register indices of outputs of ops with out variant
815 for (StaticNodeInfo& pnode : nodes_) {
816 if (!pnode.has_out_variant()) {
817 continue;
818 }
819 auto outputs = pnode.node()->outputs();
820 const auto num_outputs = static_cast<uint32_t>(outputs.size());
821 for (const auto i : c10::irange(num_outputs)) {
822 const Value* out_v = outputs[i];
823 // Types are stored in the underlying TorchScript IR
824 bool is_tensor_type = out_v->type()->castRaw<TensorType>();
825 if (opts.manage_output_tensors && is_tensor_type &&
826 graph_output_values.find(out_v) == graph_output_values.end() &&
827 value_group_.isOutputAlias(out_v)) {
828 managed_output_tensor_values_.insert(out_v);
829 continue;
830 }
831 if (value_group_.isAlwaysAlive(out_v)) {
832 continue;
833 }
834 if (is_tensor_type) {
835 managed_tensor_values_.insert(out_v);
836 } else if (node_is_optimizable_container_type(pnode.node())) {
837 // We "leak" certain container types because their allocations
838 // take a long time
839 leaked_values_.insert(out_v);
840 }
841 }
842 }
843
844 for (const Value* output : block_.outputs()) {
845 managed_tensor_values_.erase(output);
846 }
847 GRAPH_DEBUG("managed_tensor_values: ", dumpValueSet(managed_tensor_values_));
848 GRAPH_DEBUG(
849 "managed_output_tensor_values_: ",
850 dumpValueSet(managed_output_tensor_values_));
851
852 managed_tensor_ranges_ =
853 ManagedTensorRanges(block_, alias_db, managed_tensor_values_);
854 }
855
opts() const856 const StaticModuleOptions& StaticModule::opts() const {
857 return opts_;
858 }
859
num_outputs() const860 size_t StaticModule::num_outputs() const {
861 return graph_->outputs().size();
862 }
863
num_inputs() const864 size_t StaticModule::num_inputs() const {
865 return num_inputs_;
866 }
867
runtime()868 StaticRuntime& StaticModule::runtime() {
869 if (!cached_runtime_) {
870 cached_runtime_ = std::make_unique<StaticRuntime>(*this);
871 }
872 return *cached_runtime_;
873 }
874
findNodeWithKindForTesting(const std::string & kind) const875 Node* StaticModule::findNodeWithKindForTesting(const std::string& kind) const {
876 for (auto& block_and_info : block_infos_) {
877 auto& block_info = block_and_info.second;
878 for (auto& pnode : block_info.nodes()) {
879 if (pnode.node()->kind().toQualString() == kind) {
880 return pnode.node();
881 }
882 }
883 }
884 return nullptr;
885 }
886
operator ()(const std::vector<c10::IValue> & args,const KeywordArgs & kwargs)887 c10::IValue StaticModule::operator()(
888 const std::vector<c10::IValue>& args,
889 const KeywordArgs& kwargs) {
890 return runtime()(args, kwargs);
891 }
892
operator ()(std::vector<c10::IValue> && args,const KeywordArgs & kwargs)893 c10::IValue StaticModule::operator()(
894 std::vector<c10::IValue>&& args,
895 const KeywordArgs& kwargs) {
896 return runtime()(std::move(args), kwargs);
897 }
898
BlockRunner(const StaticModule & sm,IValue * values,Block * block,torch::jit::TaskLauncher * launcher,bool is_root_block)899 BlockRunner::BlockRunner(
900 const StaticModule& sm,
901 IValue* values,
902 Block* block,
903 torch::jit::TaskLauncher* launcher,
904 bool is_root_block)
905 : static_module_(sm),
906 block_info_(static_module_.block_info(block)),
907 is_root_block_(is_root_block),
908 first_input_is_self_(
909 is_root_block_ && static_module_.first_input_is_self()),
910 inputs_begin_(block_info_.block_inputs_idx()),
911 // TODO(T108633124): Turn on manage output tensors for sub-blocks.
912 manage_output_tensors_enabled_(
913 is_root_block_ && sm.opts().manage_output_tensors),
914 values_(values) {
915 nodes_.reserve(block_info_.nodes().size());
916 for (auto& pre_pnode : block_info_.nodes()) {
917 nodes_.emplace_back(pre_pnode, values_);
918 }
919
920 for (auto index : block_info_.block_output_indices()) {
921 outputs_.emplace_back(&values_[index]);
922 }
923
924 for (auto& pnode : nodes_) {
925 auto* node = pnode.node();
926
927 // attach the async taskLauncher to processedNodes
928 pnode.set_metadata(launcher);
929 auto blocks = node->blocks();
930 const auto num_blocks = blocks.size();
931 if (num_blocks == 0) {
932 continue;
933 }
934 DCHECK(node->kind() == prim::If || node->kind() == prim::Loop);
935 std::vector<BlockRunner> block_runners;
936 block_runners.reserve(num_blocks);
937
938 for (auto* b : blocks) {
939 block_runners.emplace_back(sm, values_, b, launcher);
940 }
941 pnode.set_metadata(std::move(block_runners));
942 }
943 }
944
945 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
946 BlockRunner::BlockRunner(BlockRunner&&) noexcept = default;
947
948 BlockRunner::~BlockRunner() = default;
949
set_arg(const size_t idx,std::vector<IValue> && args)950 void BlockRunner::set_arg(const size_t idx, std::vector<IValue>&& args) {
951 DCHECK(idx < args.size());
952 Input(idx + first_input_is_self_) = std::move(args[idx]);
953 }
954
set_arg(const size_t idx,const std::vector<IValue> & args)955 void BlockRunner::set_arg(const size_t idx, const std::vector<IValue>& args) {
956 DCHECK(idx < args.size());
957 Input(idx + first_input_is_self_) = args[idx];
958 }
959
set_arg(const size_t idx,const IValue & arg)960 void BlockRunner::set_arg(const size_t idx, const IValue& arg) {
961 Input(idx + first_input_is_self_) = arg;
962 }
963
964 namespace {
check_type(const Argument & schema_arg,const IValue & arg)965 void check_type(const Argument& schema_arg, const IValue& arg) {
966 // Fast path for most common case
967 if (arg.isTensor() &&
968 schema_arg.type()->kind() == c10::TypeKind::TensorType) {
969 return;
970 }
971 TORCH_CHECK(
972 arg.type()->isSubtypeOf(schema_arg.type()),
973 arg.type()->annotation_str(),
974 " is not a subtype of ",
975 schema_arg.type()->annotation_str(),
976 "; schema arg name: '",
977 schema_arg.name(),
978 "', ivalue: ",
979 iValueToString(arg));
980 }
981 } // namespace
982
983 template <typename IValueList>
set_inputs(IValueList && args,const KeywordArgs & kwargs)984 void BlockRunner::set_inputs(IValueList&& args, const KeywordArgs& kwargs) {
985 const auto& schema = static_module_.schema();
986 if (first_input_is_self_) {
987 Input(0) = static_module_.module()._ivalue();
988 }
989
990 if (!is_root_block_ || C10_UNLIKELY(!schema)) {
991 TORCH_CHECK(
992 kwargs.empty(),
993 "BlockRunner got kwargs; is_root_block: ",
994 std::to_string(is_root_block_),
995 "schema: ",
996 schema ? schema->name() : "(not available)");
997
998 const auto total_num_inputs = args.size() + first_input_is_self_;
999 TORCH_CHECK(
1000 total_num_inputs == block_info_.num_inputs(),
1001 "Block runner got ",
1002 std::to_string(total_num_inputs),
1003 " inputs; ",
1004 " first_input_is_self: ",
1005 std::to_string(first_input_is_self_),
1006 "; SR block expects ",
1007 std::to_string(block_info_.num_inputs()),
1008 " inputs for schema ",
1009 schema ? schema->name() : "(not available)");
1010
1011 for (const auto i_arg : c10::irange(args.size())) {
1012 set_arg(i_arg, std::forward<IValueList>(args));
1013 }
1014 return;
1015 }
1016
1017 const auto& schema_args = schema->arguments();
1018 size_t consumed_kwargs = 0;
1019 DCHECK(!schema_args.empty());
1020 TORCH_CHECK(
1021 args.size() < schema_args.size(),
1022 "Static runtime got ",
1023 std::to_string(args.size()),
1024 " arguments, expects ",
1025 std::to_string(schema_args.size() - 1),
1026 " for schema ",
1027 schema->name());
1028
1029 for (const auto i_arg : c10::irange(1, schema_args.size())) {
1030 // Start at 1 since the schema always contains `self`.
1031 const auto& schema_arg = schema_args[i_arg];
1032
1033 if (i_arg - 1 < args.size()) {
1034 check_type(schema_arg, std::forward<IValueList>(args)[i_arg - 1]);
1035 set_arg(i_arg - 1, std::forward<IValueList>(args));
1036 continue;
1037 }
1038
1039 auto it = kwargs.find(schema_arg.name());
1040 if (it != kwargs.end()) {
1041 check_type(schema_arg, it->second);
1042 set_arg(i_arg - 1, it->second);
1043 ++consumed_kwargs;
1044 continue;
1045 }
1046
1047 auto maybe_default_val = schema_arg.default_value();
1048 if (maybe_default_val) {
1049 set_arg(i_arg - 1, *maybe_default_val);
1050 continue;
1051 }
1052
1053 TORCH_CHECK(
1054 false,
1055 "Static runtime is missing required kwarg ",
1056 schema_arg.name(),
1057 " i_arg: ",
1058 std::to_string(i_arg),
1059 " for schema ",
1060 schema->name());
1061 }
1062 TORCH_CHECK(
1063 consumed_kwargs == kwargs.size(),
1064 "kwargs size mismatch (consumed ",
1065 std::to_string(consumed_kwargs),
1066 ", expected ",
1067 std::to_string(kwargs.size()),
1068 " for schema ",
1069 schema->name());
1070 }
1071
create_memory_planner()1072 void BlockRunner::create_memory_planner() {
1073 if (!planner_) {
1074 planner_ = std::make_unique<StandardMemoryPlanner>(
1075 this,
1076 block_info_,
1077 static_module_.opts().enable_out_variant,
1078 manage_output_tensors_enabled_,
1079 static_module_.opts().optimize_memory);
1080 }
1081 }
1082
1083 namespace {
1084
destroyNodeOutputs(ProcessedNode & p_node)1085 void destroyNodeOutputs(ProcessedNode& p_node) {
1086 const auto borrows_outputs = borrowsOutputs(p_node.node()->kind());
1087 const auto num_outputs = static_cast<uint32_t>(p_node.num_outputs());
1088 for (const auto i : c10::irange<uint32_t>(num_outputs)) {
1089 auto& output = p_node.Output(i);
1090 if (doesNotHeapAllocateWhenStoredInIValue(*output.type())) {
1091 continue;
1092 }
1093
1094 if (borrows_outputs) {
1095 // NB: No need to incref here. This codepath is only hit if the run didn't
1096 // finish, so we shouldn't be returning anything to the client.
1097 c10::MaybeOwnedTraits<IValue>::destroyBorrow(output);
1098 } else {
1099 output = IValue();
1100 }
1101 }
1102 }
1103
1104 } // namespace
1105
clean_up_intermediate_ivalues()1106 void BlockRunner::clean_up_intermediate_ivalues() noexcept {
1107 // We have to iterate in reverse order here due to borrowed
1108 // IValues - we don't want to destroy a value until all of its
1109 // borrows are cleaned up!
1110 for (auto it = nodes_.rbegin(); it != nodes_.rend(); ++it) {
1111 destroyNodeOutputs(*it);
1112 }
1113 }
1114
resetMemory()1115 void BlockRunner::resetMemory() noexcept {
1116 planner_.reset();
1117 // We must clean up intermediate values before inputs in case
1118 // there are borrowed inputs and static runtime owns the only
1119 // reference (e.g. the inputs were std::move'd into the runtime)
1120 clean_up_intermediate_ivalues();
1121 clean_up_input_ivalues();
1122 }
1123
move_outputs_to_tuple(uint32_t num_outputs)1124 c10::IValue BlockRunner::move_outputs_to_tuple(uint32_t num_outputs) {
1125 switch (num_outputs) {
1126 case 1:
1127 return c10::ivalue::Tuple::create(IValue(std::move(*outputs_[0])));
1128 case 2:
1129 return c10::ivalue::Tuple::create(
1130 IValue(std::move(*outputs_[0])), IValue(std::move(*outputs_[1])));
1131 case 3:
1132 return c10::ivalue::Tuple::create(
1133 IValue(std::move(*outputs_[0])),
1134 IValue(std::move(*outputs_[1])),
1135 IValue(std::move(*outputs_[2])));
1136 default: {
1137 std::vector<c10::IValue> outputs;
1138 outputs.reserve(num_outputs);
1139 for (const auto i : c10::irange(num_outputs)) {
1140 // use move here. Otherwise, clean up outputs_[i] explicitly
1141 outputs.emplace_back(std::move(*outputs_[i]));
1142 }
1143 return c10::ivalue::Tuple::create(std::move(outputs));
1144 }
1145 }
1146 }
1147
1148 /// [Check and correct bad schema alias info at runtime]
1149 /// Static runtime relies on the operator schema's alias info to be correct for
1150 /// memory planning. Because it's hard to enforce the alias info to be correct,
1151 /// we need to do runtime detection for accidental aliases that do not comply
1152 /// with the schema. Only aliases of managed tensors are problematic. To avoid
1153 /// runtime crashes, we can add runtime detection and force the op to comply
1154 /// with its schema by cloning the alias. Because all managed tensors' data_ptrs
1155 /// are part of the internal buffer that the MemoryPlanner allocates, we can
1156 /// check aliases by checking the memory overlap with this internal buffer. But
1157 /// a tensor's storage can be resized during inferenceso we need another way to
1158 /// handle the resized case.
1159 ///
1160 /// There are two ways for incorrect schema to break memory planning. Let's look
1161 /// at two examples:
1162 ///
1163 /// Example 1:
1164 /// @code
1165 /// def forward(x):
1166 /// a = x + x
1167 /// b = bad_op(a) # b ends up aliasing a incorrectly
1168 /// return (b)
1169 /// @endcode
1170 /// bad_op: its schema says it returns a new Tensor, but it actually returns an
1171 /// alias. In this case, the memory planner would recognize `a` as a managed
1172 /// tensor and clean up its memory before returning `b`. But `b` is actually an
1173 /// alias of `a`, when `a`'s data_ptr get reset, `b`'s data_ptr gets reset too.
1174 ///
1175 /// Example 2:
1176 /// @code
1177 /// def forward(x):
1178 /// a = x + x
1179 /// a2 = bad_op(a) # a2 ends up alias a incorrectly
1180 /// b = a + a
1181 /// c = b * b # c shares storage with a
1182 /// d = c + 2 # d shares storage with b
1183 /// e = a2 * a2
1184 /// return (d, e)
1185 /// @endcode
1186 /// With the memory reuse algorithm, `c` could end up sharing storage with `a`,
1187 /// but because of bad_op, `a2` now aliases `a`. `c` overwrites `a` and
1188 /// therefore `a2`, leading to the wrong results. We solve this problem with two
1189 /// steps. Note this doesn't happen with the current memory reuse algorithm
1190 /// because of the way it's implemented. Things could change with a different
1191 /// implementation.
1192 ///
1193 /// Step 1, annotate the ProcessedNodes with a flag `check_memory_overlap_` set
1194 /// to true if its outputs do not alias its inputs as indicated by the AliasDb
1195 /// and all of its outputs are Tensors. Then at runtime, we check that the
1196 /// nodes' output tensors do not overlap with the internal buffer that the
1197 /// MemoryPlanner allocates. For latency concerns, we only run this check for
1198 /// fallback ops. The schemas of native ops and out variants are vetted and
1199 /// enforced with static runtime unit tests. For the first iteration, we do a
1200 /// full memory overlap check with
1201 /// ProcessedNode::verify_and_correct_memory_overlap() because the internal
1202 /// buffer doesn't exist yet.
1203 ///
1204 /// Step 2, if a managed tensor gets resized during inference, it gets a new
1205 /// data_ptr which is not from the buffer. We can tackle this corner case by
1206 /// delaying the deallocation of the managed tensors to after the outputs are no
1207 /// longer used (essentially merging the internal/output buffers into one).
1208 /// Before the merging is implemented, we add another flag `overlap_detected_`
1209 /// to flag any node with overlap detected in Step 1 and do a full memory
1210 /// overlap check if the fast check (by checking memory overlap with internal
1211 /// buffer) fails. There is still a corner case that fails with the added flag.
1212 /// If a resize is triggered at the same time as the op creating an alias at the
1213 /// same time, the current checks would fail to detect the alias.
verify_and_correct_memory_overlap(ProcessedNode & n)1214 void BlockRunner::verify_and_correct_memory_overlap(ProcessedNode& n) {
1215 // The slow check can be removed once the internal/output buffers are merged
1216 if (C10_UNLIKELY(n.check_outputs_for_memory_overlap())) {
1217 if (C10_UNLIKELY(!planner_)) {
1218 // slow check, for first iter only
1219 n.verify_and_correct_memory_overlap();
1220 } else {
1221 bool overlap_detected_with_fast_check = false;
1222 const auto n_outputs = static_cast<uint32_t>(n.outputs().size());
1223 for (auto i : c10::irange(n_outputs)) {
1224 auto& output = n.Output(i);
1225 if (output.isTensor()) {
1226 overlap_detected_with_fast_check |=
1227 fast_check_and_correct_overlap_with(n, output);
1228 } else if (output.isTensorList()) {
1229 auto tensor_list = output.toListRef();
1230 for (auto& ival : tensor_list) {
1231 overlap_detected_with_fast_check |=
1232 fast_check_and_correct_overlap_with(
1233 n,
1234 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1235 const_cast<c10::IValue&>(ival));
1236 }
1237 }
1238 }
1239 if (n.outputs_memory_overlap_detected() &&
1240 !overlap_detected_with_fast_check) {
1241 // slow check. Only run when the fast check fails.
1242 n.verify_and_correct_memory_overlap();
1243 }
1244 }
1245 }
1246 }
1247
fast_check_and_correct_overlap_with(ProcessedNode & n,c10::IValue & tensor_ival)1248 bool BlockRunner::fast_check_and_correct_overlap_with(
1249 ProcessedNode& n,
1250 c10::IValue& tensor_ival) {
1251 auto& tensor = tensor_ival.toTensor();
1252 if (planner_->overlapWithInternalBuffer(tensor.data_ptr())) {
1253 DLOG(INFO) << "Detected alias for node: " << PrintNode(n.node());
1254 tensor_ival = at::native::clone(tensor, std::nullopt);
1255 n.set_outputs_memory_overlap_detected();
1256 return true;
1257 }
1258 return false;
1259 }
1260
~Deallocator()1261 BlockRunner::Deallocator::~Deallocator() {
1262 // Assume cleanup cannot throw.
1263 cleanupImpl();
1264 #ifndef NDEBUG
1265 block_runner_.check_for_memory_leak(/*output_returned*/ false);
1266 #endif
1267 }
1268
cleanupImpl()1269 void BlockRunner::Deallocator::cleanupImpl() {
1270 // MemoryPlanner is created after the first invocation of `run()`. This
1271 // is done intentionally because MemoryPlanner uses `Tensor` sizes of
1272 // the previous `run()` for memory planning of subsequent runs
1273 if (C10_LIKELY(finished_)) {
1274 block_runner_.create_memory_planner();
1275 }
1276
1277 if (C10_LIKELY(block_runner_.planner_)) {
1278 block_runner_.planner_->deallocate();
1279 } else {
1280 // This is the first run, and it didn't finish, so we can't use a
1281 // `MemoryPlanner` to deallocate stuff. Just reset everything manually.
1282 block_runner_.resetMemory();
1283 }
1284 // clean up owning refs of input tensors
1285 block_runner_.clean_up_input_ivalues();
1286 if (C10_UNLIKELY(!finished_)) {
1287 block_runner_.deallocateOutputTensors();
1288 }
1289 }
1290
1291 template <typename IValueList>
run_impl(IValueList && args,const KeywordArgs & kwargs)1292 c10::IValue BlockRunner::run_impl(
1293 IValueList&& args,
1294 const KeywordArgs& kwargs) {
1295 // We assume inference workloads, so we do not need
1296 // autograd. Enabling this is a significant win on dispatcher
1297 // overhead because it saves a round of dispatch for at least some
1298 // functions, such as resize_ and resize_as_.
1299 c10::InferenceMode mode;
1300
1301 {
1302 auto on_exit = Deallocator(*this);
1303
1304 if (planner_) {
1305 DCHECK(!manage_output_tensors_enabled_ || checkOutputTensorMemoryLeaks());
1306 planner_->allocate();
1307 }
1308
1309 set_inputs(std::forward<IValueList>(args), kwargs);
1310
1311 for (auto& n : nodes_) {
1312 // LOG(INFO) << "Running node: " << PrintNode(n.node());
1313 n.run();
1314 // Check for incorrect schema alias info.
1315 verify_and_correct_memory_overlap(n);
1316 }
1317 on_exit.setFinished();
1318 }
1319
1320 // no need to keep references of outputs in static runtime anymore
1321 if (block_info_.num_outputs() > 1) {
1322 return move_outputs_to_tuple(block_info_.num_outputs());
1323 }
1324
1325 DCHECK(check_for_memory_leak(/*output_returned*/ false));
1326
1327 // use move here. Otherwise, clean up outputs_[0] explicitly
1328 return std::move(*outputs_[0]);
1329 }
1330
1331 template <typename IValueList>
run_impl_record_functions(IValueList && args,const KeywordArgs & kwargs)1332 c10::IValue BlockRunner::run_impl_record_functions(
1333 IValueList&& args,
1334 const KeywordArgs& kwargs) {
1335 auto step_callbacks =
1336 at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_MODEL);
1337 if (C10_UNLIKELY(step_callbacks.has_value())) {
1338 at::RecordFunction guard(std::move(*step_callbacks));
1339 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
1340 guard.needsInputs()
1341 ? guard.before(
1342 "forward", c10::ArrayRef<const IValue>(args.data(), args.size()))
1343 : guard.before("forward");
1344
1345 return run_impl(std::forward<IValueList>(args), kwargs);
1346 }
1347 return run_impl(std::forward<IValueList>(args), kwargs);
1348 }
1349
1350 template <typename IValueList>
run_impl_async(IValueList && args,const KeywordArgs & kwargs)1351 c10::intrusive_ptr<c10::ivalue::Future> BlockRunner::run_impl_async(
1352 IValueList&& args,
1353 const KeywordArgs& kwargs) {
1354 // run the graph inline in the caller thread. Async ops will be
1355 // executed on taskLauncher attached to the metadata of ProcessedNodes
1356 c10::IValue output = run_impl(std::forward<IValueList>(args), kwargs);
1357
1358 // If the output is of type future, return it
1359 if (output.isFuture()) {
1360 return output.toFuture();
1361 }
1362
1363 // wrap the output into future, mark completed and return it
1364 TypePtr return_type;
1365 if (block_info_.num_outputs() > 1) {
1366 return_type = TupleType::create(
1367 fmap(outputs(), [](const IValue* v) { return v->type(); }));
1368 } else {
1369 return_type = outputs().at(0)->type();
1370 }
1371 c10::intrusive_ptr<Future> future = c10::make_intrusive<Future>(return_type);
1372 future->markCompleted(output);
1373 return future;
1374 }
1375
1376 template <typename IValueList>
1377 c10::intrusive_ptr<c10::ivalue::Future> BlockRunner::
run_impl_record_functions_async(IValueList && args,const KeywordArgs & kwargs)1378 run_impl_record_functions_async(
1379 IValueList&& args,
1380 const KeywordArgs& kwargs) {
1381 auto step_callbacks =
1382 at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_MODEL);
1383 if (C10_UNLIKELY(step_callbacks.has_value())) {
1384 at::RecordFunction guard(std::move(*step_callbacks));
1385 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
1386 guard.needsInputs()
1387 ? guard.before(
1388 "forward", c10::ArrayRef<const IValue>(args.data(), args.size()))
1389 : guard.before("forward");
1390
1391 return run_impl_async(std::forward<IValueList>(args), kwargs);
1392 }
1393 return run_impl_async(std::forward<IValueList>(args), kwargs);
1394 }
1395
operator ()(const std::vector<c10::IValue> & args,const KeywordArgs & kwargs)1396 c10::IValue BlockRunner::operator()(
1397 const std::vector<c10::IValue>& args,
1398 const KeywordArgs& kwargs) {
1399 #ifdef PYTORCH_DISABLE_NET_PROFILING
1400 return run_impl(args, kwargs);
1401 #else
1402 return run_impl_record_functions(args, kwargs);
1403 #endif
1404 }
1405
operator ()(std::vector<c10::IValue> && args,const KeywordArgs & kwargs)1406 c10::IValue BlockRunner::operator()(
1407 std::vector<c10::IValue>&& args,
1408 const KeywordArgs& kwargs) {
1409 #ifdef PYTORCH_DISABLE_NET_PROFILING
1410 return run_impl(std::move(args), kwargs);
1411 #else
1412 return run_impl_record_functions(std::move(args), kwargs);
1413 #endif
1414 }
1415
runAsync(const std::vector<c10::IValue> & args,const KeywordArgs & kwargs)1416 c10::intrusive_ptr<c10::ivalue::Future> BlockRunner::runAsync(
1417 const std::vector<c10::IValue>& args,
1418 const KeywordArgs& kwargs) {
1419 #ifdef PYTORCH_DISABLE_NET_PROFILING
1420 return run_impl_async(args, kwargs);
1421 #else
1422 return run_impl_record_functions_async(args, kwargs);
1423 #endif
1424 }
1425
runAsync(std::vector<c10::IValue> && args,const KeywordArgs & kwargs)1426 c10::intrusive_ptr<c10::ivalue::Future> BlockRunner::runAsync(
1427 std::vector<c10::IValue>&& args,
1428 const KeywordArgs& kwargs) {
1429 #ifdef PYTORCH_DISABLE_NET_PROFILING
1430 return run_impl_async(std::move(args), kwargs);
1431 #else
1432 return run_impl_record_functions_async(std::move(args), kwargs);
1433 #endif
1434 }
1435
1436 namespace {
1437
generate_latency_json(const std::string & label,double millis)1438 std::string generate_latency_json(const std::string& label, double millis) {
1439 #ifdef FBCODE_CAFFE2
1440 folly::dynamic json = folly::dynamic::object();
1441 json["type"] = label;
1442 json["metric"] = "latency";
1443 json["unit"] = "ms";
1444 json["value"] = millis;
1445 return "PyTorchObserver " + folly::toJson(json);
1446 #else
1447 (void)label;
1448 (void)millis;
1449 return "";
1450 #endif
1451 }
1452
1453 } // namespace
1454
benchmark(const std::vector<std::vector<c10::IValue>> & args_list,const std::vector<KeywordArgs> & kwargs_list,const uint32_t warmup_runs,const uint32_t main_runs,bool print_per_node_time,bool generate_ai_pep_output)1455 void BlockRunner::benchmark(
1456 const std::vector<std::vector<c10::IValue>>& args_list,
1457 const std::vector<KeywordArgs>& kwargs_list,
1458 const uint32_t warmup_runs,
1459 const uint32_t main_runs,
1460 bool print_per_node_time,
1461 bool generate_ai_pep_output) {
1462 TORCH_CHECK(kwargs_list.empty() || args_list.size() == kwargs_list.size());
1463 std::cout << "Input size: " << args_list.size() << '\n';
1464 float time_per_iter =
1465 benchmark_model(args_list, kwargs_list, warmup_runs, main_runs);
1466 std::cout << "Static runtime ms per iter: " << time_per_iter
1467 << ". Iters per second: " << 1000.0 / time_per_iter << '\n';
1468
1469 IndividualMetrics results =
1470 benchmark_individual_ops(args_list, kwargs_list, warmup_runs, main_runs);
1471
1472 if (print_per_node_time) {
1473 const auto num_nodes = static_cast<uint32_t>(nodes_.size());
1474 for (const auto i : c10::irange(num_nodes)) {
1475 const Node* node = nodes_[i].node();
1476 std::cout << "Node #" << i << ": " << results.time_per_node[i]
1477 << " ms/iter, ";
1478 node->print(std::cout, 0, nullptr, false);
1479 }
1480 }
1481
1482 std::vector<std::pair<std::string, double>> time_per_node_type_vec{
1483 results.time_per_node_type.begin(), results.time_per_node_type.end()};
1484 if (args_list.empty()) {
1485 std::sort(
1486 time_per_node_type_vec.begin(),
1487 time_per_node_type_vec.end(),
1488 [&results](auto& left, auto& right) {
1489 return results.instances_per_node_type[left.first] >
1490 results.instances_per_node_type[right.first];
1491 });
1492 } else {
1493 std::sort(
1494 time_per_node_type_vec.begin(),
1495 time_per_node_type_vec.end(),
1496 [](auto& left, auto& right) { return left.second > right.second; });
1497 }
1498 std::cout << "Time per node type:" << '\n';
1499 for (const auto& p : time_per_node_type_vec) {
1500 const std::string& kind = p.first;
1501 const double ms = p.second;
1502 std::cout << std::setw(15) << ms << " ms. " << std::setw(10)
1503 << results.percent_per_node_type[kind] << "%. " << kind << " ("
1504 << results.instances_per_node_type[kind] << " nodes";
1505 if (results.out_nodes.count(kind)) {
1506 std::cout << ", out variant)" << '\n';
1507 } else if (results.native_nodes.count(kind)) {
1508 std::cout << ", native)" << '\n';
1509 } else {
1510 std::cout << ")" << '\n';
1511 }
1512
1513 if (generate_ai_pep_output) {
1514 LOG(INFO) << generate_latency_json(kind, ms);
1515 }
1516 }
1517 if (generate_ai_pep_output) {
1518 LOG(INFO) << generate_latency_json(
1519 "static_runtime_first_iter", results.first_iter_time);
1520 }
1521 std::cout << std::setw(15) << results.total_time << " ms. in Total" << '\n';
1522 std::cout << "BlockRunner setup time: " << results.setup_time << " ms"
1523 << '\n';
1524 std::cout << "Memory allocation time: " << results.memory_alloc_time
1525 << " ms\n";
1526 std::cout << "Memory deallocation time: " << results.memory_dealloc_time
1527 << " ms" << '\n';
1528 std::cout << "Outputs deallocation time: " << results.output_dealloc_time
1529 << " ms" << '\n';
1530 std::cout << "First iter time: " << results.first_iter_time << " ms" << '\n';
1531 std::cout << "Number of operators: " << nodes_.size() << '\n';
1532
1533 if (planner_) {
1534 std::cout << "Total number of managed tensors: "
1535 << planner_->total_num_managed_tensors() << '\n';
1536 std::cout << "Total number of managed output tensors: "
1537 << planner_->total_num_managed_output_tensors() << '\n';
1538 std::cout << "Total number of unmanaged values: "
1539 << planner_->total_num_unmanaged() << '\n';
1540 std::cout << "Number of unmanaged values requiring cleanup: "
1541 << planner_->num_unmanaged_non_scalars() << '\n';
1542 std::cout << "Number of unmanaged values not requiring cleanup: "
1543 << planner_->num_unmanaged_scalars() << '\n';
1544 std::cout << "Total memory managed: " << planner_->total_managed()
1545 << " bytes" << '\n';
1546 if (static_module_.opts().optimize_memory) {
1547 std::cout << "Total number of reused tensors: "
1548 << planner_->total_reused_tensors() << '\n';
1549 }
1550 }
1551
1552 auto unsupported_nodes_count = results.total_nodes_count -
1553 results.out_nodes_count - results.native_nodes.size();
1554 std::cout << "Total number of 'out' variant nodes/total number of nodes: "
1555 << results.out_nodes_count << "/" << results.total_nodes_count
1556 << " ("
1557 << 100.0 * static_cast<float>(results.out_nodes_count) /
1558 static_cast<float>(results.total_nodes_count)
1559 << "%)" << '\n';
1560 std::cout << "Total number of nodes not covered by SR/total number of nodes: "
1561 << unsupported_nodes_count << "/" << results.total_nodes_count
1562 << " ("
1563 << 100.0 * static_cast<float>(unsupported_nodes_count) /
1564 static_cast<float>(results.total_nodes_count)
1565 << "%)" << '\n';
1566
1567 check_for_memory_leak();
1568
1569 #ifndef NDEBUG
1570 KeywordArgs empty_kwargs;
1571 display_nodes(
1572 args_list[0], kwargs_list.size() > 0 ? kwargs_list[0] : empty_kwargs);
1573 #endif
1574 }
1575
benchmark_model(const std::vector<std::vector<c10::IValue>> & args_list,const std::vector<KeywordArgs> & kwargs_list,const unsigned int warmup_runs,const unsigned int main_runs)1576 float BlockRunner::benchmark_model(
1577 const std::vector<std::vector<c10::IValue>>& args_list,
1578 const std::vector<KeywordArgs>& kwargs_list,
1579 const unsigned int warmup_runs,
1580 const unsigned int main_runs) {
1581 TORCH_CHECK(main_runs >= 1);
1582 TORCH_CHECK(kwargs_list.empty() || args_list.size() == kwargs_list.size());
1583
1584 const bool is_kwargs_empty = kwargs_list.empty();
1585 const KeywordArgs empty_kwargs;
1586 for (const auto _n_run : c10::irange(warmup_runs)) {
1587 (void)_n_run; // Suppress unused variable warning
1588 const auto num_args = static_cast<uint32_t>(args_list.size());
1589 for (const auto j : c10::irange(num_args)) {
1590 operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1591 if (manage_output_tensors_enabled_) {
1592 deallocateOutputTensors();
1593 }
1594 }
1595 }
1596 caffe2::Timer timer;
1597 for (const auto _n_run : c10::irange(main_runs)) {
1598 (void)_n_run; // Suppress unused variable warning
1599 const auto num_args = static_cast<uint32_t>(args_list.size());
1600 for (const auto j : c10::irange(num_args)) {
1601 operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1602 if (manage_output_tensors_enabled_) {
1603 deallocateOutputTensors();
1604 }
1605 }
1606 }
1607 float millis = timer.MilliSeconds();
1608 return millis /
1609 (static_cast<float>(main_runs) * static_cast<float>(args_list.size()));
1610 }
1611
display_ivalue(const IValue & iv)1612 static bool display_ivalue(const IValue& iv) {
1613 if (iv.isTensor()) {
1614 std::cout << "Tensor " << iv.toTensor().toString() << " {";
1615 const auto dims = iv.toTensor().sizes();
1616 const auto n_dims = static_cast<uint32_t>(dims.size());
1617 for (const auto i : c10::irange(n_dims)) {
1618 std::cout << iv.toTensor().sizes()[i];
1619 if (n_dims > i + 1) {
1620 std::cout << ", ";
1621 }
1622 }
1623 std::cout << "}\n";
1624 return true;
1625 } else if (iv.isTensorList()) {
1626 std::cout << "TensorList {" << iv.toTensorList().size() << "}\n";
1627 return true;
1628 } else if (iv.isGenericDict()) {
1629 std::cout << "Dict {" << iv.toGenericDict().size() << "}\n";
1630 return true;
1631 } else if (iv.isTuple()) {
1632 std::cout << "Tuple {" << iv.toTupleRef().elements().size() << "}\n";
1633 return true;
1634 } else if (iv.isInt()) {
1635 std::cout << "int {" << iv.toInt() << "}\n";
1636 return true;
1637 } else if (iv.isBool()) {
1638 std::cout << "bool {" << iv.toBool() << "}\n";
1639 return true;
1640 } else if (iv.isDouble()) {
1641 std::cout << "double {" << iv.toDouble() << "}\n";
1642 return true;
1643 }
1644 return false;
1645 }
1646
display_pnode_info(const ProcessedNode & pnode)1647 static void display_pnode_info(const ProcessedNode& pnode) {
1648 pnode.node()->print(std::cout, 0, nullptr, false);
1649 const auto num_inputs = static_cast<uint32_t>(pnode.num_inputs());
1650 for (const auto i : c10::irange(num_inputs)) {
1651 std::cout << "\ti" << i << ": ";
1652 if (!display_ivalue(pnode.Input(i))) {
1653 std::cout << *(pnode.node()->inputs()[i]->type()) << '\n';
1654 }
1655 }
1656 const auto outputs = pnode.outputs();
1657 const auto num_outputs = static_cast<uint32_t>(outputs.size());
1658 for (const auto i : c10::irange(num_outputs)) {
1659 std::cout << "\to" << i << ": ";
1660 if (!display_ivalue(outputs[i])) {
1661 std::cout << *(pnode.node()->outputs()[i]->type()) << '\n';
1662 }
1663 }
1664 }
1665
display_nodes(const std::vector<c10::IValue> & args,const KeywordArgs & kwargs)1666 void BlockRunner::display_nodes(
1667 const std::vector<c10::IValue>& args,
1668 const KeywordArgs& kwargs) {
1669 c10::InferenceMode mode;
1670
1671 auto on_exit = Deallocator(*this);
1672
1673 if (planner_) {
1674 planner_->allocate();
1675 }
1676 set_inputs(args, kwargs);
1677
1678 for (auto& node : nodes_) {
1679 node.run();
1680 display_pnode_info(node);
1681 }
1682 on_exit.setFinished();
1683 }
1684
benchmark_individual_ops(const std::vector<std::vector<c10::IValue>> & args_list,const std::vector<KeywordArgs> & kwargs_list,const uint32_t warmup_runs,const uint32_t main_runs)1685 BlockRunner::IndividualMetrics BlockRunner::benchmark_individual_ops(
1686 const std::vector<std::vector<c10::IValue>>& args_list,
1687 const std::vector<KeywordArgs>& kwargs_list,
1688 const uint32_t warmup_runs,
1689 const uint32_t main_runs) {
1690 TORCH_CHECK(kwargs_list.empty() || args_list.size() == kwargs_list.size());
1691 TORCH_CHECK(warmup_runs >= 1 && main_runs >= 1);
1692
1693 IndividualMetrics results;
1694 results.time_per_node.resize(nodes_.size(), 0);
1695 if (args_list.empty()) {
1696 // When the given input is empty, compute the op statistics from the given
1697 // graph without executing it.
1698 const auto num_nodes = static_cast<uint32_t>(nodes_.size());
1699 for (const auto i : c10::irange(num_nodes)) {
1700 const Node* node = nodes_[i].node();
1701 std::string kind(node->kind().toQualString());
1702 // TODO: Collect op statistics from sub-blocks here.
1703 results.time_per_node[i] = 0;
1704 results.time_per_node_type[kind] = 0;
1705 results.instances_per_node_type[kind]++;
1706 if (nodes_[i].has_out_variant()) {
1707 results.out_nodes.insert(kind);
1708 results.out_nodes_count++;
1709 } else if (nodes_[i].has_native()) {
1710 results.native_nodes.insert(kind);
1711 }
1712 results.total_time += results.time_per_node[i];
1713 }
1714 results.total_nodes_count = nodes_.size();
1715 results.memory_alloc_time = 0;
1716 results.memory_dealloc_time = 0;
1717 results.output_dealloc_time = 0;
1718 for (const auto& p : results.time_per_node_type) {
1719 const std::string& kind = p.first;
1720 results.percent_per_node_type[kind] = 0;
1721 }
1722 return results;
1723 }
1724
1725 const bool is_kwargs_empty = kwargs_list.empty();
1726 const KeywordArgs empty_kwargs;
1727 bool manage_output_tensors = static_module_.opts().manage_output_tensors;
1728 // See comment on above use of InferenceMode for
1729 // explanation.
1730 c10::InferenceMode mode;
1731
1732 // setup time
1733 caffe2::Timer timer;
1734
1735 set_inputs(args_list[0], is_kwargs_empty ? empty_kwargs : kwargs_list[0]);
1736
1737 results.setup_time = timer.MilliSeconds();
1738
1739 // The first iteration profiles each node's output Tensors' sizes and
1740 // initializes the memory planner with the profile information. Following
1741 // iterations just use the already established memory planning.
1742 timer.Start();
1743 operator()(args_list[0], is_kwargs_empty ? empty_kwargs : kwargs_list[0]);
1744 if (manage_output_tensors) {
1745 deallocateOutputTensors();
1746 }
1747 results.first_iter_time = timer.MilliSeconds();
1748
1749 // warmup runs
1750 for (const auto _n_run : c10::irange(warmup_runs)) {
1751 (void)_n_run; // Suppress unused variable warning
1752 const auto num_args = static_cast<uint32_t>(args_list.size());
1753 for (const auto j : c10::irange(num_args)) {
1754 operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1755 if (manage_output_tensors) {
1756 deallocateOutputTensors();
1757 }
1758 }
1759 }
1760
1761 // main runs
1762 for (const auto i : c10::irange(main_runs)) {
1763 (void)i; // Suppress unused variable warning
1764 const auto num_args = static_cast<uint32_t>(args_list.size());
1765 for (const auto j : c10::irange(num_args)) {
1766 set_inputs(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1767
1768 timer.Start();
1769 if (planner_) {
1770 planner_->allocate();
1771 }
1772 float millis = timer.MilliSeconds();
1773 results.memory_alloc_time += millis;
1774 const auto num_nodes = static_cast<uint32_t>(nodes_.size());
1775 for (const auto k : c10::irange<uint32_t>(num_nodes)) {
1776 timer.Start();
1777 nodes_[k].run();
1778 millis = timer.MilliSeconds();
1779 results.time_per_node[k] += millis;
1780 verify_and_correct_memory_overlap(nodes_[k]);
1781 }
1782 timer.Start();
1783 create_memory_planner();
1784 planner_->deallocate();
1785 // clean up owning refs of input tensors
1786 clean_up_input_ivalues();
1787 if (manage_output_tensors) {
1788 deallocateOutputTensors();
1789 }
1790 millis = timer.MilliSeconds();
1791 results.memory_dealloc_time += millis;
1792
1793 timer.Start();
1794 // no need to keep references of outputs in static runtime anymore
1795 c10::IValue output;
1796 if (static_module_.num_outputs() > 1) {
1797 output = move_outputs_to_tuple(static_module_.num_outputs());
1798 }
1799
1800 DCHECK(check_for_memory_leak(/*output_returned*/ false));
1801
1802 // use move here. Otherwise, clean up outputs_[0] explicitly
1803 output = std::move(*outputs_[0]);
1804 // release outputs explicitly to measure the time it takes
1805 output = IValue();
1806 millis = timer.MilliSeconds();
1807 results.output_dealloc_time += millis;
1808 }
1809 }
1810
1811 // post processing
1812 const float num_total_iters =
1813 (static_cast<float>(main_runs) * static_cast<float>(args_list.size()));
1814 const auto num_nodes = static_cast<uint32_t>(nodes_.size());
1815 for (const auto i : c10::irange(num_nodes)) {
1816 const Node* node = nodes_[i].node();
1817 std::string kind = std::string(node->kind().toQualString());
1818 results.time_per_node[i] /= num_total_iters;
1819 results.time_per_node_type[kind] += results.time_per_node[i];
1820 results.instances_per_node_type[kind]++;
1821 if (nodes_[i].has_out_variant()) {
1822 results.out_nodes.insert(kind);
1823 results.out_nodes_count++;
1824 } else if (nodes_[i].has_native()) {
1825 results.native_nodes.insert(kind);
1826 }
1827 results.total_time += results.time_per_node[i];
1828 }
1829 results.total_nodes_count = nodes_.size();
1830 results.memory_alloc_time /= num_total_iters;
1831 results.memory_dealloc_time /= num_total_iters;
1832 results.output_dealloc_time /= num_total_iters;
1833 for (const auto& p : results.time_per_node_type) {
1834 const std::string& kind = p.first;
1835 results.percent_per_node_type[kind] = p.second / results.total_time * 100;
1836 }
1837 return results;
1838 }
1839
check_for_memory_leak(bool output_returned,bool recurse_on_sub_blocks)1840 bool BlockRunner::check_for_memory_leak(
1841 bool output_returned,
1842 bool recurse_on_sub_blocks) {
1843 // check for inputs
1844 const auto num_inputs = static_cast<uint32_t>(block_info_.num_inputs());
1845 for (const auto i : c10::irange(num_inputs)) {
1846 TORCH_CHECK(
1847 values_[i + block_info_.block_inputs_idx()].isNone(),
1848 "Input ",
1849 i,
1850 " was not cleaned up");
1851 }
1852 c10::FastSet<const IValue*> output_ivalues(outputs_.begin(), outputs_.end());
1853 const auto num_nodes = static_cast<uint32_t>(nodes_.size());
1854 for (const auto n : c10::irange(num_nodes)) {
1855 auto& pnode = nodes_[n];
1856 const auto num_outputs = static_cast<uint32_t>(pnode.num_outputs());
1857 for (const auto i : c10::irange(num_outputs)) {
1858 const IValue* ival = &pnode.Output(i);
1859 const Value* val = pnode.node()->output(i);
1860 // subtlety: isManagedOutputTensorValue may give a false
1861 // negative here if an output is an alias of this value, so
1862 // check the actual tensor!
1863 if (planner_ &&
1864 (isManagedOutputTensor(*ival) || isManagedOutputTensorValue(val))) {
1865 // `ival` contains a managed output tensor that the runtime doesn't
1866 // reclaim at the end of an iteration, but the client does so
1867 // by explicitly calling
1868 // `BlockRunner::deallocateOutputTensors`.
1869 continue;
1870 }
1871 const std::string error_msg = "Output " + std::to_string(i) + ", %" +
1872 val->debugName() + " of node " + std::to_string(n) +
1873 " which has kind " + pnode.node()->kind().toQualString() +
1874 " was not cleaned up";
1875 if (output_ivalues.count(ival) == 0) {
1876 // check for intermediates
1877 if (!ival->isNone()) {
1878 TORCH_CHECK(
1879 ival->isTensor() ||
1880 block_info_.node_is_optimizable_container_type(
1881 pnode.node()) ||
1882 doesNotHeapAllocateWhenStoredInIValue(*val->type()),
1883 error_msg);
1884 if (ival->isTensor()) {
1885 const auto& t = ival->toTensor();
1886 if (t.defined()) {
1887 auto* storage_impl = t.storage().unsafeGetStorageImpl();
1888 TORCH_CHECK(
1889 storage_impl->data() == nullptr ||
1890 (planner_ &&
1891 planner_->isManagedStorageImpl(storage_impl)),
1892 error_msg);
1893 }
1894 }
1895 }
1896 } else {
1897 // check for outputs
1898 if (output_returned) {
1899 TORCH_CHECK(ival->isNone(), error_msg);
1900 }
1901 }
1902 }
1903 auto* metadata = pnode.metadata();
1904 if (recurse_on_sub_blocks && metadata) {
1905 auto& block_runners = metadata->block_runners();
1906 for (auto& block_runner : block_runners) {
1907 block_runner.check_for_memory_leak(
1908 output_returned, recurse_on_sub_blocks);
1909 }
1910 }
1911 }
1912 VLOG(1) << "Finished checking for memory leak";
1913 return true;
1914 }
1915
deallocateOutputTensors()1916 void BlockRunner::deallocateOutputTensors() {
1917 if (!static_module_.opts().manage_output_tensors) {
1918 TORCH_CHECK(
1919 !planner_ || planner_->numOutputBufferBytes() == 0,
1920 "manage_output_tensors is disabled, but output tensor buffer is not empty.");
1921 return;
1922 }
1923 if (planner_) {
1924 planner_->deallocateOutputTensors();
1925 DCHECK(checkOutputTensorMemoryLeaks());
1926 }
1927 }
1928
checkOutputTensorMemoryLeaks()1929 bool BlockRunner::checkOutputTensorMemoryLeaks() {
1930 if (!static_module_.opts().manage_output_tensors || !planner_) {
1931 return true;
1932 }
1933 const auto num_nodes = static_cast<uint32_t>(nodes_.size());
1934 for (const auto n : c10::irange(num_nodes)) {
1935 auto& pnode = nodes_[n];
1936 const auto num_outputs = static_cast<uint32_t>(pnode.num_outputs());
1937 for (const auto i : c10::irange(num_outputs)) {
1938 const IValue* ival = &pnode.Output(i);
1939 const Value* val = pnode.node()->output(i);
1940 if (!isManagedOutputTensorValue(val) || !ival->isTensor()) {
1941 // ival can not be a tensor if it's being managed by ops like
1942 // to_maybe_copy_out; see ReplaceWithMaybeCopy for details.
1943 continue;
1944 }
1945 const auto& t = ival->toTensor();
1946 if (t.defined()) {
1947 auto* storage_impl = t.storage().unsafeGetStorageImpl();
1948 const std::string error_msg = "Output " + std::to_string(i) + ", %" +
1949 val->debugName() + " of node " + std::to_string(n) +
1950 " was not cleaned up";
1951 TORCH_CHECK(storage_impl->data() == nullptr, error_msg);
1952 }
1953 }
1954 }
1955 VLOG(1) << "Finished checking for memory leak from output tensors";
1956 return true;
1957 }
1958
isManagedOutputTensor(const IValue & ivalue) const1959 bool BlockRunner::isManagedOutputTensor(const IValue& ivalue) const {
1960 return planner_ && planner_->isManagedOutputTensor(ivalue);
1961 }
1962
isManagedOutputTensorValue(const Value * value) const1963 bool BlockRunner::isManagedOutputTensorValue(const Value* value) const {
1964 // It's possible that manage_output_tensors_ was disabled after initializing
1965 // managed_output_tensor_values, so we have to check that flag here.
1966 if (!planner_ || !manage_output_tensors_enabled_) {
1967 return false;
1968 }
1969 const auto& managed_outputs = block_info_.managed_output_tensor_values();
1970 return managed_outputs.find(value) != managed_outputs.end();
1971 }
1972
disableManageOutputTensors()1973 void BlockRunner::disableManageOutputTensors() {
1974 if (!manage_output_tensors_enabled_) {
1975 return;
1976 }
1977 manage_output_tensors_enabled_ = false;
1978 if (!planner_) {
1979 return;
1980 }
1981 // Reset all IValues and destruct planner_ so that it can be reconstructed in
1982 // the next run.
1983 for (auto& n : nodes_) {
1984 const auto num_outputs = static_cast<uint32_t>(n.outputs().size());
1985 for (const auto i : c10::irange(num_outputs)) {
1986 n.Output(i) = IValue();
1987 }
1988 }
1989 planner_.reset();
1990 }
1991
ProcessedFunction(Node * node,bool enable_out_variant,bool check_memory_overlap)1992 ProcessedFunction::ProcessedFunction(
1993 Node* node,
1994 bool enable_out_variant,
1995 bool check_memory_overlap)
1996 : check_memory_overlap_(check_memory_overlap),
1997 num_outputs_(node->outputs().size()) {
1998 if (enable_out_variant) {
1999 f_ = getOutOfPlaceOperation(node);
2000 if (f_) {
2001 kind_ = ProcessedFunction::Kind::kOutVariant;
2002 // do not check memory overlap for out variants
2003 check_memory_overlap_ = false;
2004 VLOG(1) << "Switch to out variant for node: " << PrintNode(node);
2005 return;
2006 }
2007 }
2008 {
2009 f_ = getNativeOperation(node);
2010 if (f_) {
2011 kind_ = ProcessedFunction::Kind::kNativeFunction;
2012 #ifdef NDEBUG
2013 // skip this check in opt mode because these ops are better vetted
2014 check_memory_overlap_ = false;
2015 #endif
2016 VLOG(1) << "Switch to native impl for node: " << PrintNode(node);
2017 return;
2018 }
2019 }
2020 {
2021 const Operator& op = node->getOperator();
2022 f_ = [node_op = op.getOperation(node),
2023 has_var_args = hasVarArgs(node)](ProcessedNode* pnode) mutable {
2024 std::vector<IValue> stack;
2025 const auto size = static_cast<uint32_t>(pnode->num_inputs());
2026 stack.reserve(size + has_var_args);
2027 for (const auto i : c10::irange(size)) {
2028 stack.emplace_back(pnode->Input(i));
2029 }
2030 // Need to store the number of inputs in stack for variadic ops.
2031 if (has_var_args) {
2032 stack.emplace_back(static_cast<int>(size));
2033 }
2034 node_op(stack);
2035 const auto num_outputs = static_cast<uint32_t>(pnode->num_outputs());
2036 TORCH_DCHECK_EQ(stack.size(), num_outputs);
2037 for (const auto i : c10::irange(num_outputs)) {
2038 pnode->Output(i) = std::move(stack[i]);
2039 }
2040 };
2041 kind_ = ProcessedFunction::Kind::kInterpreterFallback;
2042 VLOG(1) << "Fallback interpreter for node: " << PrintNode(node);
2043 }
2044 }
2045
StaticNodeInfo(Node * node,ProcessedFunction * fn,ProcessedNodeInputs inputs,uint16_t outputs_offset)2046 StaticNodeInfo::StaticNodeInfo(
2047 Node* node,
2048 ProcessedFunction* fn,
2049 ProcessedNodeInputs inputs,
2050 uint16_t outputs_offset)
2051 : node_(node),
2052 fn_(fn),
2053 inputs_(std::move(inputs)),
2054 outputs_offset_(outputs_offset) {
2055 TORCH_CHECK(
2056 num_outputs() == node->outputs().size(),
2057 "Node ",
2058 node->kind().toQualString(),
2059 " has ",
2060 std::to_string(num_outputs()),
2061 " outputs, expected ",
2062 std::to_string(node->outputs().size()));
2063 }
2064
inputs_ivalue_vec() const2065 std::vector<IValue> ProcessedNode::inputs_ivalue_vec() const {
2066 std::vector<IValue> result;
2067 const auto num_inputs = static_cast<uint32_t>(inputs_.size());
2068 result.reserve(num_inputs);
2069
2070 for (const auto idx : c10::irange(num_inputs)) {
2071 result.emplace_back(Input(idx));
2072 }
2073 return result;
2074 }
2075
run()2076 void ProcessedNode::run() {
2077 #ifdef FBCODE_CAFFE2
2078 SROperatorObserver::onStart(node());
2079 #endif
2080 #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
2081 auto step_callbacks =
2082 at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_OP);
2083 if (C10_UNLIKELY(step_callbacks.has_value())) {
2084 at::RecordFunction guard(std::move(*step_callbacks));
2085 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
2086 if (guard.needsInputs()) {
2087 const auto inputs = inputs_ivalue_vec();
2088 guard.before(
2089 get_op_name(),
2090 c10::ArrayRef<const IValue>(inputs.data(), inputs.size()));
2091 } else {
2092 guard.before(get_op_name());
2093 }
2094 if (has_out_variant()) {
2095 guard._setStaticRuntimeOutVariant();
2096 }
2097
2098 fn_->run(this);
2099 } else {
2100 fn_->run(this);
2101 }
2102 #else
2103 fn_->run(this);
2104 #endif
2105 #ifndef NDEBUG
2106 if (FLAGS_static_runtime_disable_debug_memory_overlap_check) {
2107 // run check but do not enforce
2108 verify_no_memory_overlap();
2109 } else {
2110 DCHECK(verify_no_memory_overlap());
2111 }
2112 #endif
2113 #ifdef FBCODE_CAFFE2
2114 SROperatorObserver::onEnd(node());
2115 #endif
2116 }
2117
checkNoMemoryOverlap(const at::Tensor & a,const at::Tensor & b)2118 static bool checkNoMemoryOverlap(const at::Tensor& a, const at::Tensor& b) {
2119 at::MemOverlapStatus status = at::get_overlap_status(a, b);
2120 if (status == at::MemOverlapStatus::Full ||
2121 status == at::MemOverlapStatus::Partial) {
2122 return false;
2123 }
2124 if (status == at::MemOverlapStatus::TooHard) {
2125 VLOG(1) << "Detected TOO_HARD memory overlap status";
2126 }
2127 return true;
2128 }
2129
verify_no_memory_overlap(bool force_check) const2130 bool ProcessedNode::verify_no_memory_overlap(bool force_check) const {
2131 const static std::array<c10::Symbol, 7> special_case_ops = {
2132 fromQualString("prim::TypeCheck"),
2133 fromQualString("prim::IfThenElse"),
2134 fromQualString("static_runtime::select_tensor"),
2135 fromQualString("static_runtime::VarTupleUnpack"),
2136 fromQualString("static_runtime::dict_unpack"),
2137 fromQualString("static_runtime::fused_split_and_squeeze"),
2138 fromQualString("static_runtime::create_owned_ref")};
2139 if (!force_check &&
2140 std::find(
2141 begin(special_case_ops), end(special_case_ops), node()->kind()) !=
2142 end(special_case_ops)) {
2143 return true;
2144 }
2145
2146 return verify_outputs_dont_overlap_each_other() &&
2147 verify_inputs_dont_overlap_outputs(force_check);
2148 }
2149
verify_outputs_dont_overlap_each_other() const2150 bool ProcessedNode::verify_outputs_dont_overlap_each_other() const {
2151 const auto n_outputs = static_cast<uint32_t>(num_outputs());
2152 for (const auto i : c10::irange(n_outputs)) {
2153 if (!Output(i).isTensor()) {
2154 continue;
2155 }
2156 const auto& out0_t = Output(i).toTensor();
2157 for (const auto j : c10::irange(i + 1, n_outputs)) {
2158 if (!Output(j).isTensor()) {
2159 continue;
2160 }
2161 const auto& out1_t = Output(j).toTensor();
2162 if (!checkNoMemoryOverlap(out0_t, out1_t)) {
2163 LOG(INFO) << "Node output " << i << " overlaps with output " << j
2164 << ", " << PrintNode(node_);
2165 return false;
2166 }
2167 }
2168 }
2169 return true;
2170 }
2171
verify_inputs_dont_overlap_outputs(bool force_check) const2172 bool ProcessedNode::verify_inputs_dont_overlap_outputs(bool force_check) const {
2173 auto schema = node()->maybeSchema();
2174 // skip memory overlap check for mutable or view ops with only one output
2175 bool skip_check = !schema ||
2176 ((schema->is_mutable() || !fn_->checkMemoryOverlap()) &&
2177 num_outputs() == 1);
2178 if (!schema || (!force_check && skip_check)) {
2179 if (!schema) {
2180 VLOG(2) << "Detected that op schema is null";
2181 return true;
2182 }
2183 VLOG(2) << "schema->is_mutable: " << schema->is_mutable()
2184 << ", fn_->checkMemoryOverlap: " << fn_->checkMemoryOverlap()
2185 << ", num_outputs_: " << num_outputs();
2186 return true;
2187 }
2188 const auto n_inputs = static_cast<uint32_t>(inputs_.size());
2189 const auto n_outputs = static_cast<uint32_t>(num_outputs());
2190 for (const auto i : c10::irange<uint32_t>(n_inputs)) {
2191 const IValue* in = &Input(i);
2192 if (!in->isTensor()) {
2193 continue;
2194 }
2195 const auto& in_t = in->toTensor();
2196 for (const auto j : c10::irange(n_outputs)) {
2197 const IValue& out = Output(j);
2198 if (!out.isTensor()) {
2199 continue;
2200 }
2201 const auto& out_t = out.toTensor();
2202 if (!checkNoMemoryOverlap(in_t, out_t)) {
2203 LOG(INFO) << "Node input " << i << " overlaps with output " << j << ", "
2204 << PrintNode(node_);
2205 LOG(INFO) << *schema;
2206 return false;
2207 }
2208 }
2209 }
2210 return true;
2211 }
2212
check_and_correct_overlap_with(const at::Tensor & input,c10::IValue & output_ival)2213 bool ProcessedNode::check_and_correct_overlap_with(
2214 const at::Tensor& input,
2215 c10::IValue& output_ival) {
2216 auto& tensor = output_ival.toTensor();
2217 if (!checkNoMemoryOverlap(input, tensor)) {
2218 DLOG(INFO) << "Detected alias for node: " << PrintNode(node());
2219 output_ival = at::native::clone(tensor, std::nullopt);
2220 set_outputs_memory_overlap_detected();
2221 return true;
2222 }
2223 return false;
2224 }
2225
verify_and_correct_memory_overlap()2226 void ProcessedNode::verify_and_correct_memory_overlap() {
2227 const auto n_inputs = static_cast<uint32_t>(inputs_.size());
2228 const auto n_outputs = static_cast<uint32_t>(num_outputs());
2229 for (const auto i : c10::irange(n_inputs)) {
2230 const IValue& in = Input(i);
2231 if (!in.isTensor()) {
2232 continue;
2233 }
2234 const auto& in_t = in.toTensor();
2235 for (const auto j : c10::irange(n_outputs)) {
2236 auto& output = Output(j);
2237 if (output.isTensor()) {
2238 check_and_correct_overlap_with(in_t, output);
2239 } else if (output.isTensorList()) {
2240 auto tensors = output.toListRef();
2241 for (const auto& ival : tensors) {
2242 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
2243 check_and_correct_overlap_with(in_t, const_cast<c10::IValue&>(ival));
2244 }
2245 #ifdef FBCODE_CAFFE2
2246 if (outputs_memory_overlap_detected()) {
2247 LOG_EVERY_MS(WARNING, 60000)
2248 << "Detected alias for node: " << PrintNode(node());
2249 }
2250 #endif
2251 }
2252 }
2253 }
2254 }
2255
StaticRuntime(const StaticModule & sm)2256 StaticRuntime::StaticRuntime(const StaticModule& sm)
2257 : async_task_launcher_(at::launch), values_(sm.value_buffer_size()) {
2258 std::copy(sm.constants().begin(), sm.constants().end(), values_.data());
2259 // default task launcher set to inter-op thread pool
2260
2261 block_ = std::make_unique<BlockRunner>(
2262 sm,
2263 values_.data(),
2264 sm.root_block(),
2265 &async_task_launcher_,
2266 true /*is_root_block*/);
2267 }
2268
operator ()(const std::vector<c10::IValue> & args,const KeywordArgs & kwargs)2269 c10::IValue StaticRuntime::operator()(
2270 const std::vector<c10::IValue>& args,
2271 const KeywordArgs& kwargs) {
2272 return (*block_)(args, kwargs);
2273 }
2274
operator ()(std::vector<c10::IValue> && args,const KeywordArgs & kwargs)2275 c10::IValue StaticRuntime::operator()(
2276 std::vector<c10::IValue>&& args,
2277 const KeywordArgs& kwargs) {
2278 return (*block_)(std::move(args), kwargs);
2279 }
2280
runAsync(const std::vector<c10::IValue> & args,const KeywordArgs & kwargs,torch::jit::TaskLauncher taskLauncher)2281 c10::intrusive_ptr<c10::ivalue::Future> StaticRuntime::runAsync(
2282 const std::vector<c10::IValue>& args,
2283 const KeywordArgs& kwargs,
2284 torch::jit::TaskLauncher taskLauncher) {
2285 async_task_launcher_ = std::move(taskLauncher);
2286 return block_->runAsync(args, kwargs);
2287 }
2288
runAsync(std::vector<c10::IValue> && args,const KeywordArgs & kwargs,torch::jit::TaskLauncher taskLauncher)2289 c10::intrusive_ptr<c10::ivalue::Future> StaticRuntime::runAsync(
2290 std::vector<c10::IValue>&& args,
2291 const KeywordArgs& kwargs,
2292 torch::jit::TaskLauncher taskLauncher) {
2293 async_task_launcher_ = std::move(taskLauncher);
2294 return block_->runAsync(std::move(args), kwargs);
2295 }
2296
check_for_memory_leak(bool output_returned)2297 bool StaticRuntime::check_for_memory_leak(bool output_returned) {
2298 return block_->check_for_memory_leak(
2299 output_returned, /* recurse_on_sub_blocks */ true);
2300 }
2301
checkOutputTensorMemoryLeaks()2302 bool StaticRuntime::checkOutputTensorMemoryLeaks() {
2303 return block_->checkOutputTensorMemoryLeaks();
2304 }
2305
deallocateOutputTensors()2306 void StaticRuntime::deallocateOutputTensors() {
2307 block_->deallocateOutputTensors();
2308 }
2309
isManagedOutputTensor(const IValue & ivalue) const2310 bool StaticRuntime::isManagedOutputTensor(const IValue& ivalue) const {
2311 return block_->isManagedOutputTensor(ivalue);
2312 }
2313
disableManageOutputTensors()2314 void StaticRuntime::disableManageOutputTensors() {
2315 block_->disableManageOutputTensors();
2316 }
2317
get_memory_planner() const2318 const MemoryPlanner* StaticRuntime::get_memory_planner() const {
2319 return block_->get_memory_planner();
2320 }
2321
2322 } // namespace torch::jit
2323