xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/api/method.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/function.h>
4 #include <ATen/core/ivalue.h>
5 #include <ATen/core/stack.h>
6 #include <torch/csrc/api/include/torch/imethod.h>
7 #include <torch/csrc/jit/api/function_impl.h>
8 
9 namespace torch::jit {
10 
11 using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
12 
13 // A method in a module, e.g. f in:
14 //
15 // class M(ScriptModule):
16 //   @script_method
17 //   def f(self, x):
18 //     ...
19 // Note: because Method/Module are exposed to python these
20 // classes use python method naming conventions
21 struct TORCH_API Method : public torch::IMethod {
22   Method(ObjectPtr owner, Function* function);
23 
24   // the module that contains this method.
25   Module owner() const;
26   // the raw objectptr that owns this method, for when the method is owned by a
27   // torchbind object.
28   ObjectPtr raw_owner() const;
29   void run(Stack& stack);
runMethod30   void run(Stack&& stack) {
31     run(stack);
32   }
33 
34   c10::IValue operator()(
35       std::vector<c10::IValue> stack,
36       const Kwargs& kwargs = Kwargs()) const override;
37 
38   // Run method async. Invocation on this function would invokes a JIT
39   // interpreter that executes ops inline, one by one, on caller's thread. A
40   // model can utilize async op, i.e. `fork`, to launch an asynchronous task
41   // which will be launched on provided `taskLauncher`.
42   c10::intrusive_ptr<c10::ivalue::Future> run_async(
43       std::vector<c10::IValue> stack,
44       const Kwargs& kwargs = Kwargs(),
45       TaskLauncher taskLauncher = at::launch);
46 
graphMethod47   std::shared_ptr<Graph> graph() const {
48     return toGraphFunction(*function_).graph();
49   }
50 
nameMethod51   const std::string& name() const override {
52     return function_->name();
53   }
54 
num_inputsMethod55   size_t num_inputs() const {
56     return function_->num_inputs();
57   }
58 
get_executorMethod59   GraphExecutor& get_executor() {
60     return toGraphFunction(*function_).get_executor();
61   }
62 
functionMethod63   Function& function() const {
64     return *function_;
65   }
66 
67  private:
68   void setArgumentNames(std::vector<std::string>&) const override;
69 
70   // Methods are uniqued onwed by a single module. This raw pointer allows
71   // looking up the module.
72   ObjectPtr owner_;
73 
74   // Underlying unbound function
75   Function* function_;
76 };
77 
78 namespace script {
79 // We once had a `script::` namespace that was deleted. This is for backcompat
80 // of the public API; new code should not use this type alias.
81 using Method = ::torch::jit::Method;
82 } // namespace script
83 
84 } // namespace torch::jit
85