1 #pragma once
2 #include <ATen/core/ivalue.h>
3 #include <ATen/core/symbol.h>
4 #include <c10/core/CPUAllocator.h>
5 #include <c10/macros/Macros.h>
6 #include <c10/util/ArrayRef.h>
7 #include <c10/util/FbcodeMaps.h>
8 #include <torch/csrc/jit/api/module.h>
9 #include <torch/csrc/jit/ir/graph_node_list.h>
10 #include <torch/csrc/jit/ir/ir.h>
11 #include <torch/csrc/jit/passes/constant_propagation.h>
12 #include <torch/csrc/jit/passes/freeze_module.h>
13 #include <torch/csrc/jit/passes/inliner.h>
14 #include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
15 #include <torch/custom_class.h>
16 #include <limits>
17
18 #ifdef FBCODE_CAFFE2
19 #include <folly/container/F14Map.h>
20 #include <folly/container/F14Set.h>
21 #endif
22
23 namespace torch::jit {
24
25 TORCH_API bool canEnableStaticRuntime(
26 const std::shared_ptr<torch::jit::Graph>& graph);
27
28 TORCH_API std::string dumpValueSet(
29 const c10::FastSet<const Value*>& value_set,
30 const char* set_name = "");
31
doesNotHeapAllocateWhenStoredInIValue(const Type & type)32 TORCH_API inline bool doesNotHeapAllocateWhenStoredInIValue(const Type& type) {
33 switch (type.kind()) {
34 // NOTE: NumberType may allocate because it includes complex.
35 case TypeKind::NoneType:
36 case TypeKind::IntType:
37 case TypeKind::FloatType:
38 case TypeKind::BoolType:
39 case TypeKind::DeviceObjType:
40 case TypeKind::StreamObjType:
41 return true;
42 default:
43 return false;
44 }
45 }
46
getStaticRuntimeMetadataSymbol()47 TORCH_API inline c10::Symbol getStaticRuntimeMetadataSymbol() {
48 return Symbol::attr("static_runtime::metadata");
49 }
50
borrowsOutputs(c10::Symbol kind)51 TORCH_API inline bool borrowsOutputs(c10::Symbol kind) {
52 static const std::array<c10::Symbol, 4> symbols_with_borrowed_outputs = {
53 c10::Symbol::fromQualString("static_runtime::select_tensor"),
54 c10::Symbol::fromQualString("static_runtime::dict_unpack"),
55 c10::Symbol::fromQualString("static_runtime::VarTupleUnpack"),
56 c10::Symbol::fromQualString("prim::IfThenElse"),
57 };
58 return std::find(
59 symbols_with_borrowed_outputs.begin(),
60 symbols_with_borrowed_outputs.end(),
61 kind) != symbols_with_borrowed_outputs.end();
62 }
63
64 // Group values used by `graph` into three categories:
65 //
66 // - output_aliases:
67 // values that are either outputs or contain aliases of outputs
68 // - external_aliases:
69 // values that are inputs, constants, or their aliases.
70 // The output aliases that end up here are as a result of aliasDb failing to
71 // recognize them as outputs due to collection object (e.g., Tuple) aliasing
72 // inputs.
73 // Values that dont't show up in output_aliases or external_aliases are created
74 // and consumed within the graph.
75 class ValueGroup {
76 public:
77 explicit ValueGroup() = default;
78 void init(const Block& block, const AliasDb& db);
79
isExternalAlias(const Value * value)80 bool isExternalAlias(const Value* value) const {
81 return external_aliases_.find(value) != external_aliases_.end();
82 }
83
isOutputAlias(const Value * value)84 bool isOutputAlias(const Value* value) const {
85 return output_aliases_.find(value) != output_aliases_.end();
86 }
87
isAlwaysAlive(const Value * value)88 bool isAlwaysAlive(const Value* value) const {
89 return isExternalAlias(value) || isOutputAlias(value);
90 }
91
toString()92 std::string toString() const {
93 return c10::str(
94 dumpValueSet(output_aliases_, "ValueGroup::output_aliases_"),
95 "\n",
96 dumpValueSet(external_aliases_, "ValueGroup::external_aliases_"));
97 }
98
99 private:
100 c10::FastSet<const Value*> output_aliases_;
101 c10::FastSet<const Value*> external_aliases_;
102 };
103
104 class TORCH_API ManagedTensorRanges {
105 public:
106 ManagedTensorRanges() = default;
107 ManagedTensorRanges(
108 Block& block,
109 const AliasDb& alias_db,
110 const c10::FastSet<const Value*>& managed_tensor_values);
111
112 // If true, then this node is the last use of at least one
113 // managed tensor. availableTensorValuesAfterNode(node) will return a vector
114 // of the managed tensors that are available for re-use
115 // in the nodes following this one.
116 bool nodeFreesManagedTensors(Node* node) const;
117 const std::vector<const Value*>& availableTensorValuesAfterNode(
118 Node* node) const;
119
120 // For testing. True if v1 and v2 are both mutable types and have lifetimes
121 // that overlap.
122 bool lifetimesOverlap(const Value* v1, const Value* v2) const;
123
124 private:
125 struct Lifetime {
LifetimeLifetime126 Lifetime(size_t start_, size_t end_) : start(start_), end(end_) {}
127 size_t start;
128 size_t end;
129 };
130
131 // Returns nullptr if we are not tracking the lifetime of value
132 Lifetime* getLifetime(const Value* value);
133 const Lifetime* getLifetime(const Value* value) const;
134 // Collect all values in the input that have tracked lifetimes.
135 // A value's lifetime may not be tracked if it is a graph input
136 // or immutable type (containers with at least one mutable
137 // type are mutable)
138 std::vector<const Value*> collectValuesWithTrackedLifetimes(
139 at::ArrayRef<const Value*> values);
140
141 // Maps Node* to the set of managed tensors that are now available
142 // for re-use after this node.
143 c10::FastMap<Node*, std::vector<const Value*>> node_to_newly_free_tensors_{};
144 // Maps each Value* to its lifetime (start node index, end node index)
145 c10::FastMap<const Value*, Lifetime> value_lifetimes_{};
146 };
147
148 struct TORCH_API StaticModuleOptions {
149 // enabling out variant allows Static Runtime to do memory planning
150 bool enable_out_variant{true};
151 // to reuse tensor storage for tensors whose live-range do not overlap to
152 // reduce memory footprint (enable_out_variant must be true)
153 bool optimize_memory{true};
154 // to batch allocate tensor storage for output tensors of the
155 // graph, where storage is deallocated outside static runtime
156 // (enable_out_variant must be true)
157 bool manage_output_tensors{false};
158 // Gates the ReplaceWithCopy pass, which replaces ops that
159 // sometimes alias their outputs with out variants that
160 // always copy (so the output may participate in memory planning).
161 // Since replacing with copies is done after TensorExpr fusion, the
162 // resulting graph does not conform to the assumptions made in the fuser.
163 // So, even if this flag is turned on, the ReplaceWithCopy pass will not
164 // be executed if TensorExpr fusion is enabled.
165 bool use_copy_variants{true};
166 // Gates the ReplaceWithMaybeCopy pass, which replaces ops that
167 // sometimes alias their outputs with subgraphs that include an out
168 // variant.
169 // For the same reason as `use_copy_variants`, the ReplaceWithMaybeCopy pass
170 // will not be executed if TensorExpr fusion is enabled, even if this flag
171 // is turned on.
172 bool use_maybe_copy_variants{true};
173 // enable TensorExpr fusion of ops at model loading time
174 bool enable_tensorexpr_fusion{false};
175 };
176
177 /*
178 Responsible for plugging StaticRuntime metadata onto the
179 IR nodes. StaticRuntimeMetdata extends CustomClassHolder
180 which can be casted to IValue and attached to IR node.
181 This is needed to pass parent graph metadata to forked
182 graph in presence of prim::fork operator
183 */
184 class TORCH_API StaticRuntimeMetadata : public torch::CustomClassHolder {
185 public:
StaticRuntimeMetadata(const StaticModuleOptions & opts)186 explicit StaticRuntimeMetadata(const StaticModuleOptions& opts)
187 : opts_(opts) {}
188
get_opts()189 const StaticModuleOptions& get_opts() {
190 return opts_;
191 }
192
193 private:
194 StaticModuleOptions opts_;
195 };
196
197 /// The static runime supports two execution modes.
198 ///
199 /// Mode 1: single-threaded with no parallelism except for intra-op parallelism
200 /// For this mode, you can do either:
201 /// @code
202 /// // m is a TorchScript module
203 /// auto module = StaticModule(m, opts);
204 /// auto output = module(args, kwargs);
205 /// @endcode
206 ///
207 /// or
208 ///
209 /// @code
210 /// // g is the TorchScript graph
211 /// auto module = StaticModule(g, opts);
212 /// auto output = module(args, kwargs);
213 /// @endcode
214 ///
215 /// Mode 2: similar to data parallelism, run the same model for different inputs
216 /// on different threads at the same time.
217 /// You should have one StaticModule per model, and one StaticRuntime instance
218 /// per running thread. To avoiding creating StaticRuntimes on the fly, use a
219 /// synchronized stack (i.e. boost::lockfree::stack) to cache all the
220 /// StaticRuntime instances in your code.
221 /// @code
222 /// // initialization
223 /// auto module = std::make_shared<StaticModule>(m, opts);
224 ///
225 /// // 128 is good for most cases. Pick a number that works for you
226 /// boost::lockfree::stack<std::shared_ptr<StaticRuntime>,
227 /// boost::lockfree::fixed_sized<true>> pool(128);
228 ///
229 /// // inference
230 /// std::shared_ptr<StaticRuntime> runtime = nullptr;
231 /// pool.pop(runtime);
232 /// if (!runtime) {
233 /// // holds a reference to the underlying module
234 /// // but does its own memory management
235 /// runtime = std::make_shared<StaticRuntime>(*module);
236 /// }
237 /// auto output = runtime(args, kwargs);
238 /// pool.push(runtime);
239 /// @endcode
240 ///
241 class MemoryPlanner;
242 class StaticNodeInfo;
243 class ProcessedNode;
244 class StaticRuntime;
245
246 using SROperator = std::function<void(ProcessedNode*)>;
247
248 #ifdef FBCODE_CAFFE2
249 struct TORCH_API SROperatorObserver {
250 using OperatorCallback = void (*)(const Node*);
251 OperatorCallback startCb = nullptr;
252 OperatorCallback endCb = nullptr;
253
254 static void setCurrentThreadObserver(SROperatorObserver* observer);
255 static SROperatorObserver* getCurrentThreadObserver();
256 static void onStart(const Node* name);
257 static void onEnd(const Node* name);
258 };
259 #endif
260
261 class TORCH_API ProcessedFunction {
262 public:
263 ProcessedFunction(
264 Node* node,
265 bool enable_out_variant,
266 bool check_memory_overlap);
267
268 enum class Kind : uint8_t {
269 kOutVariant,
270 kNativeFunction,
271 kInterpreterFallback,
272 };
273
run(ProcessedNode * pnode)274 void run(ProcessedNode* pnode) const {
275 return f_(pnode);
276 }
277
kind()278 Kind kind() const {
279 return kind_;
280 }
281
checkMemoryOverlap()282 bool checkMemoryOverlap() const {
283 return check_memory_overlap_;
284 }
285
num_outputs()286 size_t num_outputs() const {
287 return num_outputs_;
288 }
289
290 private:
291 SROperator f_;
292 Kind kind_{ProcessedFunction::Kind::kOutVariant};
293 bool check_memory_overlap_{false};
294 size_t num_outputs_{0};
295 };
296
297 // A `BlockInfo` instance stores all of the shared state that each
298 // `BlockRunner` will need to access. Most of this information is
299 // read-only and shared between threads.
300 // - Each `BlockInfo` corresponds to one block in the graph.
301 // - Each `BlockInfo` may be used by multiple block runners (when there are many
302 // threads).
303 // - All of the `BlockInfo`s are stored in a vector in the `StaticModule` and
304 // are initialized during `StaticModule` construction.
305 // - Most of the information stored is used to initialize the block's memory
306 // planner.
307 class BlockInfo {
308 public:
309 BlockInfo(uint32_t input_idx, Block& block);
310
311 void set_nodes(
312 std::vector<StaticNodeInfo> nodes,
313 const c10::FastMap<Node*, bool>& node_has_out_variant);
314
nodes()315 const std::vector<StaticNodeInfo>& nodes() const {
316 return nodes_;
317 }
318
319 size_t num_nodes() const;
320
num_inputs()321 size_t num_inputs() const {
322 return block_.inputs().size();
323 }
324
num_outputs()325 size_t num_outputs() const {
326 return block_.outputs().size();
327 }
328
node_ptrs()329 graph_node_list node_ptrs() const {
330 return block_.nodes();
331 }
332
set_output_indices(std::vector<uint16_t> indices)333 void set_output_indices(std::vector<uint16_t> indices) {
334 output_indices_ = std::move(indices);
335 }
336
block_output_indices()337 const std::vector<uint16_t>& block_output_indices() const {
338 return output_indices_;
339 }
340
block_inputs_idx()341 auto block_inputs_idx() const {
342 return input_idx_;
343 }
344
node_is_optimizable_container_type(const Node * node)345 bool node_is_optimizable_container_type(const Node* node) const {
346 return node_is_optimizable_container_type_.find(node) !=
347 node_is_optimizable_container_type_.end();
348 }
349
value_is_managed_tensor(const Value * value)350 bool value_is_managed_tensor(const Value* value) const {
351 return managed_tensor_values_.find(value) != managed_tensor_values_.end();
352 }
353
value_is_leaked_container(const Value * value)354 bool value_is_leaked_container(const Value* value) const {
355 return leaked_values_.find(value) != leaked_values_.end();
356 }
357
value_group()358 const ValueGroup& value_group() const {
359 return value_group_;
360 }
361
managed_tensor_ranges()362 const ManagedTensorRanges& managed_tensor_ranges() const {
363 return managed_tensor_ranges_;
364 }
365
init_value_group(const AliasDb & alias_db)366 void init_value_group(const AliasDb& alias_db) {
367 value_group_.init(block_, alias_db);
368 }
369
370 void prepare_for_memory_planner(
371 const AliasDb& alias_db,
372 const StaticModuleOptions& opt);
373
managed_output_tensor_values()374 const auto& managed_output_tensor_values() const {
375 return managed_output_tensor_values_;
376 }
377
managed_tensor_values()378 const auto& managed_tensor_values() const {
379 return managed_tensor_values_;
380 }
381
leaked_values()382 const auto& leaked_values() const {
383 return leaked_values_;
384 }
385
386 private:
387 std::vector<StaticNodeInfo> nodes_;
388
389 ValueGroup value_group_;
390
391 c10::FastSet<const Node*> node_is_optimizable_container_type_;
392 c10::FastSet<const Value*> managed_tensor_values_;
393 c10::FastSet<const Value*> managed_output_tensor_values_;
394 c10::FastSet<const Value*> leaked_values_;
395
396 ManagedTensorRanges managed_tensor_ranges_{};
397
398 // The index of this block's inputs in the shared values_ array.
399 const uint16_t input_idx_;
400 // The indices of this block's outputs in the shared values_ array.
401 std::vector<uint16_t> output_indices_;
402 Block& block_;
403 };
404
405 class TORCH_API StaticModule {
406 public:
407 explicit StaticModule(
408 const std::shared_ptr<torch::jit::Graph>& g,
409 const StaticModuleOptions& opts = StaticModuleOptions(),
410 std::vector<IValue> sample_inputs = {});
411
412 explicit StaticModule(
413 const torch::jit::Module& m,
414 bool is_frozen = false,
415 const StaticModuleOptions& opts = StaticModuleOptions(),
416 std::vector<IValue> sample_inputs = {});
417
418 private:
419 explicit StaticModule(
420 std::pair<std::shared_ptr<torch::jit::Graph>, std::optional<Module>>
421 graph_and_module,
422 const StaticModuleOptions& opts);
423
424 public:
425 using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
426 c10::IValue operator()(
427 const std::vector<c10::IValue>& args,
428 const KeywordArgs& kwargs = KeywordArgs());
429 c10::IValue operator()(
430 std::vector<c10::IValue>&& args,
431 const KeywordArgs& kwargs = KeywordArgs());
432
graph()433 const Graph& graph() const {
434 return *graph_;
435 }
436
module()437 const Module& module() const {
438 DCHECK(module_.has_value());
439 return *module_;
440 }
441
442 const StaticModuleOptions& opts() const;
443
444 size_t num_inputs() const;
445 size_t num_outputs() const;
446
num_constants()447 size_t num_constants() const {
448 return constants_.size();
449 }
450
num_intermediate_values()451 size_t num_intermediate_values() const {
452 return num_intermediate_values_;
453 }
454
total_num_values()455 size_t total_num_values() const {
456 return num_inputs() + num_constants() + num_intermediate_values();
457 }
458
output_indices()459 C10_NODISCARD const std::vector<uint16_t>& output_indices() const {
460 return output_indices_;
461 }
462
constants()463 const std::vector<IValue>& constants() const {
464 return constants_;
465 }
466
block_info(Block * block)467 const BlockInfo& block_info(Block* block) const {
468 return block_infos_.at(block);
469 }
470
root_block()471 Block* root_block() const {
472 return graph_->block();
473 }
474
475 private:
476 friend class StaticRuntime;
477 friend class BlockRunner;
478
479 public:
num_nodes()480 auto num_nodes() const {
481 return std::accumulate(
482 block_infos_.begin(),
483 block_infos_.end(),
484 0,
485 [](size_t sum, const auto& block_and_info) {
486 auto& block_info = block_and_info.second;
487 return sum + block_info.num_nodes();
488 });
489 }
490
491 C10_NODISCARD Node* findNodeWithKindForTesting(const std::string& kind) const;
492
schema()493 const std::optional<c10::FunctionSchema>& schema() const {
494 return schema_;
495 }
496
first_input_is_self()497 bool first_input_is_self() const {
498 return module_.has_value();
499 }
500
501 StaticRuntime& runtime();
502
503 // See [Shared values array]
value_buffer_size()504 size_t value_buffer_size() const {
505 return value_buffer_size_;
506 }
507
508 private:
509 // Recursively prepares the BlockInfo array.
510 // - Populates `value_to_index` with the indices of each intermediate value
511 // - Returns the number of Value* processed, including sub-blocks.
512 size_t prepareBlockInfo(
513 Block* block,
514 const size_t start_idx,
515 c10::FastMap<const Value*, uint32_t>& value_to_index);
516
517 void prepareFunctionsAndConstants(
518 Block* block,
519 const AliasDb& alias_db,
520 c10::FastMap<const Value*, uint32_t>& value_to_index);
521
522 // Recursively traverse the graph and attach SR metadata
523 // to the prim::fork nodes as additional attributes
524 void attachNodeMetadata(Block* block);
525
526 // Recurses on sub-blocks and populates the array of ProcessedNodes
527 // Returns (number of nodes processed, number of blocks processed)
528 size_t prepareStaticNodeInfos(
529 Block* block,
530 const c10::FastMap<const Value*, uint32_t>& value_to_index,
531 const AliasDb& alias_db,
532 size_t node_idx = 0);
533
534 // Initialize various attributes that the memory planner will need.
535 // To be called at the tail of the ctor.
536 void prepareForMemoryPlanner();
537
538 StaticModuleOptions opts_;
539 // metadata that is stored in IR nodes as attribute
540 at::intrusive_ptr<jit::StaticRuntimeMetadata> sr_metadata_;
541 std::shared_ptr<torch::jit::Graph> graph_;
542 std::optional<torch::jit::Module> module_;
543 std::optional<c10::FunctionSchema> schema_;
544 std::unique_ptr<StaticRuntime> cached_runtime_;
545
546 // Bookkeeping for creating new StaticRuntime instances
547 // IValue table (defined by prim::Constant nodes)
548 std::vector<IValue> constants_;
549 // The functions to be called by corresponding ProcessedNode.
550 std::vector<ProcessedFunction> functions_{};
551 // A list of pre-processed nodes from which ProcessedNode are created per
552 // StaticRuntime instance.
553 std::vector<StaticNodeInfo> nodes_;
554 // Indices of graph outputs in the single values array.
555 std::vector<uint16_t> output_indices_;
556
557 size_t num_intermediate_values_ = 0;
558
559 // Includes self if module_ != std::nullopt.
560 // Note that we might have num_inputs_ == 0 even if the schema has a `self`
561 // argument. In this case, `self` isn't used in the graph, but the schema
562 // includes it anyways to be consistent with the JIT interpreter.
563 size_t num_inputs_;
564 // See `BlockInfo` definition. The blocks are stored in depth-first order.
565 c10::FastMap<Block*, BlockInfo> block_infos_;
566 size_t value_buffer_size_ = 0;
567 };
568
569 // `BlockRunner` contains the core runtime logic. Each block runner
570 // corresponds to one block in the graph and has its own memory planner.
571 // `StaticRuntime` will initialize all `BlockRunner`s
572 // upon construction. Each block runner only directly executes nodes from its
573 // block. Special ops with sub-blocks like `prim::If` may have
574 // `BlockRunner`s stored in their `ProcessedNode`s; these
575 // sub-blocks get executed in the op's implementation.
576 // `StaticRuntime` stores a vector of IValues that all
577 // `BlockRunner`s share. This vector is used to store all
578 // constants, inputs, and intermediate tensors.
579 class TORCH_API BlockRunner {
580 public:
581 BlockRunner(
582 const StaticModule& sm,
583 IValue* values,
584 Block* block,
585 torch::jit::TaskLauncher* launcher,
586 bool is_root_block = false);
587 BlockRunner(BlockRunner&&) noexcept;
588 BlockRunner& operator=(BlockRunner&&) = delete;
589 ~BlockRunner();
590
591 C10_DISABLE_COPY_AND_ASSIGN(BlockRunner);
592
593 using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
594 c10::IValue operator()(
595 const std::vector<c10::IValue>& args,
596 const KeywordArgs& kwargs = KeywordArgs());
597 c10::IValue operator()(
598 std::vector<c10::IValue>&& args,
599 const KeywordArgs& kwargs = KeywordArgs());
600
601 c10::intrusive_ptr<c10::ivalue::Future> runAsync(
602 const std::vector<c10::IValue>& args,
603 const KeywordArgs& kwargs);
604
605 c10::intrusive_ptr<c10::ivalue::Future> runAsync(
606 std::vector<c10::IValue>&& args,
607 const KeywordArgs& kwargs);
608
609 void benchmark(
610 const std::vector<std::vector<c10::IValue>>& args_list,
611 const std::vector<KeywordArgs>& kwargs_list,
612 const uint32_t warmup_runs,
613 const uint32_t main_runs,
614 bool print_per_node_time = false,
615 bool generate_ai_pep_output = false);
616
617 struct IndividualMetrics {
618 float setup_time{0.0};
619 float memory_alloc_time{0.0};
620 float memory_dealloc_time{0.0};
621 float output_dealloc_time{0.0};
622 float first_iter_time{0.0};
623 float total_time{0.0};
624 size_t out_nodes_count{0};
625 size_t total_nodes_count{0};
626 std::vector<float> time_per_node;
627 std::unordered_map<std::string, float> time_per_node_type;
628 std::unordered_map<std::string, float> percent_per_node_type;
629 std::unordered_map<std::string, int> instances_per_node_type;
630 std::unordered_set<std::string> out_nodes;
631 std::unordered_set<std::string> native_nodes;
632 };
633
634 IndividualMetrics benchmark_individual_ops(
635 const std::vector<std::vector<c10::IValue>>& args_list,
636 const std::vector<KeywordArgs>& kwargs_list,
637 const uint32_t warmup_runs,
638 const uint32_t main_runs);
639
640 // Input is readwrite
Input(uint32_t i)641 IValue& Input(uint32_t i) {
642 TORCH_DCHECK_LT(i, block_info_.num_inputs());
643 return values_[i + block_info_.block_inputs_idx()];
644 }
645
646 // Output is readonly. The writing process happens inside ProcessedNodes
Output(uint32_t i)647 C10_NODISCARD const IValue& Output(uint32_t i) const {
648 DCHECK(i < outputs_.size());
649 return *outputs_[i];
650 }
651
outputs()652 const std::vector<IValue*> outputs() const {
653 return outputs_;
654 }
655
nodes()656 const std::vector<ProcessedNode>& nodes() const {
657 return nodes_;
658 }
659
nodes()660 std::vector<ProcessedNode>& nodes() {
661 return nodes_;
662 }
663
node_ptrs()664 graph_node_list node_ptrs() const {
665 return block_info_.node_ptrs();
666 }
667
graph()668 const Graph& graph() const {
669 return static_module_.graph();
670 }
671
get_memory_planner()672 const MemoryPlanner* get_memory_planner() const {
673 return planner_.get();
674 }
675
676 bool check_for_memory_leak(
677 bool output_returned = true,
678 bool recurse_on_sub_blocks = false);
679
680 // WARNING: Deallocate managed output tensors. A client receiving Static
681 // Runtime-managed Tensors needs to be very careful to call
682 // `StaticRuntime::deallocateOutputTensors` after all references of output
683 // Tensors are gone.
684 void deallocateOutputTensors();
685
686 bool checkOutputTensorMemoryLeaks();
687
688 bool isManagedOutputTensor(const IValue& ivalue) const;
689 bool isManagedOutputTensorValue(const Value* value) const;
690
691 void disableManageOutputTensors();
692
693 // This is the fallback path taken if we can't construct the memory planner
694 // on the first iteration.
695 // IMPORTANT: Nothing here should be able to throw!!!
696 // This function can be called from the (implicitly) `noexcept` destructor
697 // of Deallocator, meaning that std::terminate will be called
698 // if any exception escapes. Even if resetMemory and ~Deallocator were
699 // `noexcept(false)`, it's possible that when ~Deallocator is called, the
700 // stack is already unwinding, so there's still danger of calling
701 // std::terminate.
702 void resetMemory() noexcept;
703
704 private:
705 // A helper object that invokes memory planner deallocation code
706 // when destructed.
707 class Deallocator {
708 public:
Deallocator(BlockRunner & block_runner)709 explicit Deallocator(BlockRunner& block_runner)
710 : block_runner_(block_runner) {}
711
712 Deallocator(Deallocator&&) = default;
713 Deallocator(const Deallocator&) = default;
714 Deallocator& operator=(const Deallocator&) = delete;
715 Deallocator& operator=(Deallocator&&) = delete;
716 ~Deallocator();
717
setFinished()718 void setFinished() {
719 finished_ = true;
720 }
721
722 private:
723 void cleanupImpl();
724
725 bool finished_ = false;
726 BlockRunner& block_runner_;
727 };
728
729 template <typename IValueList>
730 c10::IValue run_impl(IValueList&& args, const KeywordArgs& kwargs);
731
732 template <typename IValueList>
733 c10::IValue run_impl_record_functions(
734 IValueList&& args,
735 const KeywordArgs& kwargs);
736
737 template <typename IValueList>
738 c10::intrusive_ptr<c10::ivalue::Future> run_impl_async(
739 IValueList&& args,
740 const KeywordArgs& kwargs);
741
742 template <typename IValueList>
743 c10::intrusive_ptr<c10::ivalue::Future> run_impl_record_functions_async(
744 IValueList&& args,
745 const KeywordArgs& kwargs);
746
747 // helper method for copying input args/kwargs into inputs_
748 template <typename IValueList>
749 void set_inputs(IValueList&& args, const KeywordArgs& kwargs);
750
751 // Set Input(idx) to args[idx]. Invoked by set_inputs. Copies or moves
752 // depending on overload.
753 void set_arg(const size_t idx, std::vector<IValue>&& args);
754 void set_arg(const size_t idx, const std::vector<IValue>& args);
755
756 // Set Input(idx) to arg. Always copies. Used for kwargs.
757 void set_arg(const size_t idx, const IValue& arg);
758
759 bool fast_check_and_correct_overlap_with(
760 ProcessedNode& n,
761 c10::IValue& tensor_ival);
762 void verify_and_correct_memory_overlap(ProcessedNode& n);
763
764 // clean up owning refs of input IValues
clean_up_input_ivalues()765 void clean_up_input_ivalues() noexcept {
766 for (const auto idx : c10::irange(block_info_.num_inputs())) {
767 values_[idx + inputs_begin_] = IValue();
768 }
769 }
770
771 void clean_up_intermediate_ivalues() noexcept;
772
773 IValue move_outputs_to_tuple(uint32_t num_outputs);
774
775 void create_memory_planner();
776
777 float benchmark_model(
778 const std::vector<std::vector<c10::IValue>>& args_list,
779 const std::vector<KeywordArgs>& kwargs_list,
780 const uint32_t warmup_runs,
781 const uint32_t main_runs);
782
783 void display_nodes(
784 const std::vector<c10::IValue>& args,
785 const KeywordArgs& kwargs);
786
787 const StaticModule& static_module_;
788 const BlockInfo& block_info_;
789
790 const bool is_root_block_;
791 // Cache this so we don't have to call static_module_.first_input_is_self()
792 const bool first_input_is_self_;
793 // Index of the start of this blocks inputs in the shared values_ array.
794 const uint16_t inputs_begin_;
795
796 bool manage_output_tensors_enabled_ = false;
797 std::unique_ptr<MemoryPlanner> planner_;
798 // [Shared values array]
799 // ProcessedNodes reference their inputs and outputs with
800 // offsets into this array, which saves memory.
801 // All BlockRunners share the same array. The layout is as
802 // follows:
803 // [constants][block_0][block_1]...[block_N]
804 // Note that constants from all blocks are pooled together at the start.
805 // The block ordering is depth-first.
806 // Each block is further divided into inputs and intermediates:
807 // [block_i] = [inputs_i][intermediates_i]
808 // Each BlockRunner knows where its inputs start. Each ProcessedNode
809 // knows how to find the indices of its outputs/inputs in this array.
810 IValue* values_;
811
812 std::vector<IValue*> outputs_;
813 std::vector<ProcessedNode> nodes_;
814 };
815
816 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
817 class TORCH_API StaticNodeInfo {
818 public:
819 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
820 StaticNodeInfo(
821 Node* n,
822 ProcessedFunction* fn,
823 ProcessedNodeInputs inputs,
824 uint16_t outputs_offset);
825
node()826 Node* node() const {
827 return node_;
828 }
829
num_outputs()830 size_t num_outputs() const {
831 DCHECK(fn_ != nullptr);
832 return fn_->num_outputs();
833 }
834
has_out_variant()835 bool has_out_variant() const {
836 return fn_->kind() == ProcessedFunction::Kind::kOutVariant;
837 }
838
839 private:
840 friend class ProcessedNode;
841
842 Node* node_;
843 const ProcessedFunction* fn_;
844 ProcessedNodeInputs inputs_;
845 uint16_t outputs_offset_;
846 };
847
num_nodes()848 inline size_t BlockInfo::num_nodes() const {
849 return nodes_.size();
850 }
851
852 /*
853 ProcessedNodeMetadata class wraps the possible metadata
854 for ProcessedNode. Depending upon the nature of op, processedNode
855 can have one of the below possibilities of metadata:
856 - prim::If/prim::Loop ops contains block_runners_ as their metadata
857 - prim::fork op contains TaskLauncher (std::function) responsible for
858 execution of forked subgraph
859 */
860 class TORCH_API ProcessedNodeMetadata {
861 public:
ProcessedNodeMetadata(std::vector<BlockRunner> runners,torch::jit::TaskLauncher * launcher)862 ProcessedNodeMetadata(
863 std::vector<BlockRunner> runners,
864 torch::jit::TaskLauncher* launcher)
865 : block_runners_(std::move(runners)), launcher_(launcher) {}
866
ProcessedNodeMetadata()867 ProcessedNodeMetadata() : launcher_(nullptr) {}
868
869 // deleted copy ctor/assignment as standard containers (vector) always
870 // have copy constructors, but their instantiation is not well-formed
871 // if the contained type (BlockRunner) is not copyable
872 ProcessedNodeMetadata(const ProcessedNodeMetadata&) = delete;
873 ProcessedNodeMetadata& operator=(const ProcessedNodeMetadata&) = delete;
874
block_runners()875 std::vector<BlockRunner>& block_runners() {
876 return block_runners_;
877 }
878
set_block_runners(std::vector<BlockRunner> runners)879 void set_block_runners(std::vector<BlockRunner> runners) {
880 block_runners_ = std::move(runners);
881 }
882
set_launcher(torch::jit::TaskLauncher * launcher)883 void set_launcher(torch::jit::TaskLauncher* launcher) {
884 launcher_ = launcher;
885 }
886
launcher()887 torch::jit::TaskLauncher* launcher() {
888 return launcher_;
889 }
890
891 private:
892 std::vector<BlockRunner> block_runners_;
893 torch::jit::TaskLauncher* launcher_;
894 };
895
896 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
897 class TORCH_API ProcessedNode {
898 public:
899 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
900 ProcessedNode() = default;
901
ProcessedNode(const StaticNodeInfo & other,IValue * values)902 ProcessedNode(const StaticNodeInfo& other, IValue* values)
903 : node_(other.node_),
904 fn_(other.fn_),
905 inputs_(other.inputs_),
906 outputs_offset_(other.outputs_offset_),
907 values_(values),
908 metadata_(nullptr) {}
909
910 // These should be noexcept, but some Android build is failing
911 // saying the noexcept specification doesn't match the calculated
912 // one. Maybe std::variant is throwing it off?
913 ProcessedNode(ProcessedNode&&) = default;
914
915 ProcessedNode(const ProcessedNode&) = delete;
916 ProcessedNode& operator=(const ProcessedNode& other) = delete;
917 ProcessedNode& operator=(ProcessedNode&&) = default;
918
919 void run();
920
node()921 Node* node() const {
922 return node_;
923 }
924
925 // Input is readonly
Input(uint32_t i)926 C10_NODISCARD const IValue& Input(uint32_t i) const {
927 return values_[inputs_[i]];
928 }
929
930 // Output is readwrite
Output(uint32_t i)931 IValue& Output(uint32_t i) {
932 DCHECK(i < num_outputs());
933 return values_[outputs_offset_ + i];
934 }
935
Output(uint32_t i)936 C10_NODISCARD const IValue& Output(uint32_t i) const {
937 DCHECK(i < num_outputs());
938 return values_[outputs_offset_ + i];
939 }
940
num_outputs()941 uint32_t num_outputs() const {
942 DCHECK(fn_ != nullptr);
943 return static_cast<uint32_t>(fn_->num_outputs());
944 }
945
outputs()946 C10_NODISCARD c10::ArrayRef<const IValue> outputs() const {
947 return c10::ArrayRef<const IValue>(
948 values_ + outputs_offset_, num_outputs());
949 }
950
num_inputs()951 C10_NODISCARD uint16_t num_inputs() const {
952 return inputs_.size();
953 }
954
955 std::vector<IValue> inputs_ivalue_vec() const;
956
has_out_variant()957 bool has_out_variant() const {
958 return fn_->kind() == ProcessedFunction::Kind::kOutVariant;
959 }
960
has_native()961 bool has_native() const {
962 return fn_->kind() == ProcessedFunction::Kind::kNativeFunction;
963 }
964
965 #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
get_op_name()966 const char* get_op_name() const {
967 return node_->kind().toQualString();
968 }
969 #endif
970
check_outputs_for_memory_overlap()971 bool check_outputs_for_memory_overlap() const {
972 return fn_->checkMemoryOverlap();
973 }
974
set_outputs_memory_overlap_detected()975 void set_outputs_memory_overlap_detected() {
976 overlap_detected_ = true;
977 }
978
outputs_memory_overlap_detected()979 bool outputs_memory_overlap_detected() {
980 return overlap_detected_;
981 }
982
983 bool check_and_correct_overlap_with(
984 const at::Tensor& input,
985 c10::IValue& output);
986 void verify_and_correct_memory_overlap();
987
set_values(IValue * values)988 void set_values(IValue* values) {
989 DCHECK(values_ == nullptr);
990 values_ = values;
991 }
992
output_ivalue_index(uint16_t i)993 C10_NODISCARD uint16_t output_ivalue_index(uint16_t i) const {
994 DCHECK(i < num_outputs());
995 return outputs_offset_ + i;
996 }
997 // used in debug mode
998 bool verify_no_memory_overlap(bool force_check = false) const;
999
1000 // returns pointer to ProcessedNodeMetadata or nullptr if no object is owned
metadata()1001 ProcessedNodeMetadata* metadata() {
1002 return metadata_.get();
1003 }
1004
1005 // attach block_runner to metadata of ProcessedNode
set_metadata(std::vector<BlockRunner> block_runners)1006 void set_metadata(std::vector<BlockRunner> block_runners) {
1007 if (metadata_ == nullptr) {
1008 metadata_ = std::make_unique<ProcessedNodeMetadata>();
1009 }
1010 metadata_->set_block_runners(std::move(block_runners));
1011 }
1012
1013 // attach TaskLauncher to metadata of ProcessedNode
set_metadata(torch::jit::TaskLauncher * launcher)1014 void set_metadata(torch::jit::TaskLauncher* launcher) {
1015 if (metadata_ == nullptr) {
1016 metadata_ = std::make_unique<ProcessedNodeMetadata>();
1017 }
1018 metadata_->set_launcher(launcher);
1019 }
1020
1021 private:
1022 C10_NODISCARD bool verify_outputs_dont_overlap_each_other() const;
1023
1024 C10_NODISCARD bool verify_inputs_dont_overlap_outputs(bool force_check) const;
1025
1026 Node* node_;
1027 const ProcessedFunction* fn_;
1028 ProcessedNodeInputs inputs_;
1029 uint16_t outputs_offset_;
1030 bool overlap_detected_{false};
1031 IValue* values_ = nullptr; // unowned
1032 // Metadata for ProcessedNode.
1033 // 1. prim::If/Loop nodes contains sub-blocks as metadata
1034 // 2. prim::fork nodes contains custom executor for async execution
1035 std::unique_ptr<ProcessedNodeMetadata> metadata_;
1036 };
1037
1038 // `StaticRuntime` is the owner of the array of IValues (used for constants,
1039 // inputs, and intermediate tensors) that all `BlockRunner`s share.
1040 // Upon construction, it initializes all block runners. `operator()` simply
1041 // forwards the inputs to the top-level block runner. Each `StaticRuntime`
1042 // instance corresponds to one `StaticModule`. Multiple `StaticRuntime`
1043 // instances can be created; this is useful for multi-threaded execution, since
1044 // `operator()` is not thread-safe.
1045 class TORCH_API StaticRuntime {
1046 public:
1047 explicit StaticRuntime(const StaticModule& sm);
1048
1049 using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
1050 c10::IValue operator()(
1051 const std::vector<c10::IValue>& args,
1052 const KeywordArgs& kwargs = KeywordArgs());
1053 c10::IValue operator()(
1054 std::vector<c10::IValue>&& args,
1055 const KeywordArgs& kwargs = KeywordArgs());
1056
1057 // runAsync performs inline execution of graph on
1058 // caller thread and async execution on taskLauncher
1059 // If no custom taskLauncher is specified, execution is done
1060 // on inter-op thread pool.
1061 c10::intrusive_ptr<c10::ivalue::Future> runAsync(
1062 const std::vector<c10::IValue>& args,
1063 const KeywordArgs& kwargs = KeywordArgs(),
1064 torch::jit::TaskLauncher taskLauncher = at::launch);
1065
1066 c10::intrusive_ptr<c10::ivalue::Future> runAsync(
1067 std::vector<c10::IValue>&& args,
1068 const KeywordArgs& kwargs = KeywordArgs(),
1069 torch::jit::TaskLauncher taskLauncher = at::launch);
1070
1071 bool check_for_memory_leak(bool output_returned = true);
1072 bool checkOutputTensorMemoryLeaks();
1073
1074 void deallocateOutputTensors();
1075 bool isManagedOutputTensor(const IValue& ivalue) const;
1076 void disableManageOutputTensors();
1077
1078 // Gets the top-level memory planner. Used for testing.
1079 const MemoryPlanner* get_memory_planner() const;
1080
1081 void benchmark(
1082 const std::vector<std::vector<c10::IValue>>& args_list,
1083 const std::vector<KeywordArgs>& kwargs_list,
1084 const uint32_t warmup_runs,
1085 const uint32_t main_runs,
1086 bool print_per_node_time = false,
1087 bool generate_ai_pep_output = false) {
1088 block_->benchmark(
1089 args_list,
1090 kwargs_list,
1091 warmup_runs,
1092 main_runs,
1093 print_per_node_time,
1094 generate_ai_pep_output);
1095 }
1096
1097 using IndividualMetrics = BlockRunner::IndividualMetrics;
1098
benchmark_individual_ops(const std::vector<std::vector<c10::IValue>> & args_list,const std::vector<KeywordArgs> & kwargs_list,const int warmup_runs,const int main_runs)1099 IndividualMetrics benchmark_individual_ops(
1100 const std::vector<std::vector<c10::IValue>>& args_list,
1101 const std::vector<KeywordArgs>& kwargs_list,
1102 const int warmup_runs,
1103 const int main_runs) {
1104 return block_->benchmark_individual_ops(
1105 args_list, kwargs_list, warmup_runs, main_runs);
1106 }
1107
1108 private:
1109 // An array of IValues with unchanging size/data ptr.
1110 class IValueArray {
1111 public:
1112 IValueArray() = default;
IValueArray(size_t size)1113 explicit IValueArray(size_t size) : array_(allocate(size)), size_(size) {}
1114
data()1115 IValue* data() const {
1116 return array_.get();
1117 }
1118
size()1119 size_t size() const {
1120 return size_;
1121 }
1122
1123 private:
1124 // NOLINTNEXTLINE(modernize-avoid-c-arrays)
1125 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
allocate(size_t size)1126 static std::unique_ptr<IValue[]> allocate(size_t size) {
1127 if (size) {
1128 return std::make_unique<IValue[]>(size);
1129 }
1130 return nullptr;
1131 }
1132
1133 // NOLINTNEXTLINE(modernize-avoid-c-arrays)
1134 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
1135 std::unique_ptr<IValue[]> array_ = nullptr;
1136 size_t size_ = 0;
1137 };
1138
1139 std::unique_ptr<BlockRunner> block_;
1140 // for execution of async operations present in graph
1141 torch::jit::TaskLauncher async_task_launcher_;
1142 IValueArray values_;
1143 };
1144
1145 } // namespace torch::jit
1146