xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_ir.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/utils/object_ptr.h>
5 
6 namespace torch::jit {
7 
8 void initPythonIRBindings(PyObject* module);
9 
10 // execute a Python function, used for Ops we can't optimize but that we want to
11 // optimize around
12 struct ConcretePythonOp : public PythonOp {
13   static Symbol Kind;
14 
ConcretePythonOpConcretePythonOp15   ConcretePythonOp(Graph* graph) : PythonOp(graph, ::c10::prim::PythonOp) {}
initConcretePythonOp16   ConcretePythonOp* init(
17       THPObjectPtr&& pyobj,
18       const std::string& cconv,
19       pyobj_list&& scalar_args) {
20     this->pyobj = std::move(pyobj);
21     this->scalar_args = std::move(scalar_args);
22     this->cconv = cconv;
23     return this;
24   }
25   // The Python object which contains the implementation of this function.
26   // This is either a class (non-legacy) or an object (legacy).  See
27   // TraceInterpreterState for execution semantics.
28   THPObjectPtr pyobj;
29   // The calling convention for the Python function.
30   // 'c' -- constant argument
31   // 'd' -- dynamic argument
32   std::string cconv;
33   // Scalar arguments to the Python function.  Not necessarily passed to
34   // the function in this order; see cconv for the correct order.
35   std::vector<THPObjectPtr> scalar_args;
36 
37   std::string name() const override;
38   void cloneFrom(Node* other_) override;
allocNewInstanceConcretePythonOp39   Node* allocNewInstance(Graph* g) override {
40     return new ConcretePythonOp(g);
41   }
42   // recover the autograd.Function instance, if this PythonOp's function
43   // was originally SomeFunction.apply
44   // used in ONNX for discovering symbolics
45   std::optional<THPObjectPtr> autogradFunction() const override;
46   void writeScalars(std::ostream& out) const override;
47   void lint_python() const override;
48 };
49 
50 } // namespace torch::jit
51