1 #pragma once 2 3 #include <vector> 4 5 #include <ATen/core/function.h> 6 #include <ATen/core/function_schema.h> 7 #include <ATen/core/ivalue.h> 8 #include <torch/csrc/jit/mobile/code.h> 9 10 namespace torch::jit { 11 enum OpCode : uint8_t; 12 struct Instruction; 13 struct OperatorString; 14 15 namespace mobile { 16 17 class TORCH_API Function : public torch::jit::Function { 18 public: 19 explicit Function(c10::QualifiedName name); 20 Function( 21 c10::QualifiedName name, 22 Code code, 23 std::optional<c10::FunctionSchema> schema); 24 void run(Stack& stack) override; 25 at::IValue operator()(Stack& stack); ensure_defined()26 void ensure_defined() override {} 27 size_t num_inputs() const override; 28 const c10::QualifiedName& qualname() const override; 29 bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) override; 30 31 // NOTE: the APIs below is dangerous: if you call append_instruction with 32 // dbg_handle and then call it without; then the dbg_handle will become 33 // misaligned. Therefore only use ONE variant at time. 34 void append_instruction(OpCode op, int64_t X, int64_t N, int64_t dbg_handle); 35 void append_instruction(OpCode op, int64_t X, int64_t N); 36 void append_operator( 37 const std::string& name, 38 const std::string& overload_name, 39 const std::optional<int>& num_specified_args); 40 void append_constant(const c10::IValue& constant); 41 void append_type(const c10::TypePtr& type); 42 void append_function(mobile::Function& func); 43 44 void set_register_size(size_t size); 45 46 int64_t get_debug_handle(size_t pc) const; 47 const Code& get_code() const; 48 Code& get_code(); 49 50 torch::jit::Function& setSchema(c10::FunctionSchema schema) override; 51 bool hasSchema() const; 52 const c10::FunctionSchema& getSchema() const override; 53 54 // Returns the debug handle corresponding to where the execution 55 // is halted due to exception. 56 // If no corresponding debug handle is found then -1 is returned. 57 const std::vector<int64_t>& getExceptionDebugHandles() const; 58 static Function& registerFunc( 59 const std::string& qualified_name, 60 const std::vector<Instruction>& instructions, 61 const std::vector<c10::IValue>& constants, 62 const std::vector<c10::TypePtr>& types, 63 const size_t register_size); 64 65 // if not initialize, initialize by loading operators. 66 // return true of all op loaded, return false if some op is not found 67 // in the current runtime. Then, the ops that did not found will be filled 68 // in unsupported_op_names 69 bool initialize_operators(bool should_check_operators); 70 71 private: 72 c10::QualifiedName name_; 73 Code code_; 74 std::optional<c10::FunctionSchema> schema_; // (byte-code version 4+) 75 }; 76 77 std::optional<std::function<void(Stack&)>> makeOperatorFunction( 78 const c10::OperatorName& opname, 79 std::optional<int> num_specified_args); 80 81 TORCH_API std::string operator_str(const c10::OperatorName& opname); 82 83 } // namespace mobile 84 } // namespace torch::jit 85