1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 #include <torch/csrc/utils/object_ptr.h> 5 6 namespace torch::jit { 7 8 void initPythonIRBindings(PyObject* module); 9 10 // execute a Python function, used for Ops we can't optimize but that we want to 11 // optimize around 12 struct ConcretePythonOp : public PythonOp { 13 static Symbol Kind; 14 ConcretePythonOpConcretePythonOp15 ConcretePythonOp(Graph* graph) : PythonOp(graph, ::c10::prim::PythonOp) {} initConcretePythonOp16 ConcretePythonOp* init( 17 THPObjectPtr&& pyobj, 18 const std::string& cconv, 19 pyobj_list&& scalar_args) { 20 this->pyobj = std::move(pyobj); 21 this->scalar_args = std::move(scalar_args); 22 this->cconv = cconv; 23 return this; 24 } 25 // The Python object which contains the implementation of this function. 26 // This is either a class (non-legacy) or an object (legacy). See 27 // TraceInterpreterState for execution semantics. 28 THPObjectPtr pyobj; 29 // The calling convention for the Python function. 30 // 'c' -- constant argument 31 // 'd' -- dynamic argument 32 std::string cconv; 33 // Scalar arguments to the Python function. Not necessarily passed to 34 // the function in this order; see cconv for the correct order. 35 std::vector<THPObjectPtr> scalar_args; 36 37 std::string name() const override; 38 void cloneFrom(Node* other_) override; allocNewInstanceConcretePythonOp39 Node* allocNewInstance(Graph* g) override { 40 return new ConcretePythonOp(g); 41 } 42 // recover the autograd.Function instance, if this PythonOp's function 43 // was originally SomeFunction.apply 44 // used in ONNX for discovering symbolics 45 std::optional<THPObjectPtr> autogradFunction() const override; 46 void writeScalars(std::ostream& out) const override; 47 void lint_python() const override; 48 }; 49 50 } // namespace torch::jit 51