xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/OVERVIEW.md (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# JIT Technical Overview
2
3The JIT can run and optimize PyTorch programs separate from the Python interpreter. This overview is organized into sections that go over different independent components:
4
51. Core Program Representation -  The JIT executes TorchScript, a subset of Python. This section describes how TorchScript programs are represented in the JIT, and serves as the interchange format between components of the JIT.
62. Generating Programs - TorchScript programs can be created either through tracing Python code or through directly writing TorchScript. This section describes how Models are created from these frontends.
73. Executing Programs - Once created, TorchScript models are optimized and run. Since this is a just-in-time compiler, programs are optimized as they are executed, so this section describes both how programs are optimized and how they get run.
84. Saving Programs - TorchScript is often created in Python and then used from C++. This section describes how the save and load process works.
95. Python Bindings - TorchScript code is normally created and used from Python, so this section describes how the Python components interact with the code in this directory.
10
11For concepts that are actual classes in the JIT, we use capitalized words in code font, e.g. `Graph` or `Value`.
12
13Sections start with a reference to the source file where the code related to the section resides.
14
15## Table of Contents
16
17<!-- toc -->
18
19- [Core Program Representation](#core-program-representation)
20  - [Modules](#modules)
21  - [Parameters](#parameters)
22  - [Method](#method)
23  - [FunctionSchema](#functionschema)
24  - [Graph](#graph)
25  - [Node](#node)
26  - [Block](#block)
27    - [If](#if)
28    - [Loops](#loops)
29    - [With](#with)
30  - [Value](#value)
31  - [Type](#type)
32- [Generating Programs](#generating-programs)
33  - [Tracer](#tracer)
34  - [Script](#script)
35  - [Tree](#tree)
36  - [Tree Views](#tree-views)
37  - [frontend.py](#frontendpy)
38  - [Lexer](#lexer)
39  - [Tokens](#tokens)
40  - [Parser](#parser)
41  - [IR Emitter](#ir-emitter)
42  - [SugaredValue](#sugaredvalue)
43  - [Resolver](#resolver)
44  - [Environment](#environment)
45  - [Conversion To SSA](#conversion-to-ssa)
46  - [Exit Transform](#exit-transform)
47  - [Python-Compiler Interaction](#python-compiler-interaction)
48- [Executing Programs](#executing-programs)
49  - [Evaluation Semantics](#evaluation-semantics)
50  - [IValue](#ivalue)
51  - [Operation](#operation)
52  - [Operator](#operator)
53  - [Interpreter](#interpreter)
54  - [Graph Executor](#graph-executor)
55    - [Specialization](#specialization)
56    - [Dynamic Shapes Options](#dynamic-shapes-options)
57    - [Pre-derivative Optimization](#pre-derivative-optimization)
58    - [Required Passes](#required-passes)
59    - [Derivative Preserving Optimization](#derivative-preserving-optimization)
60    - [Post-derivative optimization](#post-derivative-optimization)
61    - [Derivate Splitting](#derivate-splitting)
62    - [Fusers](#fusers)
63    - [Disabling Optimizations](#disabling-optimizations)
64  - [JIT Logging](#jit-logging)
65  - [JIT Optimization Limiter](#jit-optimization-limiter)
66  - [DifferentiableGraphOp](#differentiablegraphop)
67  - [Handling Mutability](#handling-mutability)
68    - [Aliasing and mutation in the PyTorch API](#aliasing-and-mutation-in-the-pytorch-api)
69    - [Aliasing and mutation annotations in FunctionSchema](#aliasing-and-mutation-annotations-in-functionschema)
70    - [Marking custom ops as side-effectful](#marking-custom-ops-as-side-effectful)
71    - [Alias Analysis in the IR](#alias-analysis-in-the-ir)
72    - [Writing optimization passes with `AliasDb`](#writing-optimization-passes-with-aliasdb)
73- [Profiling Programs](#profiling-programs)
74- [Saving Programs](#saving-programs)
75- [Testing Programs](#testing-programs)
76  - [Testing Autodiff](#testing-autodiff)
77- [Python Printer](#python-printer)
78- [Python Bindings](#python-bindings)
79  - [Graph Manipulation](#graph-manipulation)
80
81<!-- tocstop -->
82
83# Core Program Representation
84
85## Modules ##
86
87[api/module.h](api/module.h)
88
89At the top level, all TorchScript programs are represented as a Module. Modules contain:
90
91* named Parameters - `Tensors` used in training such as `weight` or `bias`
92* named Buffers - `Tensors` that are part of the training state of a module but do not appear in module.parameters() and do not participate in gradient descent.
93* named sub-Modules - used for code organization.
94* named Attributes - all other attributes that are not in the above three categories. Typically used for configuration and are not saved/restored in the modules `state_dict`.
95* named Methods - functions that can be run on the module such as `forward`
96
97This mirrors the `nn.Module` objects used in Python. All TorchScript code is a member of some module. This includes pure functions such as those created by annotating a Python function with `@torch.jit.script`, which are represented internally as a Module that has a single method `forward` that contains the implementation of the function.
98
99## Parameters ##
100
101[api/module.h](api/module.h)
102
103Modules contain Parameter objects, which simply hold a "slot" where a `Tensor` can be placed. These `Tensors` are accessible by the Methods of the Module or the parent Module.
104
105## Method ##
106
107[api/module.h](api/module.h)
108
109A Method is a piece of TorchScript code that takes a number of arguments and produces an output value. Methods have several subcomponents. A FunctionSchema describes the types and names of the input arguments and return value. A list of `member_inputs` describes which Parameters are accessed by the method (this is blank for pure functions). A `Graph` object describes the actual code inside the method. The Method also maintains a GraphExecutor which is used to actually execute the `Graph` that defines the method.
110
111The `Graph` inside the Method is a pure function. The Parameters used by the Method are added as additional inputs to this graph before it is run. This allows the GraphExecutor to treat method inputs and method parameters the same for the purposes of optimization and execution, simplifying the process for executing programs.
112
113Methods also contain helper functions for inserting calls to the Method from other Method objects.
114
115## FunctionSchema ##
116
117[aten/src/ATen/core/function_schema.h](../../../aten/src/ATen/core/function_schema.h)
118
119Each Method has a FunctionSchema that describes the Types of the arguments and return values of a function. Operators (builtin primitives that are called by the Interpreter) also have FunctionSchema. FunctionSchema are analogous to a function _declaration_ in C++. They describe how to call the function but do not provide an implementation.
120
121## Graph ##
122
123[ir.h](ir/ir.h)
124
125Graphs are the root of the intermediate representation (IR) used to define the implementation of TorchScript functions. If you are familiar with [LLVM](llvm.org), they are analogous to an `llvm::Function` object. A `Graph` is composed of `Nodes`, `Blocks`, and `Values`. `Nodes` are instructions (e.g. do a matrix multiply). `Nodes` are organized into `Blocks` of sequentially executed `Nodes`. Each `Node` produces a list of output `Values`, and also consumes a list of input `Values`. As an example, a user may write the following TorchScript code:
126
127```python
128@torch.jit.script
129def f(a, b):
130  c = a + b
131  d = c * c
132  e = torch.tanh(d * c)
133  return d + (e + e)
134```
135
136The frontend, described later in this document will turn into a `Graph`:
137```
138graph(%0 : Double(2),
139      %1 : Double(2)):
140  %2 : int = prim::Constant[value=1]()
141  %3 : Double(2) = aten::add(%0, %1, %2)
142  %4 : Double(2) = aten::mul(%3, %3)
143  %5 : Double(2) = aten::mul(%4, %3)
144  %6 : Double(2) = aten::tanh(%5)
145  %7 : Double(2) = aten::add(%6, %6, %2)
146  %8 : Double(2) = aten::add(%5, %7, %2)
147  return (%8)
148```
149
150This is the canonical textual representation of the IR. You should be able to easily find (almost all) of the elements we discussed above.
151- `graph` is the `Graph`
152- `%x` are `Value`s
153- `%x : Double(2)` is a type annotation of `Value` `%x` (see below for a list of supported types).
154- `%x : T1, %y : T2 = namespace::name(%z, %w)` is a `Node` which represents the `namespace::name` operator (this name is usually referred to as the `Node`s _kind_). It takes `%z` and `%w` `Value`s as inputs, and returns two outputs (`%x`, `%y`) of types `T1` and `T2` respectively.
155
156Finally, nodes can have extra pieces of information assigned to them, which are called _attributes_. You can see that it's used in the `prim::Constant` node, which returns the `value` attribute when it's called. There's a fixed list of types you can attach:
157- `int64_t`
158- `double`
159- `Tensor`
160- `Graph` (useful for e.g. slicing subgraphs that are meant to be fused)
161- `std::string`
162- and lists of them (not nested)
163
164Graphs in the JIT are in single-static assignment (SSA) form, meaning that each `Value` has precisely one defining `Node` that can be looked up directly from the `Value` (`Node* n = v.node()`).
165
166**Ownership Model** `Blocks`, `Nodes`, and `Values` are _owned_ by the `Graph` they appear in and may only appear in a single `Graph`. This is enforced by assertions. Creation and deletion of `Nodes` is done via `Graph` objects (e.g. `Graph::create`). Creation and Deletion of `Blocks` and `Values` is done via `Node` objects (e.g. `Node::addOutput`, `Node::addBlock`). The code also enforces certain consistency properties. For instance, `Node::destroy` removes a `Node`, but it is only valid to call this function if the `Values` produced by this `Node` are no longer used, which can be accomplished using other functions such as `Value::replaceAllUsesWith`.
167
168Because `Graph` owns all its `Nodes`, `Values`, and `Blocks`, these values are always passed around by raw pointer. Generally developers should not write code that holds `Value`, `Node`, or `Block` objects indefinitely without also holding a `shared_ptr` to their owning `Graph`.
169
170## Node ##
171
172[ir.h](ir/ir.h)
173
174A `Node` represents a single built-in operator such as a matrix multiply or a convolution. `NodeKind Node::kind()` identifies the operator the `Node` represents. Different operators (e.g. conv vs matrix-multiply) are represented by different kinds rather than via subclassing of `Node`, as one would find in LLVM. A `NodeKind` is a `Symbol` object, which is just an [interned string](https://en.wikipedia.org/wiki/String_interning) inside some namespace. Symbols can be created from strings, e.g. through `Symbol::fromQualString("aten::add")`, so there is not a closed set of `NodeKind` values that a `Node` might have. This design was chosen to allow the open registration of new operators and user-defined operators.
175
176>*Code in the JIT should always assume the universe of valid `Node` kinds is open and subject to be expanded.*
177
178This reflects the reality of the PyTorch operator library where there are already several hundred valid operators.
179
180Nodes produce output `Values` and take input `Values` as arguments. For instance, a matrix-multiply will take two input `Tensors` and produce one output `Tensor`. `Nodes` can produce multiple outputs. For instance `prim::TupleUnpack` splits a tuple into its components, so it has a number of outputs equal to the number of members of the tuple. Though `Nodes` may have multiple outputs, the number of outputs is _statically known_ for each `Node`. Operations which may produce a dynamic amount of results, e.g. splitting a `Tensor` into chunks of size 2, will be represented as an operator that results a list object.
181
182Because `Nodes` are not subclassed per-operator, it is very easy to construct invalid `Nodes`, e.g. by forgetting an input or an output, or by passing `Values` of the wrong Type. To help avoid this, `Graph` provides the method `Graph::insert` for constructing `Nodes` that guarantees `Nodes` have the correct setup. This method uses the database of registered Operators and their FunctionSchema to construct `Nodes` using that schema.
183
184
185PyTorch IR supports function overloading, which means that a single NodeKind may correspond to multiple operators. For example, the kind `aten::add` has the following overloads (`Scalar` means `float` or `int` in this case):
186- `aten::add(Tensor self, Tensor other) -> Tensor`
187- `aten::add(Tensor self, Scalar other) -> Tensor`
188- `aten::add(int self, int other) -> int`
189- `aten::add(float self, float other) -> float`
190
191For `Nodes` representing built-in Operators, the method `Node::schema` can also look up the FunctionSchema registered for that Operator.
192
193
194Each overload corresponds to a different `FunctionSchema` object. A `Node` can be queried for its schema using the `schema()` method (it will check the argument types, and will try to match one of the options for its `kind()`).
195
196Note that the chosen overload is not shown in any way in the textual output. If you're unsure which function a node resolves to, you might need to check the type annotations of its input values.
197
198
199Each node also has a set of attributes which are named integers, strings, floats, `Tensors`, subgraphs, or lists of these types. These are used by special primitive operators to encode additional data in the `Node`. For instance `prim::Constant` defines a compile-time constant value. For `Tensor` constants, it will have a single `Tensor` attribute with the name `attr::value` which contains the value of the constant.
200
201Attributes are _rarely used_. Operators like convolution or matrix-multiply have no attributes and take their arguments through the input list. This includes things that might be typically thought of as constants, like the stride of the convolution. In PyTorch, any of this information is potentially a dynamic property of the program so `Nodes` are always encoded in a way that allows these values to be dynamically determined. However, we recognize that many inputs are almost always constants, so we make it easy to quickly check if an input is constant and get its value with `c10::optional<IValue> Node::get(Symbol name)`, which returns an `IValue` (a concrete value for the input) in the case the node is constant and `nullopt` otherwise.
202
203## Block ##
204
205[ir.h](ir/ir.h)
206
207Nodes are organized into sequentially executed lists inside a `Block`. A `Node` is a member of precisely one `Block`. The `Graph` itself has a top-level `Graph::block()`, and control-flow nodes (`prim::If` and `prim::Loop`) also have sub-blocks. While it is possible to design a `Graph` representation that does not have a sequential order for nodes (i.e. a sea-of-nodes representation), we find it is much easier to debug and understand `Blocks` when there is a specific canonical order for all of the nodes. This does not preclude optimization passes from changing the order when it would improve performance, and the interpreter is allowed to execute `Nodes` out-of-order if the re-ordering preserves the semantics (much like an out-of-order processor does). Having the ordering ensures that graphs can always be easily printed, and that we can easily step through the execution of a graph.
208
209Values are Block-scoped. A `Value` is in scope for the remainder of the `Block` it is defined in, including in the sub-blocks of any `Node` defined after it. `Values` go out of scope at the end of the block in which they are defined.
210
211When `Nodes` are inserted into a `Graph`, they are inserted at a special "insertion point" that is part of the state of the `Graph`. On construction, this will go to the end of the `Graph`.
212
213Each `Block` has two dummy nodes that are not included in the list of nodes in the `Block`. The `prim::Param` node represents the inputs to the `Block` and does not have a `prev()` or `next()` `Node`. The `prim::Return` `Node` represents the outputs of a `Block`.
214The list of `Nodes` in a `Block` is implemented as a circular linked list with the `prim::Return` `Node` serving as the beginning/end sentinel. Inserting and deleting at arbitrary places is efficient. Developers may also encounter implementations inside of IR objects that use this fact (e.g. appending to a `Block` is equivalent to putting the node before the `prim::Return` Node).
215
216Iterators for the `Block::nodes()` list are invalided when the current `Node` they point to is moved or deleted. Otherwise iterators remain valid.
217
218Blocks also contain a list of input and output values. The meaning of these values depends on where the `Block` is used. For the Graph's top-level `Block`, these are inputs and outputs to the `Graph`, and line up with the FunctionSchema associated with the Method that owns the `Graph`.
219
220**Control-flow** is represented using sub-Blocks rather than a control-flow graph representation. A `prim::If` has one `Block` for the true branch and one `Block` for the else. A `prim:Loop` has a block for the loop body (there is no condition block, instead the end of the loop body computes whether to re-enter the loop body). This representation ensures we have structured control-flow. This limitation makes a lot of optimizations easier and is true for the vast majority of networks. A `Node` can look up what `Block` it is in, and a `Block` and can look up its parent (either the `Node` that has it as a sub-Block, or `nullptr` for the main Block).
221
222### If ###
223If-statement (`prim::If`) `Blocks` have no inputs, and the outputs are the new values of variables in the outer block whose values were altered in the if-statement.
224Example IR for an if-statement looks like:
225```
226%y_1, ..., %y_r = prim::If(%condition)
227  block0():  # TRUE BRANCH, never takes arguments, has to return r outputs
228    %t_1, ..., %t_k = some::node(%a_value_from_outer_block)
229    -> (%t_1, ..., %t_r)
230  block1():  # FALSE BRANCH, never takes arguments, has to return r outputs
231    %f_1, ..., %f_m = some::node(%a_value_from_outer_block)
232    -> (%f_1, ..., %f_r)
233```
234
235Values corresponding to `%y_1, ..., %y_r` will become either `%t_1, ..., %t_r`, or `%f_1, ..., %f_r` depending on the value of `%condition` at runtime.
236
237Here's an example translation of a Python program and its corresponding IR:
238
239```python
240def f(a, b, c):
241    d = a + b
242    if c:
243        e = d + d
244    else:
245        e = b + d
246    return e
247```
248
249```
250graph(%a : Dynamic,
251      %b : Dynamic,
252      %c : Dynamic):
253  %2 : int = prim::Constant[value=1]()
254  %3 : Dynamic = aten::add(%a, %b, %2)
255  %5 : Dynamic = prim::If(%c)
256    block0():
257      %6 : int = prim::Constant[value=1]()
258      %7 : Dynamic = aten::add(%3, %3, %6)
259      -> (%7)
260    }
261    block1():
262      %8 : int = prim::Constant[value=1]()
263      %9 : Dynamic = aten::add(%b, %3, %8)
264      -> (%9)
265  return (%5)
266```
267
268The outputs of the if-statement serve a role similar to that of a Φ (Phi) function node in traditional [SSA](https://en.wikipedia.org/wiki/Static_single_assignment_form) control-flow graphs.
269
270### Loops ###
271Loops are implemented with `prim::Loop` which covers both `while` and `for` loops. A valid instantiation of this node always looks like this:
272```
273%y_1, ..., %y_r = prim::Loop(%max_trip_count, %initial_condition, %x_1, ..., %x_r)
274  block0(%i, %a_1, ..., %a_r):
275    %b_1, ..., %b_m = some::node(%a_value_from_outer_block, %a_1)
276    %iter_condition = some::other_node(%a_2)
277    -> (%iter_condition, %b_1, ..., %b_r)
278```
279
280The simplest way to explain the semantics is to consider this Python-like pseudo-code:
281```python
282y_1, ..., y_r = x_1, ..., x_r
283condition = initial_condition
284i = 0
285while condition and i < max_trip_count:
286    a_1, ..., a_r = y_1, ..., y_r
287
288    ############################################################
289    # Actual body of the loop
290    b_1, ..., b_m = some::node(a_value_from_outside_of_the_loop, a_1)
291    iter_condition = some::node(a_2)
292    ############################################################
293
294    y_1, ..., y_r = b_1, ..., b_r
295    condition = iter_condition
296    i += 1
297```
298
299> Note that translations of `for` loops simply pass in a constant `true` for both `%initial_condition` and `%iter_condition`. Translations of `while` loops set `%max_trip_count` to the largest value of `int64_t` and do not use `%i`. Those patterns are recognized by our interpreter and optimized accordingly (e.g. `while` loops don't maintain the loop counter).
300
301For example, this program:
302
303```python
304def f(x):
305    z = x
306    for i in range(x.size(0)):
307        z = z * z
308    return z
309```
310
311can be translated as:
312
313```
314graph(%z.1 : Dynamic):
315  %3 : bool = prim::Constant[value=1]()
316  %1 : int = prim::Constant[value=0]()
317  %2 : int = aten::size(%z.1, %1)
318  %z : Dynamic = prim::Loop(%2, %3, %z.1)
319    block0(%i : int, %5 : Dynamic):
320      %z.2 : Dynamic = aten::mul(%5, %5)
321      -> (%3, %z.2)
322  return (%z)
323```
324
325### With ###
326With-statements are represented in two different ways. For most of the compilation and optimization process, they are represented as a pair of `prim::Enter` and `prim::Exit` nodes that wrap the nodes corresponding to the body of the with-statement. However, with-statements are temporarily represented for the duration of the [`exit_transform` pass](frontend/exit_transforms.cpp) using a block-based representation in which a `prim::With` node is inserted after the `prim::Exit` node, all of the nodes between the `prim::Exit` and `prim::Enter` are moved into the first block of the `prim::With`, and the `prim::Exit` is moved into the second block of the `prim::With`. For example, this program:
327
328```
329with c as increment:
330  y = x + increment
331```
332
333can be translated as:
334
335```
336%2 : int = prim::Constant[value=1]()
337%increment.1 : int = prim::Enter(%c.1)
338%y.1 : Tensor = aten::add(%x.1, %increment.1, %2)
339%11 : Tensor = prim::Exit(%c.1)
340```
341
342and will temporarily be transformed to:
343
344```
345%increment.1 : int = prim::Enter(%c.1)
346= prim::With()
347  block0():
348    %y.1 : Tensor = aten::add(%x.1, %increment.1, %4)
349    -> ()
350  block1():
351    %11 : Tensor = prim::Exit(%c.1)
352    -> ()
353```
354
355for the duration of the `exit_transform` pass.
356
357## Value ##
358
359[ir.h](ir/ir.h)
360
361A `Value` represents data flowing through the operations in the program, e.g. the output of a matrix-multiply op. `Value` objects are always defined by a single `Node` (`v.node()`) due to single-static assignment form. For inputs to a Block/Graph, this node is a special `prim::Param` node that does not appear anywhere in the block's list of nodes. `Value` objects also have a Type (e.g. is it a Tensor? a list? a tuple?) that provides a static guarantee that its value will be of that Type.
362
363A `Value` object has methods that return its definition (`v.node()`) and to all of its uses (`v.uses()`). Each Use has a pointer to the `Node` whose input list includes the value. Be careful when iterating over `v.uses()` while changing how `v` is used because each change to `v` will invalidate the `v.uses()` iterator.
364
365Values are abstract representations of data in the program. When executing, the actual `Tensors`, list, tuples, etc. are stored in `IValues` (_interpreter_ values), which are tagged unions of all possible value types in TorchScript. In retrospect the name `Value` is a bit confusing because it seems like it should be the tagged union, but it originally came from analogy to `llvm::Value`, which serves the same purpose as `jit::Value`.
366
367
368## Type ##
369
370[aten/src/ATen/core/jit_type.h](/aten/src/ATen/core/jit_type.h)
371
372TorchScript, unlike Python, is statically typed, so every `Value` has a Type associated with it, and every FunctionSchema has a list of argument types and a return type for a function. Type is the base class of a hierarchy of C++ objects that represent the built-in types of TorchScript. Types provide methods such as `Type::isSubtypeOf` that describe the typing relationships. Common type are:
373
374* TensorType - a `Tensor` with optionally refined information. It may know its device, type, requires_grad state, and number of dimensions.
375  If it does know the number of dimensions it may know the size of a particular dimension.
376* Tuples - e.g. `Tuple[Tensor, Int]`. Each member of the tuple is statically typed and the length of the tuple is statically known.
377* List[T] - e.g. `List[Tensor]`. Mutable lists of a particular type.
378* Optional[T] - e.g. `Optional[Tensor]`. Either a value of type T or None.
379* Dict[K, V] - e.g. `Dict[String, Tensor]`. Dictionaries.
380
381If type S is a subtype of P, then we can substitute an `IValue` that has type S anywhere something of type P is expected. This means that all subtyping relationships also require the representation of the `IValue` for subtypes to be compatible with the representation for the base type.
382
383
384# Generating Programs #
385
386JIT programs are created using either the tracing frontend (`torch.jit.trace`) or the scripting frontend (`torch.jit.script`). In both cases, the result of these frontends is a complete Module that contains all the code in Methods, and all the model weights in the Parameters of the Module. However, each frontend goes through a different pathway for generating those Modules.
387
388## Tracer ##
389
390
391[tracer.h](frontend/tracer.h)
392[tracer_state.h](frontend/tracer_state.h)
393
394The tracer produces graphs by recording what actual operations are done on `Tensors`.
395The entry point from Python into C++ for tracing using `torch.jit.trace` is `_create_method_from_trace`.
396
397A thread local instance of the TracingState object maintains a mapping between actual data being computed during the trace (e.g. Tensors) stored in `IValues`, and the abstract `Value` in the `Graph` that would compute each value. The functions `void setValueTrace(const IValue&, Value*)` and `Value* getValueTrace(const IValue&)` are used by the tracer to maintain this mapping.
398
399An initial `IValue` to `Value` mapping is set up between the inputs to the function being traced and symbolic `Value` inputs to the `Graph` being constructed. If we are tracing a `torch.nn.Module`, the tracer also adds Parameters and sub-Modules to the Module being constructed that correspond to the Python `torch.nn.Module` being traced.  Mappings for these values are also added so that uses of the Parameters in the trace will create uses of the Parameters in the `Graph`.
400
401As the trace runs, individual operators create `Nodes` in the `Graph` being traced to record what happens. This code is currently generated per operator in [tools/autograd/gen_variable_type.py](/tools/autograd/gen_variable_type.py). It results in code that looks like the following:
402
403```cpp
404torch::jit::Node* node = nullptr;
405std::shared_ptr<jit::tracer::TracingState> tracer_state;
406if (jit::tracer::isTracing()) {
407        tracer_state = jit::tracer::getTracingState();
408        at::Symbol op_name;
409        op_name = jit::Symbol::fromQualString("aten::__ilshift__");
410        node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
411        jit::tracer::recordSourceLocation(node);
412        jit::tracer::addInputs(node, "self", self);
413        jit::tracer::addInputs(node, "other", other);
414        tracer_state->graph->insertNode(node);
415
416        jit::tracer::setTracingState(nullptr);
417}
418TypeDefault::__ilshift__(self, other);
419if (tracer_state) {
420        jit::tracer::setTracingState(std::move(tracer_state));
421        jit::tracer::addOutput(node, self);
422}
423```
424
425The functions `addInputs` and `addOutput` are overloaded to handle the different data types that operators use.
426
427set/getValueTrace only works on `Tensors` and `Futures`. Other types are not natively traced. Instead aggregates like tuples or lists are often flattened into `Tensors` at the end of a trace and explicitly constructed from individual `Tensors` at the beginning of this trace.
428
429The tracer has special behavior when tracing calls to other TorchScript functions. This behavior is implemented in the GraphExecutor right before a `Graph` is about to be run. If tracing is enabled while running the graph, the GraphExecutor will disable tracing, run the graph as normal, and then inline the `Graph` into the trace. It then hooks up the `IValues` computed by running the `Graph` to out `Values` in the inlined graph.
430
431> *When a trace calls a TorchScript function, that function is preserved as is, meaning that control-flow is preserved.* This makes it possible to work around tracing issues by generating the subset of the program that cannot be traced using the script frontend and having the trace invoke it.
432
433The resulting `Graph` created by tracing is installed as the 'forward' method of the Module being created. A Module is produced regardless of whether the thing being traced was a function or a `torch.nn.Module`. In the function case, the Module produced will simply have a single `forward` function, no Parameters, and no sub-Modules.
434
435## Script ##
436
437The script frontend directly converts Python syntax into Modules. Like many compilers this happens in two phases. First, we generate an abstract syntax tree (AST), which is constructed out of Tree objects. The IR emitter then does semantic analysis on the Tree and lowers it into a Module. We can generate Trees in two ways: (1) using frontend.py, which takes the Python AST and transliterates it into Tree objects, or (2) via the Lexer and Parser which parse Python syntax directly. The Lexer+Parser path may seem redundant but it is crucially important. We need to define builtin functions ([frontend/builtin_functions.cpp](frontend/builtin_functions.cpp)) when Python is not linked because we allow users to generate TorchScript programs directly from strings containing Python source code ([api/include/torch/jit.h](/torch/csrc/api/include/torch/jit.h)) without linking a full Python implementation (e.g. CPython). We also use this Python syntax as the serialization format for TorchScript, since it allows us to make changes to our IR without breaking backward compatibility. Furthermore, the Lexer is reused to implement the FunctionSchema parser, which turns FunctionSchema declarations from strings into FunctionSchema objects.
438
439The following sections look into each the stages in the script frontend in detail.
440
441## Tree ##
442
443[frontend/tree.h](frontend/tree.h)
444
445Our frontends produce ASTs in the form of Tree objects. Trees are similar to [s-expressions](https://en.wikipedia.org/wiki/S-expression). Leafs (i.e. Atoms) are always strings. Compound trees have a `kind` (e.g `TK_CONST` or `TK_IDENT` defined in [lexer.h](frontend/lexer.h)) and a list of sub-trees.  For instance, the Tree for `z.sigmoid() - (x + y)` is:
446
447```
448 (-
449        (+
450          (variable (ident x))
451          (variable (ident y)))
452        (apply
453          (.
454                (variable (ident z))
455                (ident sigmoid))
456          (list)
457          (list))))
458```
459
460This is printed in s-expression style with `(kind ...)` representing compound trees and `string_value` representing strings.
461
462We provide utilities to construct, traverse, and print ASTs without a lot of complicated visitor infrastructure and inheritance.
463
464Each Tree also has a mandatory SourceRange object that describes the range of text that it came from. These will be used for error reporting in the rest of the code.
465
466## Tree Views ##
467
468[frontend/tree_views.h](frontend/tree_views.h)
469
470Trees are easy to construct, visualize and traverse, but extracting information from a large compound tree like that of a function definition is unwieldy since it requires numeric indexing. Tree _Views_ are a small layer on top of a tree that make it possible to create and de-structure trees of particular kinds. For example, here is the tree view for the apply node which provides named accessors for its subtrees: the function being called, the inputs, and the attributes (i.e. kwargs):
471
472```cpp
473struct Apply : public Expr {
474  Expr callee() const {
475    return Expr(subtree(0));
476  }
477  List<Expr> inputs() const {
478    return List<Expr>(subtree(1));
479  }
480  List<Attribute> attributes() const {
481    return List<Attribute>(subtree(2));
482  ...
483};
484```
485
486The typical way to traverse a tree is to `switch` on the kind and then construct the appropriate Tree view:
487
488```cpp
489switch (tree.kind()) {
490  case TK_VAR:
491          auto var = Var(tree); // construct tree view
492        return environment_stack->getSugaredVar(var.name());
493  case '.': {
494        auto select = Select(tree); // construct tree view
495        auto sv = emitSugaredExpr(select.value(), 1);
496        return sv->attr(select.range(), method, select.selector().name());
497  }
498  case TK_APPLY: {
499        auto apply = Apply(tree); // construct tree view
500        return emitApplyExpr(apply, n_binders);
501  } break;
502
503```
504
505## frontend.py ##
506
507[torch/jit/frontend.py](../../jit/frontend.py)
508
509One way we construct Tree objects is directly from Python ASTs. This logic is contained inside frontend.py and is intentionally very minimal.
510
511> *We endeavor to keep most of the JIT code written in C++, because most of the JIT functionality still needs to work without Python installed.*
512
513This code simply constructs the Tree, filtering out the AST nodes of Python that we do not support.
514
515## Lexer ##
516
517[frontend/lexer.h](frontend/lexer.h)
518
519When loading TorchScript code directly from a string, we using a standard Lexer+Parser combo. The Lexer takes an initial string and then exposes a stateful interface for walking the Tokens of the string, providing a standard set of functions:
520
521* `next()` advances the lexer, returning the current token
522* `cur()` provides the current token
523* `lookahead()` provides the token coming after the current token
524* `nextIf(int token_kind)` advances the token if it matches token kind.
525
526Similar to Python, the Lexer handles the white-space sensitive nature of Python blocks. The Tokens `TK_INDENT`, `TK_DEDENT`, and `TK_NEWLINE` are injected into the token stream when code first becomes indented, when it dedents, and at the end of a statement. For instance for this stream:
527
528```cpp
529if
530  .
531  .
532```
533
534We would get a token stream `TK_IF TK_NEWLINE TK_INDENT . TK_NEWLINE . TK_NEWLINE TK_DEDENT`. Unmatched opening brackets disable the injection of these tokens. The result is that the Parser can simply treat `TK_INDENT`, `TK_DEDENT` and `TK_NEWLINE` like C's `{`, `}`, and `;`.
535
536## Tokens ##
537
538[frontend/lexer.h](frontend/lexer.h)
539
540Tokens are either keywords (`def`), operators (`+`), literals (`3.4`), or identifiers (`foo`). A `token_kind` integer identifies what it is and is the exact same type as the `kind` of a Tree. For single-character Tokens (e.g. `+`), the kind is the same as the character, enable statements like:
541
542```cpp
543if (lexer.nextIf('+')) {
544        // handle + ...
545}
546```
547
548Multi-character token kinds are defined in a list, `TC_FORALL_TOKEN_KINDS`. Tokens also have a `text()` field that records the actual string producing the token and is used by identifiers and literals to construct the actual values (e.g. the numeric value of a floating point literal).
549
550## Parser ##
551
552[frontend/parser.h](frontend/parser.h)
553
554The Parser uses the Lexer to build the AST for function definitions. `parseFunction` is the entrypoint for parsing a single `def ...` and will return a `Def` tree view.
555
556The Parser is written as a [top-down precedence parser](https://eli.thegreenplace.net/2010/01/02/top-down-operator-precedence-parsing), or "Pratt" parser.  They are simpler and easier to understand than typical parser generators, while still being flexible enough to parse programming languages. For the most part parsing is done by recursive decent. To resolve operator precedence issues, the function to parse an expression is augmented with a precedent _p_ such that calling the function means _parse an expression whose operators all have precedence higher than p_.
557
558## IR Emitter ##
559
560[frontend/ir_emitter.h](frontend/ir_emitter.h)
561
562The file ir_emitter.cpp translates Trees into Modules. The main entrypoint is `defineMethodsInModule` which takes a list of Def Tree Views representing function definitions and adds them as Methods to the module. During the lowering processing _semantic checking_ occurs. The IR emitter checks that all used variables are defined (sometimes called scope checking), and that all values have compatible types (type-checking). During this process it also emits the graph nodes corresponding to each statement in the Tree and generates a FunctionSchema for the whole definition.
563
564A few helper objects exist in the lowering process.  SugaredValues are special values that represent objects that can appear during compilation but that are not first class values. For instance, in TorchScript methods `self` refers to the module, and `self.weight` refers to a Parameter of the module. Neither are first-class Types and have no corresponding `Value` in a graph. Resolver objects are `std::functions` that resolve externally-defined variables to SugaredValues. For instance, the identifier `torch` which contains most of our built-in ops is looked up through Resolver objects which interact with the Python state of the program.
565
566The Environment tracks the mapping between variable names and the SugaredValues they refer to.
567
568## SugaredValue ##
569
570[frontend/sugared_value.h](frontend/sugared_value.h)
571
572SugaredValues are how the IR emitter represents non-first class values during `Graph` creation. These values are things like the Module or a Python function call that do not have corresponding `Value` objects in the `Graph`. The IR emitter _desugars_ the SugaredValue objects to instructions in the graph based on how they are used.  The SugaredValue class has a number of abstract methods on it such as `attr` or `call`. Consider the expression `self.foo`. For methods, `self` will resolve to a special SugaredValue subclass,  ModuleValue. When the emitter sees `self.foo`, it will then call the ModuleValue function `sv.attr("foo")`, asking the ModuleValue how it should desugar itself when the attribute `"foo"` accessed. If `foo` is a parameter, it would then ensure that the parameter was added to the Method being compiled, and return a `SimpleValue` sugared value that contains the `Value` object representing the parameter as an input. If `foo` were a sub-Module then it would return another SugaredModule. The method `call` is invoked when the emitter sees the value used as a function call.
573
574SugaredValues are also how we interact with Python runtime during the compilation process. For instance, `math.pi` is resolved to 3.1415... by first resolving `math` to a SugaredValue representing accesses to Python modules (PythonModuleValue) whose `attr` function turns Python numbers into  `prim::Constant` `Nodes` in the graph.
575
576Finally, normal `Values` are also represented by the SimpleValue SugaredValue in places where it is valid that either a SugaredValue or a normal `Value` will appear.
577
578## Resolver ##
579
580[frontend/resolver.h](frontend/resolver.h)
581
582Any undefined variable during compilation is resolved with a call to an externally-provided Resolver. When called from Python (e.g `torch.jit.script`) this resolver interacts with the Python runtime via pybind11 to resolve symbols like `torch` and `math` to their Python equivalents.
583
584*The combination of SugaredValue and Resolver decouples the implementation of the IR emitter from the pybind11 Python bindings that enable its interaction with the Python state.*
585
586This makes it possible to use most of the IR emitter functionality when Python is not present.
587
588## Environment ##
589
590[frontend/ir_emitter.cpp](frontend/ir_emitter.cpp)
591
592The Environment object tracks the assignment of variable names during compilation. It is local to the IR emitter file. A stack of environments exist, with a new environment being created for sub-blocks introduced by control flow. The Environment keeps two tables, one for values which are not first class in the type system (SugaredValues) and a type table for values which are. When first class values are set, we emit a prim::Store, and when they are referenced we emit a prim::Load. SugaredValues are not re-assignable.
593
594## Conversion To SSA ##
595
596[frontend/convert_to_ssa.cpp](frontend/convert_to_ssa.cpp)
597
598As explained in the [Block](#block) section, the IR is represented in structured control flow composed of ifs & loops. This makes it easier to optimize and lower to other compilers which do not support unstructured control flow. We lower Python control flow (break, continue, return) to this simplified form. We do closing over any variables in the environment, so we are able to convert all writes and reads from the environment directly to SSA form.
599
600Conversion to SSA works in multiple parts.
601- First, we add loads and stores to control flow operators (ifs & loops).
602- Then we erase break & continue statements from the graph and replace them with `prim::LoopContinuation(%loop_continue_condition, %loop_carried_vars)`. Break statements have the continue condition set to false, and continue statements inline the loop condition. %loop_carried_vars are the loop carried variables of the inner most loop that contains the break or continue statement, are added by inserting prim::Load calls at the location of the statement.
603- Then we inline the loop condition into the graph loops.
604- Next we erase loads and stores, removing all stores and replacing all loads
605with whatever the in-scope value of the variable name is.
606- Finally, we remove `prim::LoopContinuation`s and `prim::ReturnStmt`s in the exit_transform pass.
607
608## Exit Transform ##
609
610[frontend/exit_transforms.cpp](frontend/exit_transforms.cpp)
611
612This pass takes in a graph where LoopContinuation & ReturnStmts exist in the graph and erases them, correctly setting block outputs. `prim::LoopContinuation(*vals)` means that the values are targeting the most recent loop block. `prim::ReturnStmt(*vals)` means that the values are targeting the most recent `Closure` or `Graph` `Block`.
613
614If a block has an exit node, no further instructions will be executed until the exit target has been reached. If we encounter a node that contains nested blocks that may have hit an exit node, such as an if statement that exits in one block and does not exit in the other, we use a boolean value to indicate if the exit has been hit or not. Then, we conditionalize further execution.
615
616Python example:
617
618```python
619while i < 5:
620  if i == 3:
621    i += 1
622    continue
623  i += 2
624```
625
626-> transforms to
627
628```python
629continue_loop = i < 5
630while continue_loop:
631  if i == 3:
632    i = i + 1
633    continue_loop = i < 5
634    did_exit = True
635  if did_exit:
636    pass
637  else:
638    i = i + 2
639    continue_loop = i < 5
640```
641
642The pass also keeps track of nodes or blocks that will always throw Exceptions so that we do not unnecessarily conditionalize execution. In the following example, we can treat the if statement as always Returning and remove the `print` statement.
643
644```python
645if i < 0:
646  raise Exception("Negative input")
647else:
648  return math.sqrt(i)
649print(i)  # unreachable code
650```
651
652In the above example, the if statement will have one output: `math.sqrt(i)` on the false branch, and `prim::Uninitialized` in the true branch. `prim::Uninitialized` is inserted by the compiler when it can prove the value will never be used. It can be introduced by exceptions, breaks, continues, and returns.
653
654We initially considered doing the Transform pass before Loads and Stores were removed from the graph. However, this breaks when a loop carried variable
655is captured in a break or continue and then is refined in the rest of the loop body. In the below example, at the point of the `continue`, `x` has type `Optional[int]` but is refined to `int` after the continue statement.
656
657```python
658...
659if cond:
660  if i < 3:
661      x = torch.jit.annotate(Optional[int], None)
662      continue
663  x = 1
664else:
665  x = 2
666print(x)
667```
668If we were to rearrange the graph before loads & stores were removed:
669
670```python
671if cond:
672  if i < 3:
673    x = torch.jit.annotate(Optional[int], None)
674    did_continue = True
675    continue
676  else:
677    did_continue = False
678  if not did_continue:
679    x = 1
680else:
681  x = 2
682if not did_continue:
683  print(x)
684```
685The type of `x` at the print statement would be `Optional[int]`, which breaks its original type.
686
687## Python-Compiler Interaction ##
688
689[python/script_init.cpp](python/script_init.cpp)
690
691A set of special SugaredValues are used to translate between objects in the Python environment and `Values` in the `Graph` during the compilation process. The entry-point for this behavior is `toSugaredValue(py::object obj, ...)` which takes a pybind11 Python value and figures out how to turn it into an appropriate SugaredValue. `Values` exist to represent Python functions, Python modules, and `ScriptModule` objects.
692
693
694# Executing Programs #
695
696TorchScript is executed using an interpreter attached to a JIT-optimizer and compiler. The entry-point for execution is the GraphExecutor object that is created on demand inside a Method when the method is first called. This section first goes over the semantics of graphs, i.e. what does it mean to execute a graph? And then details how the implementation works.
697
698
699## Evaluation Semantics ##
700
701TorchScript programs implement a very small subset of Python that is necessary to run models.
702
703TorchScript includes immutable value types:
704* `int`
705* `float`
706* `Tuple[T0, T1, ...]`
707
708As well as mutable reference types:
709* `Tensor`
710* `List[T]`
711* `Dict[K, V]`
712
713A value of a reference type points to an underlying memory location where the data for the reference type is stored, and variable assignment for a reference type can cause multiple values to point to the same underlying data. This is similar to Python's class model.
714
715It is important to remember that TorchScript uses these semantics for `Tensors` so not all computation on `Tensor` is pure. Individual `Tensors` may be *views* of the same underlying data. Views are established by special view creating operations, such as indexing into a Tensor:
716
717```python
718t = torch.rand(3, 4)
719t2 =  t[0] # view of one slice of t
720```
721
722Some builtin operators also mutably write to the underlying `Tensor`. In the standard library these operators are always named with a trailing underscore, or take a named `out` `Tensor` where the result is written:
723
724```python
725t2.relu_() # inplace relu operator, note t is modified as well!
726torch.add(t, t, out=t) # update t, without using temporary memory if possible
727```
728
729The combination of reference semantics and mutable operators can be more difficult to optimize, but it gives program writers powerful control of the memory usage of their programs. For instance, DenseNets use a concat operation instead of the addition found in a ResNet. Rather than compute a concat of existing `Tensors`, many implementations use `Tensor` indexing and `out` keywords to avoid allocating addition memory for the activations. Ideally a compiler would always be able to do these optimizations, but in practice new ideas are tried all the time that exist outside what compiler writers expect.
730
731In addition to being mutable, `Tensors` also have a set of dynamically determined properties (i.e. properties that can vary from run to run) this includes:
732
733* dtype - their data type int, float, double, etc.
734* device - where the `Tensor` lives, e.g. the CPU, or CUDA GPU 0
735* rank - the number of dimensions that the `Tensor` has
736* size - the precise size of the Tensor
737* requires_grad - whether the `Tensor` is recording its gradient with autograd
738
739Changes in these properties change how operators on `Tensors` will evaluate and would make certain optimization invalid. For instance, if we have a fuser capable of generating new CUDA kernels but not CPU kernels, it is only valid to fuse operations where the inputs are known to run only on CUDA devices. The GraphExecutor's job is to still enable optimization even when certain combinations of properties prevent optimizations from occurring.
740
741Nodes in a graph are executed *serially* in the order they appear in a block. `Nodes` may be reordered either during optimization or by the interpreter itself if it can be proven that the new order
742is not distinguishable from the original execution order. These semantics are necessary since the combination of mutable `Tensors` and potential aliases between `Tensors` makes it unsafe to perform arbitrary reordering otherwise. However, the AliasInfo object can accurately track how alias propagate through builtin operators so optimization passes can query when certain reorders or optimizations are safe.
743
744We also provide user-accessible parallel execution through the `fork` and `wait` primitives. The `fork` primitive begins execution of `fn` in parallel with the current thread of execution, immediately returning a Future object that will hold the result of the forked function. The `wait` method of the future then causes the invoking thread to wait for the value being computed by `fn`.
745
746```python
747def fn(arg0, arg1, ...):
748  ...
749  return v
750
751fut = torch.jit.fork(fn, arg0, arg1, ...)
752...
753v = torch.jit.wait(fut)
754
755```
756
757Currently, the user is responsible for avoiding races between threads. We encourage users to not write to `Tensors` visible from other threads, and may enforce this more strictly in the future.
758
759Optimization passes that wish to exploit multi-threaded execution may automatically convert serial `Blocks` into parallel execution by inserting extra fork and wait events. This design enables our users to manually specify parallelism while also allowing optimization passes to exploit it when safe and profitable.
760
761
762## IValue ##
763
764[ivalue.h](/aten/src/ATen/core/ivalue.h)
765
766All evaluation involves computation using `IValues`, 16-byte tagged unions that can hold the concrete representation of any type in TorchScript. TorchScript is statically typed, so it would be possible to operate on unboxed primitive types, but the interface between interpreter, built-in ops and user functions would be significantly more complicated. A single tagged union keeps these interfaces simple and since most objects are `Tensors` anyway, the overhead of storing a tag is small compared to the data stored in the `Tensors`.
767
768IValue contains methods to check the type (e.g. `isTensor()`) and to convert to particular to type (e.g. `toTensor()`). We do not publicly expose the type tag and force clients to use the `isX` methods. This enables us to change the underlying implementation of `IValue` later, e.g. to use an 8-byte value with NaN-boxing. Most operators work on a specific static type, so dynamic dispatch on the tag is not frequently required.
769
770## Operation ##
771
772All builtin operators are represented using a stack machine concept. An operator pops its arguments off the top of the stack and pushes its result to the stack:
773
774```cpp
775using Stack = std::vector<IValue>;
776using Operation = std::function<void(Stack*)>;
777
778// schema: example_add(Tensor a, Tensor b) -> Tensor
779void example_add(Stack* stack) {
780    Tensor a, b;
781    // stack before: ? ? ? a b <- back
782    pop(stack, a, b); // Templated helper function
783                      // that pops a, b and converts them to Tensor
784    push(stack, a + b);
785    // stack after:
786    // ? ? ? c <- back
787}
788```
789
790Most operations, apart from some vararg primitive operators like prim::Unpack, have an associated FunctionSchema that describes how many inputs will be popped and how many will be pushed.
791
792The stack concept makes it easy to define operators with variable numbers of inputs and outputs without the need to allocate vectors of inputs and outputs for each individual operator.
793
794In practice, the interpreter will allocate one Stack, and it will eventually reach a sufficient size such that no more stack-related memory allocations will occur.
795
796## Operator ##
797
798[runtime/operator.h](runtime/operator.h)
799
800The Operator object represents a single registered operator in the system. It combines a FunctionSchema that describes how an Operation executes with a method to look up the corresponding Operation given the Node representing the operator in a Graph.  Most Operators are defined by providing a FunctionSchema and an Operation function. However, primitives like prim::Unpack require knowledge of their Node to know how to operate (e.g. how many elements to unpack). These Operators have a function that takes a `Node*` and returns an operation.
801
802
803## Interpreter ##
804
805[runtime/interpreter.cpp](runtime/interpreter.cpp)
806
807The interpreter is responsible for the straightforward execution of `Graphs` without any optimization. It is composed of two objects: Code and InterpreterState. Code is a linearized representation of the `Graph` into simple stack-machine Instructions. Code is shared among all the executions of the `Graph` and will include caches for certain operations like the generated CUDA code of FusionGroups.
808
809The InterpreterState is unique to each execution of the `Graph`. It holds a list registers with the intermediate `IValues` used in the execution, the Stack being used by each Operation, and the program counter tracking the position in the instructions. The information represents the complete state of the interpreter. `wait` instructions can cause the interpreter to suspend, and the InterpreterState is used to resume execution where the `wait` occurred, potentially on a different thread.
810
811Instructions in the interpreter have three parts: a list of registers from which to gather `IValues` onto the stack before the instruction, the Operation to run, and a list of registers in which to store the results of the Operation. Alternatively, we could have used individual instructions to load/store values from the stack to registers, but this design was easier to implement, requires fewer instructions since each instruction does more things, and has not yet been a performance bottleneck. Each Operation returns a potential relative jump to compute the next program counter.
812
813Unlike typical interpreters, we do not attempt careful register allocation. Since `Tensors` are reference types, saving registers would only save a few hundred bytes of space in typical applications by cutting down on the number of places a reference could be saved. The data in single a `Tensor` is likely significantly bigger than that, so we forgo register allocation to make debugging easier.
814
815However, we do need to ensure that values are destructed immediately after their last use. Because Torch reference counts `Tensors`, they will be deallocated immediately when their last reference is gone. To ensure we use a minimum amount of memory we want to ensure that the interpreter releases the reference as soon as it is no longer used. To do this, each Instruction also has a set of flags which indicate the inputs to the operation which will no longer be used after the operation. For these inputs, the `IValue` is moved rather than copied from the register file, ensuring the reference will go dead as soon as the Operation no longer needs it.  extra instructions may be inserted into the program to explicitly drop references for values whose last use depends on the control flow of the program.
816
817The following is an example program in `Graph` form and its equivalent in interpreter [Instructions](runtime/instruction.h):
818
819```
820graph(%x : Tensor,
821      %hx : Tensor,
822      %cx : Tensor,
823      %w_ih : Tensor,
824      %w_hh : Tensor,
825      %b_ih : Tensor,
826      %b_hh : Tensor):
827  %7 : int = prim::Constant[value=4]()
828  %8 : int = prim::Constant[value=1]()
829  %9 : Tensor = aten::t(%w_ih)
830  %10 : Tensor = aten::mm(%x, %9)
831  %11 : Tensor = aten::t(%w_hh)
832  %12 : Tensor = aten::mm(%hx, %11)
833  %13 : Tensor = aten::add(%10, %12, %8)
834  %14 : Tensor = aten::add(%13, %b_ih, %8)
835  %gates : Tensor = aten::add(%14, %b_hh, %8)
836  %16 : Tensor[] = aten::chunk(%gates, %7, %8)
837  %ingate.1 : Tensor, %forgetgate.1 : Tensor, %cellgate.1 : Tensor, %outgate.1 : Tensor = prim::ListUnpack(%16)
838  %ingate : Tensor = aten::sigmoid(%ingate.1)
839  %forgetgate : Tensor = aten::sigmoid(%forgetgate.1)
840  %cellgate : Tensor = aten::tanh(%cellgate.1)
841  %outgate : Tensor = aten::sigmoid(%outgate.1)
842  %25 : Tensor = aten::mul(%forgetgate, %cx)
843  %26 : Tensor = aten::mul(%ingate, %cellgate)
844  %cy : Tensor = aten::add(%25, %26, %8)
845  %28 : Tensor = aten::tanh(%cy)
846  %hy : Tensor = aten::mul(%outgate, %28)
847  %30 : (Tensor, Tensor) = prim::TupleConstruct(%hy, %cy)
848  return (%30)
849```
850
851```
8520, 1, 2, 3, 4, 5, 6 = Load
8537 = Constant
8548 = t move(3)
8559 = mm move(0), move(8)
85610 = t move(4)
85711 = mm move(1), move(10)
85812 = add move(9), move(11), 7
85913 = add move(12), move(5), 7
86014 = add move(13), move(6), 7
86115, 16, 17, 18 = ConstantChunk move(14)
86219 = sigmoid move(15)
86320 = sigmoid move(16)
86421 = tanh move(17)
86522 = sigmoid move(18)
86623 = mul move(20), move(2)
86724 = mul move(19), move(21)
86825 = add move(23), move(24), move(7)
86926 = tanh 25
87027 = mul move(22), move(26)
87128 = TupleConstruct move(27), move(25)
872 = Store move(28)
873```
874
875## Graph Executor ##
876
877[runtime/graph_executor.cpp](runtime/graph_executor.cpp)
878
879All program execution starts with a graph executor. It's responsible for running optimizations (potentially involving the JIT-compilation of fused kernel code), and then handing the `Graph` or subcomponents of it off to an interpreter to actually run.
880
881
882In this section, we use a running example program that computes one step of an LSTM to show how the graph is transformed:
883
884This section will use an example this LSTM program:
885
886```python
887@torch.jit.script
888def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
889    gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
890    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
891    ingate = torch.sigmoid(ingate)
892    forgetgate = torch.sigmoid(forgetgate)
893    cellgate = torch.tanh(cellgate)
894    outgate = torch.sigmoid(outgate)
895    cy = (forgetgate * cx) + (ingate * cellgate)
896    hy = outgate * torch.tanh(cy)
897    return hy, cy
898```
899
900After going through the frontend, we start with this unoptimized graph:
901
902```
903graph(%x : Tensor,
904      %hx : Tensor,
905      %cx : Tensor,
906      %w_ih : Tensor,
907      %w_hh : Tensor,
908      %b_ih : Tensor,
909      %b_hh : Tensor):
910  %7 : int = prim::Constant[value=4]()
911  %8 : int = prim::Constant[value=1]()
912  %9 : Tensor = aten::t(%w_ih)
913  %10 : Tensor = aten::mm(%x, %9)
914  %11 : Tensor = aten::t(%w_hh)
915  %12 : Tensor = aten::mm(%hx, %11)
916  %13 : Tensor = aten::add(%10, %12, %8)
917  %14 : Tensor = aten::add(%13, %b_ih, %8)
918  %gates : Tensor = aten::add(%14, %b_hh, %8)
919  %16 : Tensor[] = aten::chunk(%gates, %7, %8)
920  %ingate.1 : Tensor, %forgetgate.1 : Tensor, %cellgate.1 : Tensor, %outgate.1 : Tensor = prim::ListUnpack(%16)
921  %ingate : Tensor = aten::sigmoid(%ingate.1)
922  %forgetgate : Tensor = aten::sigmoid(%forgetgate.1)
923  %cellgate : Tensor = aten::tanh(%cellgate.1)
924  %outgate : Tensor = aten::sigmoid(%outgate.1)
925  %25 : Tensor = aten::mul(%forgetgate, %cx)
926  %26 : Tensor = aten::mul(%ingate, %cellgate)
927  %cy : Tensor = aten::add(%25, %26, %8)
928  %28 : Tensor = aten::tanh(%cy)
929  %hy : Tensor = aten::mul(%outgate, %28)
930  %30 : (Tensor, Tensor) = prim::TupleConstruct(%hy, %cy)
931  return (%30)
932```
933
934Execution starts in `GraphExecutor::run`, which takes a Stack of inputs.
935
936### Specialization ###
937
938The executor *specializes* the `Graph` for the particular set of inputs. Specialization is handled by the `ArgumentSpec` object which extracts a "signature" composed of all the properties being specialized. We only specialize to the properties of `Tensors`. The `ArgumentSpec` only records properties for `Tensors` that either appear directly in the inputs to the graph or inside Tuples that are inputs to the `Graph`. The properties recorded are currently:
939
940* dtype
941* rank, but not size
942* requires_grad
943* device type (CPU, CUDA)
944* defined - whether the `Tensor` exists or is a placeholder
945
946The ArgumentSpec object is used as a key into a cache that holds pre-optimized Code objects (held in an ExecutionPlan object). On a cache hit, an InterpreterState is created and the Code in the cache is run.
947
948### Dynamic Shapes Options ###
949
950In the "Specialization" section above, it is mentioned that "rank, but not size" is specialized on. This is partially true; size is sometimes specialized on because this specialization can sometimes produce more efficient code. By default, static shapes are specialized initially; if more shapes are observed then eventually the graph executor will generate a dynamic-shape version that doesn't depend on specific input shapes.
951
952To control these settings, you can use `torch._C._jit_set_fusion_strategy()`; it takes as an argument a list of tuples in the format `(type, number)` where `type` is a string in `{"DYNAMIC" ,"STATIC"}` and `number` is an integer.
953
954For example:
955```
956torch._C._jit_set_fusion_strategy([
957    ("STATIC", 2),
958    ("DYNAMIC", 20),
959])
960```
961
962This will make two attempts to generate static-shape graphs, and after that fall back to generating dynamic-shape graphs. If for some reason compilation keeps occuring (even with dynamic-shape graphs - e.g. this could happen if ranks or dtypes vary), after 20 compilation attempts the graph executor will fall back to running the graph without any attempts to compile it.
963
964### Pre-derivative Optimization ###
965
966On a code cache miss, we generate a new optimized `Graph` on the fly (`compileSpec`). It starts by creating a copy of the initial `Graph` and setting the input types to the specialized `Tensor` types observed in this specialization. TensorType inputs to the `Graph` will get refined with types that know the device, number of dimensions, and requires grad state.
967
968```
969# post specialization, inputs are now specialized types
970graph(%x : Float(*, *),
971      %hx : Float(*, *),
972      %cx : Float(*, *),
973      %w_ih : Float(*, *),
974      %w_hh : Float(*, *),
975      %b_ih : Float(*),
976      %b_hh : Float(*)):
977  %7 : int = prim::Constant[value=4]()
978  %8 : int = prim::Constant[value=1]()
979  %9 : Tensor = aten::t(%w_ih)
980  %10 : Tensor = aten::mm(%x, %9)
981  %11 : Tensor = aten::t(%w_hh)
982  %12 : Tensor = aten::mm(%hx, %11)
983  %13 : Tensor = aten::add(%10, %12, %8)
984  %14 : Tensor = aten::add(%13, %b_ih, %8)
985  %gates : Tensor = aten::add(%14, %b_hh, %8)
986  %16 : Tensor[] = aten::chunk(%gates, %7, %8)
987  %ingate.1 : Tensor, %forgetgate.1 : Tensor, %cellgate.1 : Tensor, %outgate.1 : Tensor = prim::ListUnpack(%16)
988  %ingate : Tensor = aten::sigmoid(%ingate.1)
989  %forgetgate : Tensor = aten::sigmoid(%forgetgate.1)
990  %cellgate : Tensor = aten::tanh(%cellgate.1)
991  %outgate : Tensor = aten::sigmoid(%outgate.1)
992  %25 : Tensor = aten::mul(%forgetgate, %cx)
993  %26 : Tensor = aten::mul(%ingate, %cellgate)
994  %cy : Tensor = aten::add(%25, %26, %8)
995  %28 : Tensor = aten::tanh(%cy)
996  %hy : Tensor = aten::mul(%outgate, %28)
997  %30 : (Tensor, Tensor) = prim::TupleConstruct(%hy, %cy)
998  return (%30)
999```
1000
1001### Required Passes ###
1002
1003It then runs "required passes", which are graph transformations necessary to generate legal graphs for the interpreter. (Some passes such as differentiation will introduce `Nodes` that are not defined by operators and require passes to clean up. The combination of `specializeUndef` and `LowerGradOf` clean up these operations.) These passes also remove broadcasting "expand" nodes that get implicitly inserted by the tracer but are not valid for all sizes.
1004
1005It then runs inference passes to calculate properties of the graph given this particular specialization:
1006
1007* It propagates constants, pre-computing as much as possible
1008* It propagates the input ranks, dtypes, devices, and requires_grad information to the rest of the graph where possible.
1009
1010```
1011graph(%x : Float(*, *),
1012      %hx : Float(*, *),
1013      %cx : Float(*, *),
1014      %w_ih : Float(*, *),
1015      %w_hh : Float(*, *),
1016      %b_ih : Float(*),
1017      %b_hh : Float(*)):
1018  %8 : int = prim::Constant[value=1]()
1019  %9 : Float(*, *) = aten::t(%w_ih)
1020  %10 : Float(*, *) = aten::mm(%x, %9)
1021  %11 : Float(*, *) = aten::t(%w_hh)
1022  %12 : Float(*, *) = aten::mm(%hx, %11)
1023  %13 : Float(*, *) = aten::add(%10, %12, %8)
1024  %14 : Float(*, *) = aten::add(%13, %b_ih, %8)
1025  %gates : Float(*, *) = aten::add(%14, %b_hh, %8)
1026  %31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *), %34 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%gates)
1027  %ingate : Float(*, *) = aten::sigmoid(%31)
1028  %forgetgate : Float(*, *) = aten::sigmoid(%32)
1029  %cellgate : Float(*, *) = aten::tanh(%33)
1030  %outgate : Float(*, *) = aten::sigmoid(%34)
1031  %25 : Float(*, *) = aten::mul(%forgetgate, %cx)
1032  %26 : Float(*, *) = aten::mul(%ingate, %cellgate)
1033  %cy : Float(*, *) = aten::add(%25, %26, %8)
1034  %28 : Float(*, *) = aten::tanh(%cy)
1035  %hy : Float(*, *) = aten::mul(%outgate, %28)
1036  %30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
1037  return (%30)
1038```
1039
1040### Derivative Preserving Optimization ###
1041
1042It then runs a number of *derivative preserving* optimization passes. If a computation involves `Tensors` that have `requires_grad` and it is valid to compute its derivative, then these passes are only allowed to replace that computation with another computation that is also differentiable. In other words, these passes cannot break autograd. Algebraic rewrites and peephole optimizations are generally derivative preserving but something that generates code, like pointwise fusion, is not.
1043
1044Current derivative preserving passes:
1045
1046* Eliminating dead code
1047* Eliminating common subexpressions
1048* Pooling redundant constants into single values
1049* Peephole optimizations, including some algebraic rewrites into simpler operations
1050* Unrolling small loops
1051* Batching matrix multiplications that result from unrolling loops
1052
1053```
1054graph(%x : Float(*, *),
1055      %hx : Float(*, *),
1056      %cx : Float(*, *),
1057      %w_ih : Float(*, *),
1058      %w_hh : Float(*, *),
1059      %b_ih : Float(*),
1060      %b_hh : Float(*)):
1061  %8 : int = prim::Constant[value=1]()
1062  %9 : Float(*, *) = aten::t(%w_ih)
1063  %10 : Float(*, *) = aten::mm(%x, %9)
1064  %11 : Float(*, *) = aten::t(%w_hh)
1065  %12 : Float(*, *) = aten::mm(%hx, %11)
1066  %13 : Float(*, *) = aten::add(%10, %12, %8)
1067  %14 : Float(*, *) = aten::add(%13, %b_ih, %8)
1068  %gates : Float(*, *) = aten::add(%14, %b_hh, %8)
1069  %31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *), %34 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%gates)
1070  %ingate : Float(*, *) = aten::sigmoid(%31)
1071  %forgetgate : Float(*, *) = aten::sigmoid(%32)
1072  %cellgate : Float(*, *) = aten::tanh(%33)
1073  %outgate : Float(*, *) = aten::sigmoid(%34)
1074  %25 : Float(*, *) = aten::mul(%forgetgate, %cx)
1075  %26 : Float(*, *) = aten::mul(%ingate, %cellgate)
1076  %cy : Float(*, *) = aten::add(%25, %26, %8)
1077  %28 : Float(*, *) = aten::tanh(%cy)
1078  %hy : Float(*, *) = aten::mul(%outgate, %28)
1079  %30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
1080  return (%30)
1081```
1082
1083### Post-derivative optimization ###
1084
1085The next optimization depends on whether any part of the graph actual requires a gradient to be calculated, which is determined by `needsGradient`. In the case where no gradients are required (i.e. for inference graphs), then we can directly apply optimizations that generate graphs that may not have valid gradients defined. For now this is the `FuseGraph` pass, which looks for adjacent point-wise operations along with reviewing operations such as `split` and `concat`, and creates `prim::FusionGroup` `Nodes` in the graph to replace these operations. The Operator registered to execute `prim::FusionGroup` nodes will generate a new CUDA kernel for each unique `Node`, which replaces the original separate execution.
1086
1087Note the two phases for compilation of fusion groups: First, the `FuseGraph` pass splits the `Graph` into fusible sub-Graphs and returns the resulting `Graph` to the graph executor. Second, when the `Graph` is turned into Code, the Operation for the FusionGroup node will be looked up and a new CUDA kernel generated for the body. Other compilers should work in a similar way by first introducing a new operator into the `Graph` where the compiled code should run, and then registering an Operator that implements that `Node` which performs the actual compilation.
1088
1089In the case where no gradients are required, the optimization process is finished, a Code object is constructed from the `Graph`, it is added to the code cache, and then an InterpreterState is constructed and run.
1090
1091```
1092graph(%x : Float(*, *),
1093      %hx : Float(*, *),
1094      %cx : Float(*, *),
1095      %w_ih : Float(*, *),
1096      %w_hh : Float(*, *),
1097      %b_ih : Float(*),
1098      %b_hh : Float(*)):
1099  %9 : Float(*, *) = aten::t(%w_ih)
1100  %10 : Float(*, *) = aten::mm(%x, %9)
1101  %11 : Float(*, *) = aten::t(%w_hh)
1102  %12 : Float(*, *) = aten::mm(%hx, %11)
1103  %77 : Tensor[] = prim::ListConstruct(%b_hh, %b_ih, %10, %12)
1104  %78 : Tensor[] = aten::broadcast_tensors(%77)
1105  %79 : Tensor, %80 : Tensor, %81 : Tensor, %82 : Tensor = prim::ListUnpack(%78)
1106  %hy : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %82, %81, %80, %79)
1107  %30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
1108  return (%30);
1109
1110with prim::FusionGroup_0 = graph(%13 : Float(*, *),
1111      %71 : Tensor,
1112      %76 : Tensor,
1113      %81 : Tensor,
1114      %86 : Tensor):
1115  %87 : Float(*, *), %88 : Float(*, *), %89 : Float(*, *), %90 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%86)
1116  %82 : Float(*, *), %83 : Float(*, *), %84 : Float(*, *), %85 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%81)
1117  %77 : Float(*, *), %78 : Float(*, *), %79 : Float(*, *), %80 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%76)
1118  %72 : Float(*, *), %73 : Float(*, *), %74 : Float(*, *), %75 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%71)
1119  %69 : int = prim::Constant[value=1]()
1120  %70 : Float(*, *) = aten::add(%77, %72, %69)
1121  %66 : Float(*, *) = aten::add(%78, %73, %69)
1122  %62 : Float(*, *) = aten::add(%79, %74, %69)
1123  %58 : Float(*, *) = aten::add(%80, %75, %69)
1124  %54 : Float(*, *) = aten::add(%70, %82, %69)
1125  %50 : Float(*, *) = aten::add(%66, %83, %69)
1126  %46 : Float(*, *) = aten::add(%62, %84, %69)
1127  %42 : Float(*, *) = aten::add(%58, %85, %69)
1128  %38 : Float(*, *) = aten::add(%54, %87, %69)
1129  %34 : Float(*, *) = aten::add(%50, %88, %69)
1130  %30 : Float(*, *) = aten::add(%46, %89, %69)
1131  %26 : Float(*, *) = aten::add(%42, %90, %69)
1132  %ingate : Float(*, *) = aten::sigmoid(%38)
1133  %forgetgate : Float(*, *) = aten::sigmoid(%34)
1134  %cellgate : Float(*, *) = aten::tanh(%30)
1135  %outgate : Float(*, *) = aten::sigmoid(%26)
1136  %14 : Float(*, *) = aten::mul(%forgetgate, %13)
1137  %11 : Float(*, *) = aten::mul(%ingate, %cellgate)
1138  %cy : Float(*, *) = aten::add(%14, %11, %69)
1139  %4 : Float(*, *) = aten::tanh(%cy)
1140  %hy : Float(*, *) = aten::mul(%outgate, %4)
1141  return (%hy, %cy)
1142```
1143
1144### Derivate Splitting ###
1145
1146Many `Graphs` will require gradients (i.e. one of the inputs will have a `requires_grad` property set). In this case, it is unsafe to run post-derivative optimizations directly on the `Graph`. Instead, our approach is to first *split* the `Graph` into sub-Graphs where symbolic gradient formulas are known and produce an explicit `Graph` for the forward pass along with a complementary `Graph` that implements the backwards pass using some of the values computed in the forward pass. We can then apply post-derivative optimization to the forward graph. The "gradOutputs" for the backwards graph are only known when the backward pass runs, so we cannot fully optimize it at this time. For instance, we do not know if some of those gradOutputs will also `require_grad` meaning that a gradient-of-gradient situation exists. Instead the backward pass will use a new GraphExecutor object to run and optimize its execution. In this way, we can handle an indefinite number of recursive gradient calculations.
1147
1148The creating of derivative subgraphs is done using a similar approach to finding fusion groups: adjacent operations with known gradient formulas are grouped together into `prim::DifferentiableGraph` nodes. We only generate these nodes if we can find a large enough subgraph where optimization is likely to be profitable since there is some overhead involved in entering and exiting a differentiable subgraph.
1149
1150```
1151graph(%x : Float(*, *),
1152      %hx : Float(*, *),
1153      %cx : Float(*, *),
1154      %w_ih : Float(*, *),
1155      %w_hh : Float(*, *),
1156      %b_ih : Float(*),
1157      %b_hh : Float(*)):
1158  %8 : int = prim::Constant[value=1]()
1159  %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih)
1160  %30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
1161  return (%30)
1162with prim::DifferentiableGraph_0 = graph(%13 : Float(*, *),
1163      %29 : Float(*),
1164      %33 : Float(*),
1165      %40 : Float(*, *),
1166      %43 : Float(*, *),
1167      %45 : Float(*, *),
1168      %48 : Float(*, *)):
1169  %49 : Float(*, *) = aten::t(%48)
1170  %47 : Float(*, *) = aten::mm(%45, %49)
1171  %44 : Float(*, *) = aten::t(%43)
1172  %42 : Float(*, *) = aten::mm(%40, %44)
1173  %38 : int = prim::Constant[value=1]()
1174  %39 : Float(*, *) = aten::add(%47, %42, %38)
1175  %35 : Float(*, *) = aten::add(%39, %33, %38)
1176  %gates : Float(*, *) = aten::add(%35, %29, %38)
1177  %24 : Float(*, *), %25 : Float(*, *), %26 : Float(*, *), %27 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%gates)
1178  %ingate : Float(*, *) = aten::sigmoid(%24)
1179  %forgetgate : Float(*, *) = aten::sigmoid(%25)
1180  %cellgate : Float(*, *) = aten::tanh(%26)
1181  %outgate : Float(*, *) = aten::sigmoid(%27)
1182  %14 : Float(*, *) = aten::mul(%forgetgate, %13)
1183  %11 : Float(*, *) = aten::mul(%ingate, %cellgate)
1184  %cy : Float(*, *) = aten::add(%14, %11, %38)
1185  %4 : Float(*, *) = aten::tanh(%cy)
1186  %hy : Float(*, *) = aten::mul(%outgate, %4)
1187  return (%hy, %cy)
1188```
1189
1190### Fusers ###
1191
1192As mentioned in the [Post-derivative optimization](#post-derivative-optimization) section, one of the
1193available optimizations is _fusion_, which merges operator kernels and compiles new kernels. Fusion
1194has two benefits: first, it reduces dispatcher overhead by combining multiple operator calls into a
1195single call to the fused kernel; and second, on GPU it can reduce the number of reads and writes to
1196global GPU memory, which can be a significant portion of the runtime for pointwise operators.
1197
1198The current default fuser is
1199[NNC](https://github.com/pytorch/pytorch/tree/master/torch/csrc/jit/tensorexpr)
1200
1201Since fusers rely on specialized information that is only available at runtime - such as dtype,
1202device, and shape - they are only applied after the first invocation of a torchscript function or
1203module. As a result, the first invocation of a torchscript function can sometimes behave slightly
1204differently from subsequent invocations.
1205
1206To enable/disable different fusers, refer to the settings below. These settings apply globally in
1207the process in which they are set. Different fusers may excel in different scenarios, and disabling
1208or switching the fuser could also provide a temporary fix in case of bugs.
1209
1210**Python APIs:**
1211
1212
1213| Feature | Python API |
1214|---|---|
1215| NNC enable/disable | `torch._C._jit_set_texpr_fuser_enabled()` |
1216| NNC on CPU | `torch._C._jit_override_can_fuse_on_cpu()` |
1217| NNC on GPU | `torch._C._jit_override_can_fuse_on_gpu()` |
1218| NNC context manager | `with torch.jit.fuser("fuser1"):` |
1219| NVFuser enable/disable (deprecated) | `torch._C._jit_set_nvfuser_enabled()` |
1220| NVFuser context manager (deprecated) | `with torch.jit.fuser("fuser2")` |
1221| oneDNN Graph on CPU | `torch._C._jit_set_llga_enabled(True)` |
1222| oneDNN Graph context manager | `with torch.jit.fuser("fuser3"):` |
1223
1224**C++ APIs:**
1225
1226| Feature | C++ API | Header file |
1227|---|---|---|
1228| NNC enable/disable | `torch::jit::setTensorExprFuserEnabled(bool);` | [here](https://github.com/pytorch/pytorch/blob/1a7e560adecb0192f69f4d05b990800b60dc380b/torch/csrc/jit/passes/tensorexpr_fuser.h#L22) |
1229| NNC on CPU | `torch::jit::overrideCanFuseOnCPU(bool);` | [here](https://github.com/pytorch/pytorch/blob/1a7e560adecb0192f69f4d05b990800b60dc380b/torch/csrc/jit/codegen/fuser/interface.h#L28-L29) |
1230| NNC on GPU | `torch::jit::overrideCanFuseOnGPU(bool);` | [here](https://github.com/pytorch/pytorch/blob/1a7e560adecb0192f69f4d05b990800b60dc380b/torch/csrc/jit/codegen/fuser/interface.h#L28-L29) |
1231| NVFuser enable/disable (deprecated) | `torch::jit::fuser::cuda::setEnabled(bool);` | [here](https://github.com/pytorch/pytorch/blob/1a7e560adecb0192f69f4d05b990800b60dc380b/torch/csrc/jit/codegen/cuda/interface.h#L56) |
1232
1233### Disabling Optimizations ###
1234
1235To completely disable the runtime optimizations and only run the minimum optimizations necessary,
1236the following commands can be used to globally (in a process) disable the majority of runtime
1237optimizations. This will disable JIT autodiff (instead it will rely on the default autograd
1238implementation provided in eager mode) as well as the fusers and some other runtime optimizations.
1239
1240* Python: `torch._C._get_graph_executor_optimize(False)`
1241* C++: `torch::jit::setGraphExecutorOptimize(false);`
1242* C++ header: [here](https://github.com/pytorch/pytorch/blob/1a7e560adecb0192f69f4d05b990800b60dc380b/torch/csrc/jit/python/update_graph_executor_opt.h#L5)
1243
1244## JIT Logging ##
1245
1246[jit_log.h](jit_log.h)
1247
1248Logging is a very useful debugging technique, especially in the context of compilers. Compilers perform a series of passes and analyses and logging can help to trace issues such as wrong results or segmentation faults
1249all the way back to the original erroneous transformation.
1250
1251`TorchScript` offers a simple logging facility that can enabled by setting an environment variable `PYTORCH_JIT_LOG_LEVEL`.
1252
1253Logging is enabled on a per file basis. To enable logging in `dead_code_elimination.cpp`, `PYTORCH_JIT_LOG_LEVEL` should be
1254set to `dead_code_elimination.cpp` or, simply, to `dead_code_elimination` (i.e. `PYTORCH_JIT_LOG_LEVEL=dead_code_elimination`).
1255
1256Multiple files can be logged by separating each file name with a colon `:` as in the following example, `PYTORCH_JIT_LOG_LEVEL=dead_code_elimination:guard_elimination`
1257
1258There are 3 logging levels available for your use ordered by the detail level from lowest to highest.
1259
1260* `GRAPH_DUMP` should be used for printing entire graphs after optimization passes
1261* `GRAPH_UPDATE` should be used for reporting graph transformations (i.e. node deletion, constant folding, etc)
1262* `GRAPH_DEBUG` should be used for providing information useful for debugging
1263  the internals of a particular optimization pass or analysis
1264
1265The current logging level is `GRAPH_UPDATE` meaning that both `GRAPH_DUMP` and `GRAPH_UPDATE` will be enabled when
1266one specifies a file(s) in `PYTORCH_JIT_LOG_LEVEL`.
1267
1268`GRAPH_DEBUG` can be enabled by prefixing a file name with an `>` as in `>alias_analysis`.
1269`>>` and `>>>` are also valid and **currently** are equivalent to `GRAPH_DEBUG` as there is no logging level that is
1270higher than `GRAPH_DEBUG`.
1271
1272By default, types in the graph are printed with maximum verbosity.  The verbosity level can be controlled via the environment variable `PYTORCH_JIT_TYPE_VERBOSITY`.  The available settings are:
1273
1274* `0`: No type information
1275* `1`: Types and shapes only
1276* `2`: Also print strides
1277* `3`: Also print device type and whether gradient is required
1278
1279## JIT Optimization Limiter ##
1280
1281[jit_opt_limit.h](jit_opt_limit.h)
1282
1283Often times, we need to limit the number of optimizations for any lowering passes for debugging purposes.
1284
1285`TorchScript` offers a simple optimization limit checker that can be configured through environment variable `PYTORCH_JIT_OPT_LIMIT`. The purpose is to limit how many optimization you can make per pass. This is useful for debugging any passes.
1286
1287Opt limit checker is enabled on a per file basis (hence per pass). For example, in `constant_propagation.cpp`, `PYTORCH_JIT_OPT_LIMIT` should be set to `constant_propagation=<opt_limit>` where `<opt_limit>` is the number of optimizations you want to make for the pass. (i.e.
1288`PYTORCH_JIT_OPT_LIMIT="constant_propagation=<opt_limit>"`).
1289
1290Multiple files can be configured by separating each file name with a colon
1291`:` as in the following example,
1292`PYTORCH_JIT_OPT_LIMIT="constant_propagation=<opt_limit>:dead_code_elimination=<opt_limit>"`
1293
1294You can call opt limiter by calling a macro `JIT_OPT_ALLOWED`. It will return true if
1295we haven't reached the optimization limit yet. Typical usage:
1296
1297```cpp
1298if (!JIT_OPT_ALLOWED) {
1299    GRAPH_DUMP(...); //supplied from jit_log
1300    return;
1301}
1302```
1303
1304## DifferentiableGraphOp ##
1305
1306[runtime/graph_executor.cpp](runtime/graph_executor.cpp)
1307
1308
1309A DifferentiableGraphOp combines an explicit forward `Graph` `f` with a paired backward graph `df`. When it runs, the input `Tensors` to `f` are detached from the autograd, the body of `f` is run, and then the autograd graph for the outputs of `f` are hooked up to the `df` function. The `df` function's outputs are also hooked up to the autograd graph.
1310
1311## Handling Mutability ##
1312### Aliasing and mutation in the PyTorch API
1313In PyTorch, `Tensors` are reference types. Operators can return "views" of the input `Tensor`, creating a new `Tensor` object that shares the same underlying storage as the original:
1314```python
1315a = torch.rand(2, 3)
1316b = a
1317# At this point, `a` and `b` share their storage.
1318c = b[0]
1319# `c` shares storage with `a` and `b`, but only sees a slice of the allocated memory.
1320```
1321
1322Some operators will *mutate* one or more of their operands in-place. These are typically denoted with a trailing underscore, or by taking an `out` argument as input:
1323```python
1324a = torch.zeros(2, 3)
1325b = torch.ones(2, 3)
1326a.add_(b)  # in-place add, so `a` is modified.
1327torch.add(a, b, out=a) # another way to express the same thing
1328```
1329
1330### Aliasing and mutation annotations in FunctionSchema
1331The JIT's `FunctionSchema`  allows operator writers to add annotations specifying the aliasing and mutation behavior of an operator. Optimization passes will use this information to determine whether transformations are semantics-preserving. This section provides a description of the alias annotation language, assuming that the reader already knows what `FunctionSchema` looks like.
1332
1333First, here is a pure function which always returns new memory:
1334```
1335add(Tensor a, Tensor b) -> Tensor
1336```
1337The type `Tensor` with no annotations is sugar for "fresh, read-only `Tensor`". So since there are no annotations on anything, we know that this operator creates no aliases and mutates no inputs.
1338
1339Next, a function that returns an alias to one of the inputs.:
1340```
1341view(Tensor(a) self, int[] size) -> Tensor(a)
1342```
1343The shared `(a)` annotation on `self` and the output signify that the `Tensors` will share the same storage. Another way to say is that `self` and the output belong to the same "alias set" `a`.
1344
1345Now a function that writes in-place to one of the inputs (note the trailing underscore):
1346```
1347add_(Tensor(a!) self, Tensor other) -> Tensor(a!)
1348```
1349The `!` annotation means that this operator writes to the specified alias set (in this case `a`).
1350
1351Sometimes we don't have enough information to provide an exact alias annotation. For example, here is the operator to extract an element from a list:
1352```
1353list_select(Tensor[] list, int idx) -> Tensor(*)
1354```
1355Note the alias set `*`. This is the **wildcard set**. These are values which we conservatively analyze. Containers, such as lists and dictionaries, Graph inputs, and class attributes are conservatively analyzed to all alias. In most cases, people shouldn't be writing operators with wildcard annotations. They are used as temporary workaround for when our alias analysis isn't sophisticated enough to understand something yet but we don't want to block feature development.
1356
1357Similarly, we have operators which result in Tensors being contained in a list. In this case, to preserve the relationship between output list and input, we annotate that the input enters the wildcard set with the `(a -> *)` syntax.
1358```
1359func: chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]
1360```
1361
1362This annotation language is consumed by the `FunctionSchema` parser, which produces `AliasInfo` objects summarizing the aliasing relationships for each schema `Argument`.
1363
1364### Marking custom ops as side-effectful
1365
1366Sometimes, one will register a custom op that is side-effectful. For example, an op that does logging might take in a tensor (or other input), but not return anything. Without further annotation, these types of ops will often be dead-code-eliminated by TorchScript.
1367
1368To mark a custom op as side-effectful, or otherwise mark it to be handled conservatively by the alias analysis, it can be marked as `c10::AliasAnalysisKind::CONSERVATIVE`:
1369
1370```c++
1371TORCH_LIBRARY(my_library, m) {
1372  m.def(torch::schema(
1373    "my_logging_op(Tensor data) -> ()",
1374    c10::AliasAnalysisKind::CONSERVATIVE"));
1375}
1376```
1377
1378### Alias Analysis in the IR
1379
1380[ir/alias_analysis.h](ir/alias_analysis.h)
1381
1382An alias analysis pass consumes the per-operator aliasing information to construct a database of aliasing and mutation relationships in a graph, called `AliasDb`. This section focuses on the alias analysis pass; the public interface to `AliasDb` will be described later.
1383
1384The core data structure in the AliasDb is called `MemoryDAG`, which is a DAG where the edges are "may point to" relationships and the  vertices are aliasing `Element`s. The most common kind of `Element` is an IR `Value`, but there are other kinds of things that can alias that aren't first-class `Value`s in the IR, like wildcards or contained types (such as in a list or tuple).
1385
1386The alias analysis pass walks through the nodes in a graph, examining schema `AliasInfo`  objects and adding edges in the `MemoryDAG` accordingly. For example, for the node:
1387```
1388%output : Tensor = aten::view(%self, %size)
1389```
1390the analyzer will examine the schema for `view()`:
1391```
1392view(Tensor(a) self, int[] size) -> Tensor(a)
1393```
1394and add an edge from `%output` to `%self`. The alias analysis pass is flow-insensitive, as we are only adding "points-to" edges when processing a node.
1395
1396As a more involved example, the following TorchScript snippet:
1397```python
1398@torch.jit.script
1399def foo(a : Tensor, b : Tensor):
1400  c = 2 * b
1401  a += 1
1402  if a.max() > 4:
1403    r = a[0]
1404  else:
1405    r = b[0]
1406  return c, r
1407```
1408Will produce a graph like this:
1409
1410![AliasTracker graph](/docs/source/_static/img/aliastracker_graph.png)
1411
1412A few things to note:
1413- "Graph Input Element" is an example of an `Element` that isn't a first-class `Value`. Alias analysis happens on a per-function level, so we don't necessarily know the aliasing relationships of the inputs. The only safe assumption is that `a` and `b` may alias each other, so they point to a special `Element` that describes "the world outside of this function".
1414- `r` may point to either `a` or `b`, depending on the runtime value of `a.max()`.  A given `Element` may point to multiple other `Element`s. This can happen if there is branching control flow (like in this example), or with certain ops like `contiguous()`, which either returns an alias to the input or a fresh `Tensor`, depending on the runtime characteristics of the input.
1415- `c` is a fresh `Tensor` (i.e. it doesn't point to anything) since it was created using the pure operation `2 * b`.
1416
1417The last point demonstrates a key concept: *leaf elements uniquely describe memory locations*. Since a leaf element doesn't point to anything, the memory that backs it must have been freshly allocated by some op. Thus we can use leaf elements to represent disjoint memory locations.
1418
1419So to determine whether  `a` and `b` may alias, we traverse the `MemoryDAG` DAG and figure out if `a` and `b` share any leaf nodes. If they do, then we know `a` and `b` might point to the same memory location, i.e. `a` and `b` may alias. This kind of query is common enough that `MemoryDAG` does path compression to speed up leaf-finding, so that aliasing queries can be serviced in amortized constant time.
1420
1421### Writing optimization passes with `AliasDb`
1422`AliasDb` provides a high-level interface to help people write mutability-safe optimization passes.
1423
1424In particular, `moveAfterTopologicallyValid()` (and its `moveBefore` variant) will reorder nodes in a way that preserves data dependencies and avoids any data hazards.  The rules for this are that all mutable *writes* to a given memory location must occur in the same order (avoid WAW hazards), and that no reads can be reordered before or after any write (WAR, RAW hazards).
1425
1426However, reordering of reads across writes *is allowed* if we can prove that the read cannot alias the thing being written. This happens whenever we have `Tensors` that come from functions that produce fresh results (common) inside of the function. It also happens whenever the creation of the mutable `Tensor` is seen in the function (so it gets assigned a fresh variable), and all of its writes occur in that function.
1427
1428The intention is that if you only mutate the graph through `AliasDb`, you don't have to think about mutability/aliasing at all in your pass. As we write more passes, the interface to `AliasDb` will get richer (one example is transforming an in-place operation to its pure equivalent if we can prove it's safe).
1429
1430`AliasDb` also provides lower level APIs that users of LLVM's alias analysis pass would be familiar with, such as querying whether any two `Value`s may alias.
1431
1432TODO: differentiation, symbolic autograd, fusion, operators
1433
1434# Profiling Programs
1435
1436`prim::profile` nodes are inserted on every **use** of a value by `ProfilingRecord::instrumentBlock`. Every `prim::profile` node runs a lambda that uses a captured, initial type value and the type of an incoming `Tensor` and merges the two into a refined `TensorType`.
1437
1438`prim::profile` nodes are replaced with `prim::Guard` nodes by `InsertGuards`. `prim::Guard` nodes are inserted to guarantee that beyond the guard a guarded `Tensor` will always be of the profiled shape. This guarantee will enable optimizations and code generators to generate more efficient code.
1439
1440We attempt to reduce the number of `prim::Guard` nodes as these nodes may interfere with optimizations.
1441* First, `GuardElimination::moveGuardsToDefs` tries to move `prim::Guards` to their definitions, so the guards guarding the same `Tensor` follow the definition directly or another guard on the same `Tensor`.
1442* This ordering allows us to **coalesce** (done in `GuardElimination::coalesceGuards`) multiple guards into a single one.
1443* After guards are  **coalesced** , `GuardElimination::eliminateGuards` attempts to eliminate more guards as follows: it inspects each operation and its inputs. It checks if inputs to the operation are guarded and also if the operation produces the consistent shapes given the guarded inputs. For example, if two inputs to `add` are guaranteed to be of shape `(2, 3)`, the output shape will also always be `(2, 3)`. If this property holds, we are allowed to remove the guard guarding operation's output.
1444
1445Lastly, we need to be handle cases when the assumptions about `Tensor` shapes fail at runtime. To handle guard failures, we need to be able to run the original code i.e. the code  that doesn't rely on assumptions about shapes. As guards can be inserted and moved (by Optimizer) at/to arbitrary points in a computational graph, we need to be able to resume execution starting from those arbitrary points onward.
1446
1447`InsertBailoutNodes` builds deoptimized versions of the original computational graph, that contain the rest of computations starting from their corresponding guard failure points and also captures live values needed to execute those deoptimized graphs. In other words, the pass replaces `prim::Guard` nodes with `prim::BailOut` nodes which have the`attr::Subgraph` attributes set to the deoptimized versions of the  remaining computations at their corresponding `prim::Guard`s.
1448
1449# Saving Programs
1450
1451See [the serialization docs](docs/serialization.md).
1452
1453# Testing Programs
1454## Testing Autodiff ##
1455
1456[runtime/symbolic_script.cpp](runtime/symbolic_script.cpp)
1457
1458When differentiating a graph, each node that has a symbolic gradient will be included in a `prim::DifferentiableGraph`. We fall back to using autograd for the node if there isn't a gradient formula for it.
1459Adding/updating symbolic gradient functions must be tested carefully as it's easy to get CI green by comparing autograd result with itself, but potentially cause an autodiff support regression.
1460
1461If your PR adds/updates a gradient formula for `torch`/`nn` functions, you **MUST** enable/update the corresponding tests in
1462- `torch` functions: `method_tests` in [common_method_tests.py](../../../test/common_method_tests.py)
1463- `nn` functions: `nn_functional_tests` in [test_jit.py](../../../test/test_jit.py)
1464
1465To turn on autodiff check, you can add an optional `check_ad(should_autodiff_node[bool], nonfusible_nodes[str|list[str]], fusible_nodes[str|list[str]])` tuple after the optional test variant name field.
1466If `should_autodiff_node=True`, the differentiated traced/script forward graph must have a `prim::DifferentiableGraph`.
1467
1468All nodes in `nonfusible_nodes` should show up at least once in `prim::DifferentiableGraph` subgraphs.
1469When fusion is enabled, all nodes in `fusible_nodes` should show up in one of `prim::FusionGroup` graphs attached to `prim::DifferentiableGraph`,
1470otherwise they're checked as `nonfusible_nodes` as well.
1471On the other hand, if `should_autodiff_node=False`, the graph can still have `prim::DifferentiableGraph` with other nodes, but not `nonfusible_nodes` and `fusible_nodes`.
1472
1473To make writing tests easier, you only need to write out node names if it's different from the function name. Below are a few examples:
1474```python
1475('conv1d', ...), # No symbolic gradient formula
1476('avg_pool2d', ..., (True,)), # Has symbolic gradient formula, only has one nonfusible node aten::avg_pool2d
1477('nll_loss', ..., (True, 'aten::nll_loss_forward')), # Is replaced by a different node in its symbolic gradient formula
1478('dropout', ..., (True, ['prim::is_CUDA', 'aten::bernoulli_'], ['aten::rand_like', ..., 'aten::div'])), # Some ops are fused when fusion is enabled
1479```
1480
1481Note that even for the same function, different tests could trigger different function schemas (e.g `aten::add`) while only a few of them have symbolic gradient formulas.
1482You should only turn on autodiff checks in tests that have symbolic gradients. If you are not sure, uncomment the debugging line in [runtime/symbolic_script.cpp](runtime/symbolic_script.cpp)
1483to check which function schema the test triggers.
1484
1485# Python Printer
1486
1487[serialization/python_print.cpp](serialization/python_print.cpp)
1488[serialization/import_source.cpp](serialization/import_source.cpp)
1489
1490The Python Printer takes a `Graph` and produces Python-like code that represents the same graph. Using some special values in [serialization/import_source.cpp](serialization/import_source.cpp), this code can be read back in by the compiler to produce the same `Graph`. In Python a `ScriptModule`'s `code` property shows the Python Printed graph.
1491
1492The table below shows the graph and code for this small `ScriptModule`:
1493```python
1494class M(torch.jit.ScriptModule):
1495    @torch.jit.script_method
1496    def forward(self, x, y, z):
1497        # type: (Tensor, int, float) -> Tensor
1498        if y > 2:
1499            x = x + z
1500        else:
1501            x = x + y
1502        return x
1503
1504m = M()
1505```
1506
1507`m.graph`
1508```
1509graph(%x.1 : Tensor,
1510      %y : int,
1511      %z : float):
1512  %5 : int = prim::Constant[value=1]()
1513  %3 : int = prim::Constant[value=2]()
1514  %4 : bool = aten::gt(%y, %3)
1515  %x : Tensor = prim::If(%4)
1516    block0():
1517      %x.2 : Tensor = aten::add(%x.1, %z, %5)
1518      -> (%x.2)
1519    block1():
1520      %x.3 : Tensor = aten::add(%x.1, %y, %5)
1521      -> (%x.3)
1522  return (%x)
1523```
1524
1525`m.code`
1526```python
1527def forward(self,
1528    x: Tensor,
1529    y: int,
1530    z: float) -> Tensor:
1531  if torch.gt(y, 2):
1532    x0 = torch.add(x, z, 1)
1533  else:
1534    x0 = torch.add(x, y, 1)
1535  return x0
1536```
1537
1538# Python Bindings
1539
1540TODO: Script Module, torch.jit.trace, __constant__ handling, weak script modules
1541
1542## Graph Manipulation
1543
1544Python bindings for manipulating TorchScript IR exists in [python_ir.cpp](https://github.com/pytorch/pytorch/blob/58e7ec5843e63ee044e0a4f5aa2583a056a64078/torch/csrc/jit/python/python_ir.cpp#L4). In general, graph structures should look the same as the representation described above in [Core Program Representation](#core-program-representation).
1545
1546Things to watch out for:
1547* You may need to first inline your graph (`torch._C._jit_pass_inline`) or recursively traverse CallFunction nodes (`for x in graph.findAllNodes("prim::CallFunction")`) if you want to recursively modify your graph and the functions it calls
1548* To insert a graph after node n, use the context manager `with graph.insert_point_guard(new_node)`
1549
1550See more examples in [test_python_ir.py](https://github.com/pytorch/pytorch/blob/main/test/jit/test_python_ir.py)
1551