xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/function.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <vector>
4 
5 #include <ATen/core/function.h>
6 #include <ATen/core/function_schema.h>
7 #include <ATen/core/ivalue.h>
8 #include <torch/csrc/jit/mobile/code.h>
9 
10 namespace torch::jit {
11 enum OpCode : uint8_t;
12 struct Instruction;
13 struct OperatorString;
14 
15 namespace mobile {
16 
17 class TORCH_API Function : public torch::jit::Function {
18  public:
19   explicit Function(c10::QualifiedName name);
20   Function(
21       c10::QualifiedName name,
22       Code code,
23       std::optional<c10::FunctionSchema> schema);
24   void run(Stack& stack) override;
25   at::IValue operator()(Stack& stack);
ensure_defined()26   void ensure_defined() override {}
27   size_t num_inputs() const override;
28   const c10::QualifiedName& qualname() const override;
29   bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) override;
30 
31   // NOTE: the APIs below is dangerous: if you call append_instruction with
32   // dbg_handle and then call it without; then the dbg_handle will become
33   // misaligned. Therefore only use ONE variant at time.
34   void append_instruction(OpCode op, int64_t X, int64_t N, int64_t dbg_handle);
35   void append_instruction(OpCode op, int64_t X, int64_t N);
36   void append_operator(
37       const std::string& name,
38       const std::string& overload_name,
39       const std::optional<int>& num_specified_args);
40   void append_constant(const c10::IValue& constant);
41   void append_type(const c10::TypePtr& type);
42   void append_function(mobile::Function& func);
43 
44   void set_register_size(size_t size);
45 
46   int64_t get_debug_handle(size_t pc) const;
47   const Code& get_code() const;
48   Code& get_code();
49 
50   torch::jit::Function& setSchema(c10::FunctionSchema schema) override;
51   bool hasSchema() const;
52   const c10::FunctionSchema& getSchema() const override;
53 
54   // Returns the debug handle corresponding to where the execution
55   // is halted due to exception.
56   // If no corresponding debug handle is found then -1 is returned.
57   const std::vector<int64_t>& getExceptionDebugHandles() const;
58   static Function& registerFunc(
59       const std::string& qualified_name,
60       const std::vector<Instruction>& instructions,
61       const std::vector<c10::IValue>& constants,
62       const std::vector<c10::TypePtr>& types,
63       const size_t register_size);
64 
65   // if not initialize, initialize by loading operators.
66   // return true of all op loaded, return false if some op is not found
67   // in the current runtime. Then, the ops that did not found will be filled
68   // in unsupported_op_names
69   bool initialize_operators(bool should_check_operators);
70 
71  private:
72   c10::QualifiedName name_;
73   Code code_;
74   std::optional<c10::FunctionSchema> schema_; // (byte-code version 4+)
75 };
76 
77 std::optional<std::function<void(Stack&)>> makeOperatorFunction(
78     const c10::OperatorName& opname,
79     std::optional<int> num_specified_args);
80 
81 TORCH_API std::string operator_str(const c10::OperatorName& opname);
82 
83 } // namespace mobile
84 } // namespace torch::jit
85