1 #pragma once 2 #include <memory> 3 #include <optional> 4 #include <vector> 5 6 #include <ATen/ThreadLocalState.h> 7 #include <ATen/core/ivalue.h> 8 #include <ATen/core/jit_type.h> 9 #include <torch/csrc/Export.h> 10 #include <torch/csrc/jit/frontend/source_range.h> 11 12 C10_DECLARE_bool(torch_jit_disable_warning_prints); 13 C10_DECLARE_bool(torch_jit_enable_rethrow_caught_exception); 14 15 namespace at { 16 class Tensor; 17 TORCH_API void launch(std::function<void()> func); 18 } // namespace at 19 namespace c10 { 20 struct IValue; 21 struct OperatorName; 22 } // namespace c10 23 24 namespace torch::jit { 25 26 // The interpreter run Graphs with Tensor inputs and Tensor outputs 27 // a separate component in the autograd handles unwrapping and wrapping 28 // variable objects for use in the interpreter. 29 namespace interpreter { 30 struct CodeImpl; 31 } 32 33 struct Node; 34 struct GraphExecutor; 35 struct InterpreterStateImpl; 36 struct Graph; 37 struct Node; 38 struct Instruction; 39 using Stack = std::vector<c10::IValue>; 40 using c10::ivalue::Future; 41 using TaskLauncher = std::function<void(std::function<void()>)>; 42 43 struct TORCH_API Code { 44 Code() = default; 45 explicit Code(interpreter::CodeImpl* pImpl); 46 // remaining_bailout_depth is irrelevant in a `Code` object unless the `Code` 47 // is directly created by `GraphExecutor` in which case it's likely to contain 48 // `prim::BailOut`s to control the maximum depth of bailout chains 49 explicit Code( 50 const std::shared_ptr<Graph>& graph, 51 std::string function_name, 52 size_t remaining_bailout_depth = 0); 53 54 const std::vector<GraphExecutor*>& grad_executors(); 55 const std::vector<GraphExecutor*>& diff_graph_op_executors(); 56 57 explicit operator bool() const { 58 return pImpl != nullptr; 59 } 60 size_t num_inputs() const; 61 size_t num_outputs() const; 62 size_t num_bailouts() const; 63 const std::vector<c10::IValue>& constant_table() const; 64 const std::vector<c10::TypePtr>& type_table() const; 65 const std::vector<Instruction>& instructions() const; 66 const std::unordered_map<std::string, size_t>& op_to_num_specified_args() 67 const; 68 const std::vector<Node*>& instructions_source() const; 69 void request_bailout(size_t index); 70 size_t register_size() const; 71 std::shared_ptr<Graph> graph() const; 72 73 private: 74 std::shared_ptr<interpreter::CodeImpl> pImpl; 75 friend struct InterpreterStateImpl; 76 friend std::ostream& operator<<(std::ostream& out, const Code& code); 77 }; 78 79 struct TORCH_API MobileCode : Code { 80 explicit MobileCode( 81 const std::shared_ptr<Graph>& graph, 82 std::string function_name, 83 bool emit_default_input_instructions = true, 84 bool support_default_args_before_out = true, 85 bool emit_promoted_ops = true, 86 size_t remaining_bailout_depth = 0); 87 }; 88 89 struct InterpreterState { 90 TORCH_API InterpreterState( 91 const Code& code, 92 TaskLauncher taskLauncher = at::launch); 93 TORCH_API void run(Stack& stack); 94 TORCH_API c10::intrusive_ptr<Future> runAsync(Stack& stack); 95 c10::intrusive_ptr<Future> getFuture(); 96 97 private: 98 InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl); 99 // Ideally we should use c10::intrusive_ptr<InterpreterStateImpl> for pImpl; 100 // but intrusive_ptr requires full definition of InterpreterStateImpl, 101 // which we need to hide in the header. 102 c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl; 103 friend struct InterpreterStateImpl; 104 }; 105 106 // Created by wait() 107 struct Suspend : public std::exception { whatSuspend108 const char* what() const noexcept override { 109 return "Suspend"; 110 } 111 112 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) SuspendSuspend113 explicit Suspend(c10::intrusive_ptr<Future> future_) 114 : future(std::move(future_)) {} 115 116 c10::intrusive_ptr<Future> future; 117 }; 118 119 // InterpreterContinuation propagates dist_autograd_context_id 120 // through (and only through) the forward pass manually, other 121 // thread local settings are propagated with ThreadLocalState 122 struct InterpreterContinuation { 123 InterpreterContinuation( 124 InterpreterState state_, 125 Stack stack_, 126 int64_t dist_autograd_context_id = 0, 127 std::optional<at::ThreadLocalState> tls_state = std::nullopt) stateInterpreterContinuation128 : state(std::move(state_)), 129 stack(std::move(stack_)), 130 tls_state_(std::move(tls_state)) 131 #ifdef USE_DISTRIBUTED 132 , 133 dist_autograd_context_id_(dist_autograd_context_id) 134 #endif 135 { 136 } 137 138 void operator()(); 139 140 private: 141 InterpreterState state; 142 Stack stack; 143 std::optional<at::ThreadLocalState> tls_state_ = std::nullopt; 144 #ifdef USE_DISTRIBUTED 145 int64_t dist_autograd_context_id_; 146 #endif 147 }; 148 149 // what is the tensors type, including state from the current execution context 150 // that modifies how the tensor behaves. For instance if no_grad is enabled 151 // this will cause the TensorType to have requires_grad=False. 152 TORCH_API at::TensorTypePtr tensorTypeInCurrentExecutionContext( 153 const at::Tensor& t); 154 155 // current (TLS) TorchScript interpreter callstack 156 TORCH_API std::vector<StackEntry> currentCallstack(); 157 TORCH_API std::vector<std::string> currentModuleHierarchy(); 158 159 } // namespace torch::jit 160