1 #pragma once 2 3 #include <c10/macros/Export.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 #include <torch/csrc/jit/passes/pass_manager.h> 6 #include <torch/csrc/jit/runtime/profiling_record.h> 7 8 /* 9 * This file contains APIs for cuda fuser; 10 * 11 * We use an empty static struct to hold the function pointers, which are 12 * registered separately. This is to support cpu-only compilation. 13 * Registration is done in torch/csrc/jit/codegen/cuda/register_interface.cpp 14 */ 15 16 namespace torch { 17 namespace jit { 18 namespace fuser { 19 namespace cuda { 20 21 TORCH_API std::atomic<bool>& getCudaFusionGuardMode(); 22 23 TORCH_API bool getSingletonFusion(); 24 TORCH_API bool setSingletonFusion(bool value); 25 TORCH_API bool getHorizontalFusion(); 26 TORCH_API bool setHorizontalFusion(bool value); 27 28 // dummy struct to allow API registration 29 struct CudaFuserInterface { 30 void (*fn_compile_n)(Node*) = nullptr; 31 void (*fn_run_n_s)(const Node*, Stack&) = nullptr; 32 void (*fn_fuse_graph)(std::shared_ptr<Graph>&) = nullptr; 33 bool (*fn_can_fuse_n)(const Node*) = nullptr; 34 void (*fn_insert_profile_inodes)(ProfilingRecord* pr) = nullptr; 35 bool (*fn_profile_n)(const Node*) = nullptr; 36 bool (*fn_skip_n)(const std::string&, bool flip) = nullptr; 37 }; 38 39 // Get interface, this is used by registration and user facing API internally 40 TORCH_API CudaFuserInterface* getFuserInterface(); 41 42 TORCH_API void compileFusionGroup(Node* fusion_node); 43 TORCH_API void runFusionGroup(const Node* fusion_node, Stack& stack); 44 TORCH_API void fuseGraph(std::shared_ptr<Graph>&); 45 TORCH_API bool canFuseNode(const Node* node); 46 TORCH_API void InsertProfileNodesForCUDAFuser(ProfilingRecord* pr); 47 TORCH_API bool profileNode(const Node* node); 48 49 TORCH_API bool skipNode(const std::string& symbol_str, bool flip = true); 50 51 TORCH_API bool isEnabled(); 52 TORCH_API bool setEnabled(bool is_enabled); 53 TORCH_API bool canBeEnabled(); 54 55 } // namespace cuda 56 } // namespace fuser 57 } // namespace jit 58 } // namespace torch 59