1 #pragma once 2 // This file is temporary until native_functions.yaml and derivatives.yaml are 3 // merged. Ideally this should all go into native_functions.yaml 4 5 #include <torch/csrc/Export.h> 6 #include <torch/csrc/jit/ir/ir.h> 7 8 namespace torch::jit { 9 10 TORCH_API std::optional<std::shared_ptr<Graph>> GetDecomposition( 11 const FunctionSchema& schema); 12 13 TORCH_API void RegisterDecomposition( 14 const FunctionSchema& schema, 15 std::shared_ptr<Graph> g); 16 17 TORCH_API void RunDecompositions(std::shared_ptr<Graph> g); 18 19 TORCH_API std::optional<GraphFunction*> GetDecompositionFunction( 20 const FunctionSchema& schema); 21 22 // For invocation in C++, recommended is to assign to static local variable 23 TORCH_API Function* GetDecompositionExecutor(const char* schema_literal); 24 25 TORCH_API Function* GetDecompositionExecutor(const FunctionSchema& schema); 26 27 TORCH_API void run_jit_decomposition( 28 const c10::OperatorHandle& op, 29 torch::jit::Stack* stack); 30 31 TORCH_API bool has_jit_decomposition(const FunctionSchema& schema); 32 33 } // namespace torch::jit 34