README.md
1> :warning: **This is an experimental feature**
2
3# Static Runtime
4
5Static Runtime is an optimized CPU inference runtime for PyTorch models.
6It can be used as a drop-in replacement for the TorchScript JIT interpreter
7in either C++ or Python.
8
9Static Runtime is mainly useful if the following conditions are met:
101. The model has very little control flow.
112. PyTorch overhead (tensor creation, etc) accounts for
12a non-trivial fraction of the model's runtime. In particular, if
13tensor allocation consumes a significant amount of time, Static
14Runtime can help. Memory for intermediate tensors is coalesced into
15a single slab, so most dynamic allocations are avoided during
16inference.
173. Inference performance is extremely important.
18
19## Assumptions
20
21This is a list of current assumptions for use with
22this feature.
23
24- Inference only execution, CPU only
25- Static input dtypes
26- Static input shapes (the runtime supports dynamic shapes, but excessive dynamic shapes may degrade performance)
27
28## Threading model
29Static runtime supports two execution modes.
30
31Mode 1: single-threaded with no parallelism except for intra-op parallelism.
32For this mode, you can do either:
33```
34 // m is the TorchScript module
35 auto runtime = StaticRuntime(m, opts);
36 auto output = runtime.run(args, kwargs);
37```
38or
39```
40 auto mod = PrepareForStaticRuntime(m);
41 auto runtime = StaticRuntime(mod, opts);
42 auto output = runtime.run(args, kwargs);
43```
44Mode 2: similar to data parallelism, run the same model for different inputs
45on different threads at the same time. In this case, run
46`PrepareForStaticRuntime` to prepare the graph for Static Runtime. You
47should have one InferenceModule instance per model, and one Static Runtime instance
48per running thread. To avoiding creating StaticRuntime on the fly, use a
49synchronized stack (i.e. `boost::lockfree::stack`) to cache all the Static
50Runtime instances in your code.
51```
52 // initialization
53 auto mod = PrepareForStaticRuntime(m);
54 // 128 is good for most cases. Pick a number that works for you
55 boost::lockfree::stack<std::shared_ptr<StaticRuntime>,
56 boost::lockfree::fixed_sized<true>> pool(128);
57
58 // inference
59 std::shared_ptr<StaticRuntime> runtime = nullptr;
60 pool.pop(runtime);
61 if (!runtime) {
62 runtime = std::make_shared<StaticRuntime>(mod, opts);
63 }
64 auto output = runtime->run(args, kwargs);
65 pool.push(runtime);
66```
67
68**In both modes, `StaticRuntime` may not be used after its associated `StaticModule` is destructed!**
69
70## Memory Planning
71Static runtime's memory planner does two things:
72
731) Coalesces internal allocations for tensor storage
742) Does static analysis to figure out how to efficiently re-use memory.
75
76### Standard Resizing
77Static runtime will record the space required for each intermediate managed tensor it sees
78on the first inference iteration. An intermediate tensor is *managed* if two conditions
79are satisfied:
80
811) The op that produces it has an out variant. Out variants are wrappers around ops that
82conceptually transform the op's signature from `Tensor some_op(const Tensor& some_arg)`
83into `void some_op(Tensor& output, const Tensor& some_arg)`. Out variants are registered
84with static runtime via the `REGISTER_OPERATOR_FUNCTOR` macro; see "Registering Ops" for
85more info.
86
872) The tensor does not alias a graph output. Output tensors are handled separately by
88the memory planner, see "Managed Output Tensors" for details.
89
90With this algorithm, static analysis is used to group the tensors in `StorageGroup`s.
91Tensors in the same storage group share memory, and two tensors can be in the same storage group
92if their lifetimes do not overlap.
93
94On the subsequent iterations, static runtime allocates the tensor buffer at the start of the run.
95The amount of memory allocated is `sum([max(tensor.size()) for tensor in storage_groups])`.
96
97If a tensor needs to be bigger than the allocated space on subsequent runs, a dynamic allocation
98will occur. This is why dynamic shapes will degrade performance. With the standard resizing
99strategy, static runtime will record the new largest tensor size in each storage group at the
100end of the iteration and allocate a buffer that is possibly bigger on the next iteration.
101
102### Managed Output Tensors
103
104`StaticRuntime` can optionally manage output tensors via the `manage_output_tensors` option in `StaticModuleOptions`.
105When this flag is turned on, we coalesce allocations for output tensors together. Note that the buffer containing
106output tensors is separated from the one containing intermediate tensors. The former needs to live past the end
107of the inference run, but the latter needs deallocated at the end of the run.
108
109Under the hood, we store a refcounted pointer to the output arena in each returned `Tensor`. The arena is destroyed
110explicitly.
111
112## Registering Ops
113Static runtime has three op execution modes:
114
1151) Out variants: ops that return tensors which we may be able to manage. See "Memory Planning" for more
116details. Out variants are registered via the `REGISTER_OPERATOR_FUNCTOR` macro in `ops.h`.
117```
118REGISTER_OPERATOR_FUNCTOR(
119 aten::op_name,
120 aten_op_name, // This macro generates a struct, this field names it
121 [](torch::jit::Node* n) -> SROperator {
122 // This mechanism lets us support a subset of schemas
123 if (n->matches(some_schema)) {
124 return some_overload;
125 } else if (n->matches(another_schema)) {
126 return another_overload;
127 }
128 return nullptr;
129 })
130```
131
132A `SROperator` is a type alias for `std::function<void(ProcessedNode*)>`. See "Implementation Details" for more
133details on `ProcessedNode`.
134
1352) Native functions: just like out variants, except their outputs cannot be managed. This is because the op's return
136type is not a tensor or it is a view op (returns a tensor alias instead of a new tensor). Registration is done with
137`REGISTER_NATIVE_OPERATOR_FUNCTOR`. This macro is used in the same way as `REGISTER_OPERATOR_FUNCTOR`.
138
1393) JIT fallback: static runtime has no implementation for this op, so the implementation that the JIT interpreter uses
140is selected instead.
141
142When loading a model, ops are selected for each `torch::jit::Node` in the graph as follows:
143
1441) If an out variant is registered, pass the node to the function that produces the `SROperator`. If
145the result is not `nullptr`, use that op.
1462) If a native function is registered, pass the node to the function that produces the `SROperator`. If
147the result is not `nullptr`, use that op.
1483) Use the JIT implementation. Static runtime will throw an exception if it does not exist.
149
150## Implementation Details
151
152### Structure and Lifetime Details
153
154The following diagram shows the core data structure. An arrow from `A` to `B` means that
155`A` stores a reference to `B`. If the reference is unowned,
156`A` may not out live `B` or anything that `B` stores a reference to (directly or indirectly).
157If the reference is owned, the lifetimes of `A` and `B` are the same.
158```
159
160 IValue array◄────────────────┐─────────────────────────────────────────┐
161 ▲ │ Owns │ Owns
162 │ │ ┌───────────────────────────────►ProcessedNode───────►BlockRunner
163 │Owns │ │ │ │
164 │ Owns │ │ Owns │ │
165StaticModule◄───────────StaticRuntime───────────►BlockRunner────────►MemoryPlanner │ ▼
166 │ │ │ │ │ ...
167Owns│ │ │ │ │
168 ▼ │ │ │ │
169BlockInfo◄├───────────────────────────────────────────┘──────────────────┘ │
170 │ │
171 Owns│ │
172 ▼ │
173ProcessedFunction ◄─────────────────────────────────────────────────────────────────────────────┘
174```
175
176Each class is described in detail below.
177
178### `StaticModule` and `StaticRuntime`
179
180`StaticModule`s are constructed from `torch::jit::Module`s and can be used to construct `StaticRuntime`
181instances. Each `StaticModule` caches exactly one `StaticRuntime` instance - it is lazily initialized when
182you access it via `runtime()`.
183
184`StaticModule::operator()` can be used directly to make predictions. Under the hood, this method just
185forwards to the cached runtime's `StaticRuntime::operator()`. One upshot of this behavior is that
186`StaticModule::operator()` is not thread-safe.
187
188The way to use static runtime in a multi-threaded context is to give each thread its own `StaticRuntime`
189instance. New runtime instances can be created directly (`StaticRuntime(static_module)`) or `clone()`'d from
190an existing runtimes.
191
192`StaticModule` takes a set of options that control the behavior of the runtime instances that it spawns;
193see `StaticModuleOptions` for more details.
194
195Internally, `StaticRuntime` owns an array of `IValue`s that is referenced from all `BlockRunner`s and
196`ProcessedNode`s. All values that are generated at runtime are stored in this array.
197
198### `BlockRunner`
199
200A `BlockRunner` represents a single sub-block in the graph. Every graph has at least one `BlockRunner`
201corresponding to the top-level block, and `StaticRuntime` starts its inference run by invoking
202`(*top_level_block)(args, kwargs)`. Each `BlockRunner` has its own `MemoryPlanner` and set of `ProcessedNode`s.
203Special nodes that have sub-blocks (like `prim::If`) might own `BlockRunner`s. The op implementations are responsible
204for invoking `BlockRunner`s corresponding to sub-blocks.
205
206### `MemoryPlanner`
207
208See the "Memory Planning" section. `MemoryPlanner` is an abstract base class. Each sub-class implements a different
209memory planning algorithm.
210
211In addition to the memory planning we do for tensors, `MemoryPlanner` encapsulates a few other optimizations.
212
213* Managed output tensors (see "Managed Output Tensors")
214* Borrowed `IValue`s; ops that just unpack their inputs (e.g. `dict_unpack`) might produce weak-references to
215avoid refcount bumps, the `MemoryPlanner` needs to destroy these borrows appropriately.
216
217### `ProcessedNode` and `ProcessedFunction`
218
219`ProcessedNode` is our abstraction for a single op. Each `ProcessedNode` stores an unowned reference to `StaticRuntime`'s
220`IValue` array. It knows how to map input/output indices to indices in this array (so `processed_node->output(i)` returns
221a reference to `ivalue_array[some_set_of_indices[i]]`)
222
223Each `ProcessedNode` stores a `ProcessedFunction`, which represents the actual op to execute. `ProcessedFunction`s are initialized
224upon `StaticModule` construction according to the out variant/native/JIT fallback lookup rules described in "Registering Ops".
225**Note that all `ProcessedFunction`s are shared amongst all runtime instances**, so all `ProcessedFunction`s must be thread-safe.
226
227### `ProcessedNodeMetadata`
228
229`ProcessedNodeMetadata` holds various "extra" fields on behalf of `ProcessedNode`. Typically, this field is unused. But a few ops need extra machinery to work:
230* `prim::If` operations have two `BlockRunner`s for the execution of true and false sub-blocks depending upon the condition check.
231* `prim::Loop` operations have a `BlockRunner` for the execution of the looping sub-block.
232* `prim::fork` operations have `torch::jit::TaskLauncher` (`std::function<void(std::function<void()>)>`) responsible for forked graph execution.
233
234### Asynchronous Execution
235
236The `StaticRuntime::runAsync()` API allows the execution of asynchronous operations on the `TaskLauncher` passed as arguments.
237`StaticRuntime::runAsync()` performs inline execution of the parent graph on the caller thread. Asynchronous operations like `prim::fork` are executed
238on the launcher passed in. In the case that no launcher is provided, the execution happens via `at::launch`, i.e. on the inter-op thread pool.
239