1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 5 namespace torch::jit { 6 7 TORCH_API bool canFuseOnCPULegacy(); 8 TORCH_API void overrideCanFuseOnCPULegacy(bool value); 9 10 // NB: Be sure to run DCE before fusion, because dead instructions 11 // can prevent fusion opportunities from being exploited. 12 // On Windows will noop, NYI 13 TORCH_API void FuseGraph( 14 std::shared_ptr<Graph>& graph, 15 bool strict_fuser_check = false); 16 17 // \brief Custom fusion pass using a node-level callback to 18 // determine the inclusion of nodes in a subgraph. 19 // 20 // This helper omits aliased inputs and fusion across control flow 21 // boundaries. 22 // 23 // \arg graph The graph to be modified in-place 24 // \arg is_fusable A callback run on each fusable node in the graph. 25 // \arg kind The label given to the resultant fused subgraph 26 // \arg arg_limit The maximum number of args the resultant fused subgraph 27 // should have. Note: This will likely develop into a general 28 // post condition on the fused subgraph. 29 TORCH_API void CustomFuseGraph( 30 std::shared_ptr<Graph>& graph, 31 const std::function<bool(Node*)>& is_fusable, 32 Symbol kind, 33 size_t arg_limit = std::numeric_limits<size_t>::max()); 34 35 } // namespace torch::jit 36