xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/instruction.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 #include <typeinfo>
5 #include <unordered_set>
6 
7 namespace torch::jit {
8 // instruction look like:
9 // op_code X, N
10 // meaning of X, N depend on the op:
11 // O - index into operator table
12 // R - index into register table
13 // I - literal integer
14 // C - index into constant table
15 // P - jump offset relative to beginning of current instruction
16 // F - index into function table
17 // T - index into the type table, used for guard instructions
18 // S - index into object slots
19 // C - index into code table
20 
21 #define FORALL_OPCODES(_)                                                      \
22   _(OP, "O") /* invoke operator X */                                           \
23   _(OPN, "OI") /* invoke vararg operator X with N arguments */                 \
24   _(LOAD, "R") /* push a value from a register X */                            \
25   _(MOVE, "R") /* push a value from register X, clearing the register */       \
26   _(STOREN, "RI") /* store N values to registers [X, X+N) */                   \
27   _(STORE, "R") /* store 1 value to registers X */                             \
28   _(DROP, "") /* drop 1 value from the top of the stack */                     \
29   _(DROPR, "R") /* clear register X */                                         \
30   _(LOADC, "C") /* push the constant X */                                      \
31   _(JF, "P") /* pop the top of the stack, if false, branch to P */             \
32   _(JMP, "P") /* unconditional branch to X */                                  \
33   _(LOOP, "PI") /* perform a loop, X is where to branch if cond is false */    \
34   _(RET, "") /* exit execution */                                              \
35   _(WAIT, "") /* wait for a future to be complete */                           \
36   _(CALL, "F") /* call function X */                                           \
37   _(GUARD, "T") /* check a guard against type_table, true if passes */         \
38   _(TYPECHECK, "TN") /* check each type of input[i] against type_table[X+N] */ \
39   _(FAIL_GUARD, "T") /* fail a guard, patch back to GUARD */                   \
40   _(PROFILE_OP, "F") /* get a callback from profile_function_table at X */     \
41   _(TAIL_CALL, "F") /* replace current frame with function F */                \
42   _(INTERFACE_CALL, "CI") /* call method X on the first argument (of N) */     \
43   _(GET_ATTR, "S") /* get attribute from slot X in an Object */                \
44   _(SET_ATTR, "S") /* set attribute to slot X in an Object */                  \
45   _(LIST_UNPACK, "I") /* unpack list expecting length I */                     \
46   _(TUPLE_CONSTRUCT, "I") /* construct a tuple using X inputs */               \
47   _(NAMED_TUPLE_CONSTRUCT,                                                     \
48     "TI") /* construct a tuple of type X, using N inputs */                    \
49   _(LIST_CONSTRUCT, "TI") /* construct a list of type X, using N inputs */     \
50   _(DICT_CONSTRUCT, "TI") /* construct a dict of type X, using N inputs */     \
51   _(CREATE_OBJECT, "T") /* create an object of type X */                       \
52   _(ISINSTANCE, "TI") /* check object is one of  types[X:X+N]  */              \
53   _(TUPLE_SLICE, "II") /* slice tup[X:(X+N)] */                                \
54   _(TUPLE_INDEX, "") /* get the value from a tuple at that index */            \
55   _(RAISE_EXCEPTION, "") /* throws the exception from Python */                \
56   _(DICT_INDEX, "") /* gets the value from the dict for given key */           \
57   _(UNCHECKED_CAST, "") /* perform an unchecked cast operation */              \
58   _(__IS__, "") /* performs `is` operator from Python */                       \
59   _(UN_INITIALIZED,                                                            \
60     "") /* sets default values to variables that are uninitialized */          \
61   _(__ISNOT__, "") /* performs `is not` operator from Python  */               \
62   _(FORMAT, "I") /* performs string format function `f strings` or `{}.format` \
63                      the number of inputs in stored in X */                    \
64   _(DEVICE, "") /* invokes aten::device for a Tensor */                        \
65   _(DTYPE, "") /* invokes aten::dtype for a Tensor */                          \
66   _(DIM, "") /* invokes aten::dim for a Tensor */                              \
67   _(__NOT__, "") /* performs `not` operator from Python  */                    \
68   _(TO_LIST, "") /* convert the input to a list */                             \
69   _(NUM_TO_TENSOR,                                                             \
70     "") /* performs the conversion of a number/scalar to Tensor */             \
71   _(IS_CUDA, "") /* invokes aten::is_cuda for a Tensor */                      \
72   _(FORK, "CN") /* launch a thread to run code entry x with N inputs  */       \
73   _(WARN, "I") /* emit a warning with line information */                      \
74   _(ENTER, "EN") /* enter scope of a contextmanager */                         \
75   _(EXIT, "EX") /* exit the last entered contextmanager */                     \
76   _(AWAITABLE, "CN") /* initialize await for code entry x with N inputs  */
77 
78 enum OpCode : uint8_t {
79 #define DEFINE_OP(op, _) op,
80   FORALL_OPCODES(DEFINE_OP)
81 #undef DEFINE_OP
82 };
83 
84 struct Instruction {
85   OpCode op;
86   uint8_t unused;
87   uint16_t N;
88   int32_t X;
89   // TODO: check for overflow
InstructionInstruction90   Instruction(OpCode op, int32_t X, uint16_t N)
91       : op(op), unused(0), N(N), X(X) {}
92 };
93 std::ostream& operator<<(std::ostream& out, Instruction inst);
94 
95 bool isOpSupportedInMobile(OpCode op);
96 char const* toString(OpCode op);
97 OpCode parseOpCode(const char* str);
98 std::ostream& operator<<(std::ostream& out, Instruction inst);
99 
100 } // namespace torch::jit
101