xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/impl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/ivalue.h>
3 #include <ATen/core/symbol.h>
4 #include <c10/core/CPUAllocator.h>
5 #include <c10/macros/Macros.h>
6 #include <c10/util/ArrayRef.h>
7 #include <c10/util/FbcodeMaps.h>
8 #include <torch/csrc/jit/api/module.h>
9 #include <torch/csrc/jit/ir/graph_node_list.h>
10 #include <torch/csrc/jit/ir/ir.h>
11 #include <torch/csrc/jit/passes/constant_propagation.h>
12 #include <torch/csrc/jit/passes/freeze_module.h>
13 #include <torch/csrc/jit/passes/inliner.h>
14 #include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
15 #include <torch/custom_class.h>
16 #include <limits>
17 
18 #ifdef FBCODE_CAFFE2
19 #include <folly/container/F14Map.h>
20 #include <folly/container/F14Set.h>
21 #endif
22 
23 namespace torch::jit {
24 
25 TORCH_API bool canEnableStaticRuntime(
26     const std::shared_ptr<torch::jit::Graph>& graph);
27 
28 TORCH_API std::string dumpValueSet(
29     const c10::FastSet<const Value*>& value_set,
30     const char* set_name = "");
31 
doesNotHeapAllocateWhenStoredInIValue(const Type & type)32 TORCH_API inline bool doesNotHeapAllocateWhenStoredInIValue(const Type& type) {
33   switch (type.kind()) {
34     // NOTE: NumberType may allocate because it includes complex.
35     case TypeKind::NoneType:
36     case TypeKind::IntType:
37     case TypeKind::FloatType:
38     case TypeKind::BoolType:
39     case TypeKind::DeviceObjType:
40     case TypeKind::StreamObjType:
41       return true;
42     default:
43       return false;
44   }
45 }
46 
getStaticRuntimeMetadataSymbol()47 TORCH_API inline c10::Symbol getStaticRuntimeMetadataSymbol() {
48   return Symbol::attr("static_runtime::metadata");
49 }
50 
borrowsOutputs(c10::Symbol kind)51 TORCH_API inline bool borrowsOutputs(c10::Symbol kind) {
52   static const std::array<c10::Symbol, 4> symbols_with_borrowed_outputs = {
53       c10::Symbol::fromQualString("static_runtime::select_tensor"),
54       c10::Symbol::fromQualString("static_runtime::dict_unpack"),
55       c10::Symbol::fromQualString("static_runtime::VarTupleUnpack"),
56       c10::Symbol::fromQualString("prim::IfThenElse"),
57   };
58   return std::find(
59              symbols_with_borrowed_outputs.begin(),
60              symbols_with_borrowed_outputs.end(),
61              kind) != symbols_with_borrowed_outputs.end();
62 }
63 
64 // Group values used by `graph` into three categories:
65 //
66 // - output_aliases:
67 //     values that are either outputs or contain aliases of outputs
68 // - external_aliases:
69 //     values that are inputs, constants, or their aliases.
70 //     The output aliases that end up here are as a result of aliasDb failing to
71 //     recognize them as outputs due to collection object (e.g., Tuple) aliasing
72 //     inputs.
73 // Values that dont't show up in output_aliases or external_aliases are created
74 // and consumed within the graph.
75 class ValueGroup {
76  public:
77   explicit ValueGroup() = default;
78   void init(const Block& block, const AliasDb& db);
79 
isExternalAlias(const Value * value)80   bool isExternalAlias(const Value* value) const {
81     return external_aliases_.find(value) != external_aliases_.end();
82   }
83 
isOutputAlias(const Value * value)84   bool isOutputAlias(const Value* value) const {
85     return output_aliases_.find(value) != output_aliases_.end();
86   }
87 
isAlwaysAlive(const Value * value)88   bool isAlwaysAlive(const Value* value) const {
89     return isExternalAlias(value) || isOutputAlias(value);
90   }
91 
toString()92   std::string toString() const {
93     return c10::str(
94         dumpValueSet(output_aliases_, "ValueGroup::output_aliases_"),
95         "\n",
96         dumpValueSet(external_aliases_, "ValueGroup::external_aliases_"));
97   }
98 
99  private:
100   c10::FastSet<const Value*> output_aliases_;
101   c10::FastSet<const Value*> external_aliases_;
102 };
103 
104 class TORCH_API ManagedTensorRanges {
105  public:
106   ManagedTensorRanges() = default;
107   ManagedTensorRanges(
108       Block& block,
109       const AliasDb& alias_db,
110       const c10::FastSet<const Value*>& managed_tensor_values);
111 
112   // If true, then this node is the last use of at least one
113   // managed tensor. availableTensorValuesAfterNode(node) will return a vector
114   // of the managed tensors that are available for re-use
115   // in the nodes following this one.
116   bool nodeFreesManagedTensors(Node* node) const;
117   const std::vector<const Value*>& availableTensorValuesAfterNode(
118       Node* node) const;
119 
120   // For testing. True if v1 and v2 are both mutable types and have lifetimes
121   // that overlap.
122   bool lifetimesOverlap(const Value* v1, const Value* v2) const;
123 
124  private:
125   struct Lifetime {
LifetimeLifetime126     Lifetime(size_t start_, size_t end_) : start(start_), end(end_) {}
127     size_t start;
128     size_t end;
129   };
130 
131   // Returns nullptr if we are not tracking the lifetime of value
132   Lifetime* getLifetime(const Value* value);
133   const Lifetime* getLifetime(const Value* value) const;
134   // Collect all values in the input that have tracked lifetimes.
135   // A value's lifetime may not be tracked if it is a graph input
136   // or immutable type (containers with at least one mutable
137   // type are mutable)
138   std::vector<const Value*> collectValuesWithTrackedLifetimes(
139       at::ArrayRef<const Value*> values);
140 
141   // Maps Node* to the set of managed tensors that are now available
142   // for re-use after this node.
143   c10::FastMap<Node*, std::vector<const Value*>> node_to_newly_free_tensors_{};
144   // Maps each Value* to its lifetime (start node index, end node index)
145   c10::FastMap<const Value*, Lifetime> value_lifetimes_{};
146 };
147 
148 struct TORCH_API StaticModuleOptions {
149   // enabling out variant allows Static Runtime to do memory planning
150   bool enable_out_variant{true};
151   // to reuse tensor storage for tensors whose live-range do not overlap to
152   // reduce memory footprint (enable_out_variant must be true)
153   bool optimize_memory{true};
154   // to batch allocate tensor storage for output tensors of the
155   // graph, where storage is deallocated outside static runtime
156   // (enable_out_variant must be true)
157   bool manage_output_tensors{false};
158   // Gates the ReplaceWithCopy pass, which replaces ops that
159   // sometimes alias their outputs with out variants that
160   // always copy (so the output may participate in memory planning).
161   // Since replacing with copies is done after TensorExpr fusion, the
162   // resulting graph does not conform to the assumptions made in the fuser.
163   // So, even if this flag is turned on, the ReplaceWithCopy pass will not
164   // be executed if TensorExpr fusion is enabled.
165   bool use_copy_variants{true};
166   // Gates the ReplaceWithMaybeCopy pass, which replaces ops that
167   // sometimes alias their outputs with subgraphs that include an out
168   // variant.
169   // For the same reason as `use_copy_variants`, the ReplaceWithMaybeCopy pass
170   // will not be executed if TensorExpr fusion is enabled, even if this flag
171   // is turned on.
172   bool use_maybe_copy_variants{true};
173   // enable TensorExpr fusion of ops at model loading time
174   bool enable_tensorexpr_fusion{false};
175 };
176 
177 /*
178   Responsible for plugging StaticRuntime metadata onto the
179   IR nodes. StaticRuntimeMetdata extends CustomClassHolder
180   which can be casted to IValue and attached to IR node.
181   This is needed to pass parent graph metadata to forked
182   graph in presence of prim::fork operator
183 */
184 class TORCH_API StaticRuntimeMetadata : public torch::CustomClassHolder {
185  public:
StaticRuntimeMetadata(const StaticModuleOptions & opts)186   explicit StaticRuntimeMetadata(const StaticModuleOptions& opts)
187       : opts_(opts) {}
188 
get_opts()189   const StaticModuleOptions& get_opts() {
190     return opts_;
191   }
192 
193  private:
194   StaticModuleOptions opts_;
195 };
196 
197 /// The static runime supports two execution modes.
198 ///
199 /// Mode 1: single-threaded with no parallelism except for intra-op parallelism
200 /// For this mode, you can do either:
201 /// @code
202 ///   // m is a TorchScript module
203 ///   auto module = StaticModule(m, opts);
204 ///   auto output = module(args, kwargs);
205 /// @endcode
206 ///
207 /// or
208 ///
209 /// @code
210 ///   // g is the TorchScript graph
211 ///   auto module = StaticModule(g, opts);
212 ///   auto output = module(args, kwargs);
213 /// @endcode
214 ///
215 /// Mode 2: similar to data parallelism, run the same model for different inputs
216 /// on different threads at the same time.
217 /// You should have one StaticModule per model, and one StaticRuntime instance
218 /// per running thread. To avoiding creating StaticRuntimes on the fly, use a
219 /// synchronized stack (i.e. boost::lockfree::stack) to cache all the
220 /// StaticRuntime instances in your code.
221 /// @code
222 ///   // initialization
223 ///   auto module = std::make_shared<StaticModule>(m, opts);
224 ///
225 ///   // 128 is good for most cases. Pick a number that works for you
226 ///   boost::lockfree::stack<std::shared_ptr<StaticRuntime>,
227 ///     boost::lockfree::fixed_sized<true>> pool(128);
228 ///
229 ///   // inference
230 ///   std::shared_ptr<StaticRuntime> runtime = nullptr;
231 ///   pool.pop(runtime);
232 ///   if (!runtime) {
233 ///     // holds a reference to the underlying module
234 ///     // but does its own memory management
235 ///     runtime = std::make_shared<StaticRuntime>(*module);
236 ///   }
237 ///   auto output = runtime(args, kwargs);
238 ///   pool.push(runtime);
239 /// @endcode
240 ///
241 class MemoryPlanner;
242 class StaticNodeInfo;
243 class ProcessedNode;
244 class StaticRuntime;
245 
246 using SROperator = std::function<void(ProcessedNode*)>;
247 
248 #ifdef FBCODE_CAFFE2
249 struct TORCH_API SROperatorObserver {
250   using OperatorCallback = void (*)(const Node*);
251   OperatorCallback startCb = nullptr;
252   OperatorCallback endCb = nullptr;
253 
254   static void setCurrentThreadObserver(SROperatorObserver* observer);
255   static SROperatorObserver* getCurrentThreadObserver();
256   static void onStart(const Node* name);
257   static void onEnd(const Node* name);
258 };
259 #endif
260 
261 class TORCH_API ProcessedFunction {
262  public:
263   ProcessedFunction(
264       Node* node,
265       bool enable_out_variant,
266       bool check_memory_overlap);
267 
268   enum class Kind : uint8_t {
269     kOutVariant,
270     kNativeFunction,
271     kInterpreterFallback,
272   };
273 
run(ProcessedNode * pnode)274   void run(ProcessedNode* pnode) const {
275     return f_(pnode);
276   }
277 
kind()278   Kind kind() const {
279     return kind_;
280   }
281 
checkMemoryOverlap()282   bool checkMemoryOverlap() const {
283     return check_memory_overlap_;
284   }
285 
num_outputs()286   size_t num_outputs() const {
287     return num_outputs_;
288   }
289 
290  private:
291   SROperator f_;
292   Kind kind_{ProcessedFunction::Kind::kOutVariant};
293   bool check_memory_overlap_{false};
294   size_t num_outputs_{0};
295 };
296 
297 // A `BlockInfo` instance stores all of the shared state that each
298 // `BlockRunner` will need to access. Most of this information is
299 // read-only and shared between threads.
300 // - Each `BlockInfo` corresponds to one block in the graph.
301 // - Each `BlockInfo` may be used by multiple block runners (when there are many
302 //   threads).
303 // - All of the `BlockInfo`s are stored in a vector in the `StaticModule` and
304 //   are initialized during `StaticModule` construction.
305 // - Most of the information stored is used to initialize the block's memory
306 //   planner.
307 class BlockInfo {
308  public:
309   BlockInfo(uint32_t input_idx, Block& block);
310 
311   void set_nodes(
312       std::vector<StaticNodeInfo> nodes,
313       const c10::FastMap<Node*, bool>& node_has_out_variant);
314 
nodes()315   const std::vector<StaticNodeInfo>& nodes() const {
316     return nodes_;
317   }
318 
319   size_t num_nodes() const;
320 
num_inputs()321   size_t num_inputs() const {
322     return block_.inputs().size();
323   }
324 
num_outputs()325   size_t num_outputs() const {
326     return block_.outputs().size();
327   }
328 
node_ptrs()329   graph_node_list node_ptrs() const {
330     return block_.nodes();
331   }
332 
set_output_indices(std::vector<uint16_t> indices)333   void set_output_indices(std::vector<uint16_t> indices) {
334     output_indices_ = std::move(indices);
335   }
336 
block_output_indices()337   const std::vector<uint16_t>& block_output_indices() const {
338     return output_indices_;
339   }
340 
block_inputs_idx()341   auto block_inputs_idx() const {
342     return input_idx_;
343   }
344 
node_is_optimizable_container_type(const Node * node)345   bool node_is_optimizable_container_type(const Node* node) const {
346     return node_is_optimizable_container_type_.find(node) !=
347         node_is_optimizable_container_type_.end();
348   }
349 
value_is_managed_tensor(const Value * value)350   bool value_is_managed_tensor(const Value* value) const {
351     return managed_tensor_values_.find(value) != managed_tensor_values_.end();
352   }
353 
value_is_leaked_container(const Value * value)354   bool value_is_leaked_container(const Value* value) const {
355     return leaked_values_.find(value) != leaked_values_.end();
356   }
357 
value_group()358   const ValueGroup& value_group() const {
359     return value_group_;
360   }
361 
managed_tensor_ranges()362   const ManagedTensorRanges& managed_tensor_ranges() const {
363     return managed_tensor_ranges_;
364   }
365 
init_value_group(const AliasDb & alias_db)366   void init_value_group(const AliasDb& alias_db) {
367     value_group_.init(block_, alias_db);
368   }
369 
370   void prepare_for_memory_planner(
371       const AliasDb& alias_db,
372       const StaticModuleOptions& opt);
373 
managed_output_tensor_values()374   const auto& managed_output_tensor_values() const {
375     return managed_output_tensor_values_;
376   }
377 
managed_tensor_values()378   const auto& managed_tensor_values() const {
379     return managed_tensor_values_;
380   }
381 
leaked_values()382   const auto& leaked_values() const {
383     return leaked_values_;
384   }
385 
386  private:
387   std::vector<StaticNodeInfo> nodes_;
388 
389   ValueGroup value_group_;
390 
391   c10::FastSet<const Node*> node_is_optimizable_container_type_;
392   c10::FastSet<const Value*> managed_tensor_values_;
393   c10::FastSet<const Value*> managed_output_tensor_values_;
394   c10::FastSet<const Value*> leaked_values_;
395 
396   ManagedTensorRanges managed_tensor_ranges_{};
397 
398   // The index of this block's inputs in the shared values_ array.
399   const uint16_t input_idx_;
400   // The indices of this block's outputs in the shared values_ array.
401   std::vector<uint16_t> output_indices_;
402   Block& block_;
403 };
404 
405 class TORCH_API StaticModule {
406  public:
407   explicit StaticModule(
408       const std::shared_ptr<torch::jit::Graph>& g,
409       const StaticModuleOptions& opts = StaticModuleOptions(),
410       std::vector<IValue> sample_inputs = {});
411 
412   explicit StaticModule(
413       const torch::jit::Module& m,
414       bool is_frozen = false,
415       const StaticModuleOptions& opts = StaticModuleOptions(),
416       std::vector<IValue> sample_inputs = {});
417 
418  private:
419   explicit StaticModule(
420       std::pair<std::shared_ptr<torch::jit::Graph>, std::optional<Module>>
421           graph_and_module,
422       const StaticModuleOptions& opts);
423 
424  public:
425   using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
426   c10::IValue operator()(
427       const std::vector<c10::IValue>& args,
428       const KeywordArgs& kwargs = KeywordArgs());
429   c10::IValue operator()(
430       std::vector<c10::IValue>&& args,
431       const KeywordArgs& kwargs = KeywordArgs());
432 
graph()433   const Graph& graph() const {
434     return *graph_;
435   }
436 
module()437   const Module& module() const {
438     DCHECK(module_.has_value());
439     return *module_;
440   }
441 
442   const StaticModuleOptions& opts() const;
443 
444   size_t num_inputs() const;
445   size_t num_outputs() const;
446 
num_constants()447   size_t num_constants() const {
448     return constants_.size();
449   }
450 
num_intermediate_values()451   size_t num_intermediate_values() const {
452     return num_intermediate_values_;
453   }
454 
total_num_values()455   size_t total_num_values() const {
456     return num_inputs() + num_constants() + num_intermediate_values();
457   }
458 
output_indices()459   C10_NODISCARD const std::vector<uint16_t>& output_indices() const {
460     return output_indices_;
461   }
462 
constants()463   const std::vector<IValue>& constants() const {
464     return constants_;
465   }
466 
block_info(Block * block)467   const BlockInfo& block_info(Block* block) const {
468     return block_infos_.at(block);
469   }
470 
root_block()471   Block* root_block() const {
472     return graph_->block();
473   }
474 
475  private:
476   friend class StaticRuntime;
477   friend class BlockRunner;
478 
479  public:
num_nodes()480   auto num_nodes() const {
481     return std::accumulate(
482         block_infos_.begin(),
483         block_infos_.end(),
484         0,
485         [](size_t sum, const auto& block_and_info) {
486           auto& block_info = block_and_info.second;
487           return sum + block_info.num_nodes();
488         });
489   }
490 
491   C10_NODISCARD Node* findNodeWithKindForTesting(const std::string& kind) const;
492 
schema()493   const std::optional<c10::FunctionSchema>& schema() const {
494     return schema_;
495   }
496 
first_input_is_self()497   bool first_input_is_self() const {
498     return module_.has_value();
499   }
500 
501   StaticRuntime& runtime();
502 
503   // See [Shared values array]
value_buffer_size()504   size_t value_buffer_size() const {
505     return value_buffer_size_;
506   }
507 
508  private:
509   // Recursively prepares the BlockInfo array.
510   // - Populates `value_to_index` with the indices of each intermediate value
511   // - Returns the number of Value* processed, including sub-blocks.
512   size_t prepareBlockInfo(
513       Block* block,
514       const size_t start_idx,
515       c10::FastMap<const Value*, uint32_t>& value_to_index);
516 
517   void prepareFunctionsAndConstants(
518       Block* block,
519       const AliasDb& alias_db,
520       c10::FastMap<const Value*, uint32_t>& value_to_index);
521 
522   // Recursively traverse the graph and attach SR metadata
523   // to the prim::fork nodes as additional attributes
524   void attachNodeMetadata(Block* block);
525 
526   // Recurses on sub-blocks and populates the array of ProcessedNodes
527   // Returns (number of nodes processed, number of blocks processed)
528   size_t prepareStaticNodeInfos(
529       Block* block,
530       const c10::FastMap<const Value*, uint32_t>& value_to_index,
531       const AliasDb& alias_db,
532       size_t node_idx = 0);
533 
534   // Initialize various attributes that the memory planner will need.
535   // To be called at the tail of the ctor.
536   void prepareForMemoryPlanner();
537 
538   StaticModuleOptions opts_;
539   // metadata that is stored in IR nodes as attribute
540   at::intrusive_ptr<jit::StaticRuntimeMetadata> sr_metadata_;
541   std::shared_ptr<torch::jit::Graph> graph_;
542   std::optional<torch::jit::Module> module_;
543   std::optional<c10::FunctionSchema> schema_;
544   std::unique_ptr<StaticRuntime> cached_runtime_;
545 
546   // Bookkeeping for creating new StaticRuntime instances
547   // IValue table (defined by prim::Constant nodes)
548   std::vector<IValue> constants_;
549   // The functions to be called by corresponding ProcessedNode.
550   std::vector<ProcessedFunction> functions_{};
551   // A list of pre-processed nodes from which ProcessedNode are created per
552   // StaticRuntime instance.
553   std::vector<StaticNodeInfo> nodes_;
554   // Indices of graph outputs in the single values array.
555   std::vector<uint16_t> output_indices_;
556 
557   size_t num_intermediate_values_ = 0;
558 
559   // Includes self if module_ != std::nullopt.
560   // Note that we might have num_inputs_ == 0 even if the schema has a `self`
561   // argument. In this case, `self` isn't used in the graph, but the schema
562   // includes it anyways to be consistent with the JIT interpreter.
563   size_t num_inputs_;
564   // See `BlockInfo` definition. The blocks are stored in depth-first order.
565   c10::FastMap<Block*, BlockInfo> block_infos_;
566   size_t value_buffer_size_ = 0;
567 };
568 
569 // `BlockRunner` contains the core runtime logic. Each block runner
570 // corresponds to one block in the graph and has its own memory planner.
571 // `StaticRuntime` will initialize all `BlockRunner`s
572 // upon construction. Each block runner only directly executes nodes from its
573 // block. Special ops with sub-blocks like `prim::If` may have
574 // `BlockRunner`s stored in their `ProcessedNode`s; these
575 // sub-blocks get executed in the op's implementation.
576 // `StaticRuntime` stores a vector of IValues that all
577 // `BlockRunner`s share. This vector is used to store all
578 // constants, inputs, and intermediate tensors.
579 class TORCH_API BlockRunner {
580  public:
581   BlockRunner(
582       const StaticModule& sm,
583       IValue* values,
584       Block* block,
585       torch::jit::TaskLauncher* launcher,
586       bool is_root_block = false);
587   BlockRunner(BlockRunner&&) noexcept;
588   BlockRunner& operator=(BlockRunner&&) = delete;
589   ~BlockRunner();
590 
591   C10_DISABLE_COPY_AND_ASSIGN(BlockRunner);
592 
593   using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
594   c10::IValue operator()(
595       const std::vector<c10::IValue>& args,
596       const KeywordArgs& kwargs = KeywordArgs());
597   c10::IValue operator()(
598       std::vector<c10::IValue>&& args,
599       const KeywordArgs& kwargs = KeywordArgs());
600 
601   c10::intrusive_ptr<c10::ivalue::Future> runAsync(
602       const std::vector<c10::IValue>& args,
603       const KeywordArgs& kwargs);
604 
605   c10::intrusive_ptr<c10::ivalue::Future> runAsync(
606       std::vector<c10::IValue>&& args,
607       const KeywordArgs& kwargs);
608 
609   void benchmark(
610       const std::vector<std::vector<c10::IValue>>& args_list,
611       const std::vector<KeywordArgs>& kwargs_list,
612       const uint32_t warmup_runs,
613       const uint32_t main_runs,
614       bool print_per_node_time = false,
615       bool generate_ai_pep_output = false);
616 
617   struct IndividualMetrics {
618     float setup_time{0.0};
619     float memory_alloc_time{0.0};
620     float memory_dealloc_time{0.0};
621     float output_dealloc_time{0.0};
622     float first_iter_time{0.0};
623     float total_time{0.0};
624     size_t out_nodes_count{0};
625     size_t total_nodes_count{0};
626     std::vector<float> time_per_node;
627     std::unordered_map<std::string, float> time_per_node_type;
628     std::unordered_map<std::string, float> percent_per_node_type;
629     std::unordered_map<std::string, int> instances_per_node_type;
630     std::unordered_set<std::string> out_nodes;
631     std::unordered_set<std::string> native_nodes;
632   };
633 
634   IndividualMetrics benchmark_individual_ops(
635       const std::vector<std::vector<c10::IValue>>& args_list,
636       const std::vector<KeywordArgs>& kwargs_list,
637       const uint32_t warmup_runs,
638       const uint32_t main_runs);
639 
640   // Input is readwrite
Input(uint32_t i)641   IValue& Input(uint32_t i) {
642     TORCH_DCHECK_LT(i, block_info_.num_inputs());
643     return values_[i + block_info_.block_inputs_idx()];
644   }
645 
646   // Output is readonly. The writing process happens inside ProcessedNodes
Output(uint32_t i)647   C10_NODISCARD const IValue& Output(uint32_t i) const {
648     DCHECK(i < outputs_.size());
649     return *outputs_[i];
650   }
651 
outputs()652   const std::vector<IValue*> outputs() const {
653     return outputs_;
654   }
655 
nodes()656   const std::vector<ProcessedNode>& nodes() const {
657     return nodes_;
658   }
659 
nodes()660   std::vector<ProcessedNode>& nodes() {
661     return nodes_;
662   }
663 
node_ptrs()664   graph_node_list node_ptrs() const {
665     return block_info_.node_ptrs();
666   }
667 
graph()668   const Graph& graph() const {
669     return static_module_.graph();
670   }
671 
get_memory_planner()672   const MemoryPlanner* get_memory_planner() const {
673     return planner_.get();
674   }
675 
676   bool check_for_memory_leak(
677       bool output_returned = true,
678       bool recurse_on_sub_blocks = false);
679 
680   // WARNING: Deallocate managed output tensors.  A client receiving Static
681   // Runtime-managed Tensors needs to be very careful to call
682   // `StaticRuntime::deallocateOutputTensors` after all references of output
683   // Tensors are gone.
684   void deallocateOutputTensors();
685 
686   bool checkOutputTensorMemoryLeaks();
687 
688   bool isManagedOutputTensor(const IValue& ivalue) const;
689   bool isManagedOutputTensorValue(const Value* value) const;
690 
691   void disableManageOutputTensors();
692 
693   // This is the fallback path taken if we can't construct the memory planner
694   // on the first iteration.
695   // IMPORTANT: Nothing here should be able to throw!!!
696   // This function can be called from the (implicitly) `noexcept` destructor
697   // of Deallocator, meaning that std::terminate will be called
698   // if any exception escapes. Even if resetMemory and ~Deallocator were
699   // `noexcept(false)`, it's possible that when ~Deallocator is called, the
700   // stack is already unwinding, so there's still danger of calling
701   // std::terminate.
702   void resetMemory() noexcept;
703 
704  private:
705   // A helper object that invokes memory planner deallocation code
706   // when destructed.
707   class Deallocator {
708    public:
Deallocator(BlockRunner & block_runner)709     explicit Deallocator(BlockRunner& block_runner)
710         : block_runner_(block_runner) {}
711 
712     Deallocator(Deallocator&&) = default;
713     Deallocator(const Deallocator&) = default;
714     Deallocator& operator=(const Deallocator&) = delete;
715     Deallocator& operator=(Deallocator&&) = delete;
716     ~Deallocator();
717 
setFinished()718     void setFinished() {
719       finished_ = true;
720     }
721 
722    private:
723     void cleanupImpl();
724 
725     bool finished_ = false;
726     BlockRunner& block_runner_;
727   };
728 
729   template <typename IValueList>
730   c10::IValue run_impl(IValueList&& args, const KeywordArgs& kwargs);
731 
732   template <typename IValueList>
733   c10::IValue run_impl_record_functions(
734       IValueList&& args,
735       const KeywordArgs& kwargs);
736 
737   template <typename IValueList>
738   c10::intrusive_ptr<c10::ivalue::Future> run_impl_async(
739       IValueList&& args,
740       const KeywordArgs& kwargs);
741 
742   template <typename IValueList>
743   c10::intrusive_ptr<c10::ivalue::Future> run_impl_record_functions_async(
744       IValueList&& args,
745       const KeywordArgs& kwargs);
746 
747   // helper method for copying input args/kwargs into inputs_
748   template <typename IValueList>
749   void set_inputs(IValueList&& args, const KeywordArgs& kwargs);
750 
751   // Set Input(idx) to args[idx]. Invoked by set_inputs. Copies or moves
752   // depending on overload.
753   void set_arg(const size_t idx, std::vector<IValue>&& args);
754   void set_arg(const size_t idx, const std::vector<IValue>& args);
755 
756   // Set Input(idx) to arg. Always copies. Used for kwargs.
757   void set_arg(const size_t idx, const IValue& arg);
758 
759   bool fast_check_and_correct_overlap_with(
760       ProcessedNode& n,
761       c10::IValue& tensor_ival);
762   void verify_and_correct_memory_overlap(ProcessedNode& n);
763 
764   // clean up owning refs of input IValues
clean_up_input_ivalues()765   void clean_up_input_ivalues() noexcept {
766     for (const auto idx : c10::irange(block_info_.num_inputs())) {
767       values_[idx + inputs_begin_] = IValue();
768     }
769   }
770 
771   void clean_up_intermediate_ivalues() noexcept;
772 
773   IValue move_outputs_to_tuple(uint32_t num_outputs);
774 
775   void create_memory_planner();
776 
777   float benchmark_model(
778       const std::vector<std::vector<c10::IValue>>& args_list,
779       const std::vector<KeywordArgs>& kwargs_list,
780       const uint32_t warmup_runs,
781       const uint32_t main_runs);
782 
783   void display_nodes(
784       const std::vector<c10::IValue>& args,
785       const KeywordArgs& kwargs);
786 
787   const StaticModule& static_module_;
788   const BlockInfo& block_info_;
789 
790   const bool is_root_block_;
791   // Cache this so we don't have to call static_module_.first_input_is_self()
792   const bool first_input_is_self_;
793   // Index of the start of this blocks inputs in the shared values_ array.
794   const uint16_t inputs_begin_;
795 
796   bool manage_output_tensors_enabled_ = false;
797   std::unique_ptr<MemoryPlanner> planner_;
798   // [Shared values array]
799   // ProcessedNodes reference their inputs and outputs with
800   // offsets into this array, which saves memory.
801   // All BlockRunners share the same array. The layout is as
802   // follows:
803   // [constants][block_0][block_1]...[block_N]
804   // Note that constants from all blocks are pooled together at the start.
805   // The block ordering is depth-first.
806   // Each block is further divided into inputs and intermediates:
807   // [block_i] = [inputs_i][intermediates_i]
808   // Each BlockRunner knows where its inputs start. Each ProcessedNode
809   // knows how to find the indices of its outputs/inputs in this array.
810   IValue* values_;
811 
812   std::vector<IValue*> outputs_;
813   std::vector<ProcessedNode> nodes_;
814 };
815 
816 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
817 class TORCH_API StaticNodeInfo {
818  public:
819   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
820   StaticNodeInfo(
821       Node* n,
822       ProcessedFunction* fn,
823       ProcessedNodeInputs inputs,
824       uint16_t outputs_offset);
825 
node()826   Node* node() const {
827     return node_;
828   }
829 
num_outputs()830   size_t num_outputs() const {
831     DCHECK(fn_ != nullptr);
832     return fn_->num_outputs();
833   }
834 
has_out_variant()835   bool has_out_variant() const {
836     return fn_->kind() == ProcessedFunction::Kind::kOutVariant;
837   }
838 
839  private:
840   friend class ProcessedNode;
841 
842   Node* node_;
843   const ProcessedFunction* fn_;
844   ProcessedNodeInputs inputs_;
845   uint16_t outputs_offset_;
846 };
847 
num_nodes()848 inline size_t BlockInfo::num_nodes() const {
849   return nodes_.size();
850 }
851 
852 /*
853   ProcessedNodeMetadata class wraps the possible metadata
854   for ProcessedNode. Depending upon the nature of op, processedNode
855   can have one of the below possibilities of metadata:
856   - prim::If/prim::Loop ops contains block_runners_ as their metadata
857   - prim::fork op contains TaskLauncher (std::function) responsible for
858     execution of forked subgraph
859 */
860 class TORCH_API ProcessedNodeMetadata {
861  public:
ProcessedNodeMetadata(std::vector<BlockRunner> runners,torch::jit::TaskLauncher * launcher)862   ProcessedNodeMetadata(
863       std::vector<BlockRunner> runners,
864       torch::jit::TaskLauncher* launcher)
865       : block_runners_(std::move(runners)), launcher_(launcher) {}
866 
ProcessedNodeMetadata()867   ProcessedNodeMetadata() : launcher_(nullptr) {}
868 
869   // deleted copy ctor/assignment as standard containers (vector) always
870   // have copy constructors, but their instantiation is not well-formed
871   // if the contained type (BlockRunner) is not copyable
872   ProcessedNodeMetadata(const ProcessedNodeMetadata&) = delete;
873   ProcessedNodeMetadata& operator=(const ProcessedNodeMetadata&) = delete;
874 
block_runners()875   std::vector<BlockRunner>& block_runners() {
876     return block_runners_;
877   }
878 
set_block_runners(std::vector<BlockRunner> runners)879   void set_block_runners(std::vector<BlockRunner> runners) {
880     block_runners_ = std::move(runners);
881   }
882 
set_launcher(torch::jit::TaskLauncher * launcher)883   void set_launcher(torch::jit::TaskLauncher* launcher) {
884     launcher_ = launcher;
885   }
886 
launcher()887   torch::jit::TaskLauncher* launcher() {
888     return launcher_;
889   }
890 
891  private:
892   std::vector<BlockRunner> block_runners_;
893   torch::jit::TaskLauncher* launcher_;
894 };
895 
896 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
897 class TORCH_API ProcessedNode {
898  public:
899   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
900   ProcessedNode() = default;
901 
ProcessedNode(const StaticNodeInfo & other,IValue * values)902   ProcessedNode(const StaticNodeInfo& other, IValue* values)
903       : node_(other.node_),
904         fn_(other.fn_),
905         inputs_(other.inputs_),
906         outputs_offset_(other.outputs_offset_),
907         values_(values),
908         metadata_(nullptr) {}
909 
910   // These should be noexcept, but some Android build is failing
911   // saying the noexcept specification doesn't match the calculated
912   // one. Maybe std::variant is throwing it off?
913   ProcessedNode(ProcessedNode&&) = default;
914 
915   ProcessedNode(const ProcessedNode&) = delete;
916   ProcessedNode& operator=(const ProcessedNode& other) = delete;
917   ProcessedNode& operator=(ProcessedNode&&) = default;
918 
919   void run();
920 
node()921   Node* node() const {
922     return node_;
923   }
924 
925   // Input is readonly
Input(uint32_t i)926   C10_NODISCARD const IValue& Input(uint32_t i) const {
927     return values_[inputs_[i]];
928   }
929 
930   // Output is readwrite
Output(uint32_t i)931   IValue& Output(uint32_t i) {
932     DCHECK(i < num_outputs());
933     return values_[outputs_offset_ + i];
934   }
935 
Output(uint32_t i)936   C10_NODISCARD const IValue& Output(uint32_t i) const {
937     DCHECK(i < num_outputs());
938     return values_[outputs_offset_ + i];
939   }
940 
num_outputs()941   uint32_t num_outputs() const {
942     DCHECK(fn_ != nullptr);
943     return static_cast<uint32_t>(fn_->num_outputs());
944   }
945 
outputs()946   C10_NODISCARD c10::ArrayRef<const IValue> outputs() const {
947     return c10::ArrayRef<const IValue>(
948         values_ + outputs_offset_, num_outputs());
949   }
950 
num_inputs()951   C10_NODISCARD uint16_t num_inputs() const {
952     return inputs_.size();
953   }
954 
955   std::vector<IValue> inputs_ivalue_vec() const;
956 
has_out_variant()957   bool has_out_variant() const {
958     return fn_->kind() == ProcessedFunction::Kind::kOutVariant;
959   }
960 
has_native()961   bool has_native() const {
962     return fn_->kind() == ProcessedFunction::Kind::kNativeFunction;
963   }
964 
965 #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
get_op_name()966   const char* get_op_name() const {
967     return node_->kind().toQualString();
968   }
969 #endif
970 
check_outputs_for_memory_overlap()971   bool check_outputs_for_memory_overlap() const {
972     return fn_->checkMemoryOverlap();
973   }
974 
set_outputs_memory_overlap_detected()975   void set_outputs_memory_overlap_detected() {
976     overlap_detected_ = true;
977   }
978 
outputs_memory_overlap_detected()979   bool outputs_memory_overlap_detected() {
980     return overlap_detected_;
981   }
982 
983   bool check_and_correct_overlap_with(
984       const at::Tensor& input,
985       c10::IValue& output);
986   void verify_and_correct_memory_overlap();
987 
set_values(IValue * values)988   void set_values(IValue* values) {
989     DCHECK(values_ == nullptr);
990     values_ = values;
991   }
992 
output_ivalue_index(uint16_t i)993   C10_NODISCARD uint16_t output_ivalue_index(uint16_t i) const {
994     DCHECK(i < num_outputs());
995     return outputs_offset_ + i;
996   }
997   // used in debug mode
998   bool verify_no_memory_overlap(bool force_check = false) const;
999 
1000   // returns pointer to ProcessedNodeMetadata or nullptr if no object is owned
metadata()1001   ProcessedNodeMetadata* metadata() {
1002     return metadata_.get();
1003   }
1004 
1005   // attach block_runner to metadata of ProcessedNode
set_metadata(std::vector<BlockRunner> block_runners)1006   void set_metadata(std::vector<BlockRunner> block_runners) {
1007     if (metadata_ == nullptr) {
1008       metadata_ = std::make_unique<ProcessedNodeMetadata>();
1009     }
1010     metadata_->set_block_runners(std::move(block_runners));
1011   }
1012 
1013   // attach TaskLauncher to metadata of ProcessedNode
set_metadata(torch::jit::TaskLauncher * launcher)1014   void set_metadata(torch::jit::TaskLauncher* launcher) {
1015     if (metadata_ == nullptr) {
1016       metadata_ = std::make_unique<ProcessedNodeMetadata>();
1017     }
1018     metadata_->set_launcher(launcher);
1019   }
1020 
1021  private:
1022   C10_NODISCARD bool verify_outputs_dont_overlap_each_other() const;
1023 
1024   C10_NODISCARD bool verify_inputs_dont_overlap_outputs(bool force_check) const;
1025 
1026   Node* node_;
1027   const ProcessedFunction* fn_;
1028   ProcessedNodeInputs inputs_;
1029   uint16_t outputs_offset_;
1030   bool overlap_detected_{false};
1031   IValue* values_ = nullptr; // unowned
1032   // Metadata for ProcessedNode.
1033   // 1. prim::If/Loop nodes contains sub-blocks as metadata
1034   // 2. prim::fork nodes contains custom executor for async execution
1035   std::unique_ptr<ProcessedNodeMetadata> metadata_;
1036 };
1037 
1038 // `StaticRuntime` is the owner of the array of IValues (used for constants,
1039 // inputs, and intermediate tensors) that all `BlockRunner`s share.
1040 // Upon construction, it initializes all block runners. `operator()` simply
1041 // forwards the inputs to the top-level block runner. Each `StaticRuntime`
1042 // instance corresponds to one `StaticModule`. Multiple `StaticRuntime`
1043 // instances can be created; this is useful for multi-threaded execution, since
1044 // `operator()` is not thread-safe.
1045 class TORCH_API StaticRuntime {
1046  public:
1047   explicit StaticRuntime(const StaticModule& sm);
1048 
1049   using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
1050   c10::IValue operator()(
1051       const std::vector<c10::IValue>& args,
1052       const KeywordArgs& kwargs = KeywordArgs());
1053   c10::IValue operator()(
1054       std::vector<c10::IValue>&& args,
1055       const KeywordArgs& kwargs = KeywordArgs());
1056 
1057   // runAsync performs inline execution of graph on
1058   // caller thread and async execution on taskLauncher
1059   // If no custom taskLauncher is specified, execution is done
1060   // on inter-op thread pool.
1061   c10::intrusive_ptr<c10::ivalue::Future> runAsync(
1062       const std::vector<c10::IValue>& args,
1063       const KeywordArgs& kwargs = KeywordArgs(),
1064       torch::jit::TaskLauncher taskLauncher = at::launch);
1065 
1066   c10::intrusive_ptr<c10::ivalue::Future> runAsync(
1067       std::vector<c10::IValue>&& args,
1068       const KeywordArgs& kwargs = KeywordArgs(),
1069       torch::jit::TaskLauncher taskLauncher = at::launch);
1070 
1071   bool check_for_memory_leak(bool output_returned = true);
1072   bool checkOutputTensorMemoryLeaks();
1073 
1074   void deallocateOutputTensors();
1075   bool isManagedOutputTensor(const IValue& ivalue) const;
1076   void disableManageOutputTensors();
1077 
1078   // Gets the top-level memory planner. Used for testing.
1079   const MemoryPlanner* get_memory_planner() const;
1080 
1081   void benchmark(
1082       const std::vector<std::vector<c10::IValue>>& args_list,
1083       const std::vector<KeywordArgs>& kwargs_list,
1084       const uint32_t warmup_runs,
1085       const uint32_t main_runs,
1086       bool print_per_node_time = false,
1087       bool generate_ai_pep_output = false) {
1088     block_->benchmark(
1089         args_list,
1090         kwargs_list,
1091         warmup_runs,
1092         main_runs,
1093         print_per_node_time,
1094         generate_ai_pep_output);
1095   }
1096 
1097   using IndividualMetrics = BlockRunner::IndividualMetrics;
1098 
benchmark_individual_ops(const std::vector<std::vector<c10::IValue>> & args_list,const std::vector<KeywordArgs> & kwargs_list,const int warmup_runs,const int main_runs)1099   IndividualMetrics benchmark_individual_ops(
1100       const std::vector<std::vector<c10::IValue>>& args_list,
1101       const std::vector<KeywordArgs>& kwargs_list,
1102       const int warmup_runs,
1103       const int main_runs) {
1104     return block_->benchmark_individual_ops(
1105         args_list, kwargs_list, warmup_runs, main_runs);
1106   }
1107 
1108  private:
1109   // An array of IValues with unchanging size/data ptr.
1110   class IValueArray {
1111    public:
1112     IValueArray() = default;
IValueArray(size_t size)1113     explicit IValueArray(size_t size) : array_(allocate(size)), size_(size) {}
1114 
data()1115     IValue* data() const {
1116       return array_.get();
1117     }
1118 
size()1119     size_t size() const {
1120       return size_;
1121     }
1122 
1123    private:
1124     // NOLINTNEXTLINE(modernize-avoid-c-arrays)
1125     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
allocate(size_t size)1126     static std::unique_ptr<IValue[]> allocate(size_t size) {
1127       if (size) {
1128         return std::make_unique<IValue[]>(size);
1129       }
1130       return nullptr;
1131     }
1132 
1133     // NOLINTNEXTLINE(modernize-avoid-c-arrays)
1134     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
1135     std::unique_ptr<IValue[]> array_ = nullptr;
1136     size_t size_ = 0;
1137   };
1138 
1139   std::unique_ptr<BlockRunner> block_;
1140   // for execution of async operations present in graph
1141   torch::jit::TaskLauncher async_task_launcher_;
1142   IValueArray values_;
1143 };
1144 
1145 } // namespace torch::jit
1146