1 #include <torch/csrc/jit/codegen/fuser/interface.h>
2
3 #include <torch/csrc/jit/codegen/fuser/compiler.h>
4 #include <torch/csrc/jit/codegen/fuser/executor.h>
5 #include <torch/csrc/jit/codegen/fuser/fallback.h>
6 #include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
7
8 #include <c10/util/Flags.h>
9 #include <stdexcept>
10
11 namespace torch::jit {
12
13 namespace detail {
14
15 #ifdef TORCH_ENABLE_LLVM
16 bool cpu_fuser_enabled = true;
17 #else
18 bool cpu_fuser_enabled = false;
19 #endif
20
21 // note: this doesn't necessarily enable NNC because NVFuser might override it
22 bool gpu_fuser_enabled = true;
23
24 } // namespace detail
25
registerFusion(const Node * fusion_group)26 int64_t registerFusion(const Node* fusion_group) {
27 return fuser::registerFusion(fusion_group);
28 }
29
runFusion(const int64_t key,Stack & stack)30 void runFusion(const int64_t key, Stack& stack) {
31 const auto result = fuser::runFusion(key, stack);
32 if (!result)
33 fuser::runFallback(key, stack);
34 }
35
canFuseOnCPU()36 bool canFuseOnCPU() {
37 return fuser::hasFusionBackend(DeviceType::CPU) && detail::cpu_fuser_enabled;
38 }
39
canFuseOnGPU()40 bool canFuseOnGPU() {
41 return fuser::hasFusionBackend(DeviceType::CUDA) && detail::gpu_fuser_enabled;
42 }
43
overrideCanFuseOnCPU(bool value)44 void overrideCanFuseOnCPU(bool value) {
45 detail::cpu_fuser_enabled = value;
46 }
47
overrideCanFuseOnGPU(bool value)48 void overrideCanFuseOnGPU(bool value) {
49 detail::gpu_fuser_enabled = value;
50 }
51
52 // Uses the above interface by stuffing the graph into a node and treating that
53 // node as a fusion group.
debugLaunchGraph(Graph & graph,at::ArrayRef<at::Tensor> inputs)54 std::vector<at::Tensor> debugLaunchGraph(
55 Graph& graph,
56 at::ArrayRef<at::Tensor> inputs) {
57 // Creates a fusion group node
58 auto wrapper_graph = std::make_shared<Graph>();
59 Node* fusion_group = wrapper_graph->insertNode(
60 wrapper_graph->createWithSubgraph(prim::FusionGroup));
61 fusion_group->g_(attr::Subgraph, graph.copy());
62 for (size_t i = 0; i < graph.inputs().size(); ++i) {
63 fusion_group->addInput(wrapper_graph->addInput());
64 }
65 for (size_t i = 0; i < graph.outputs().size(); ++i) {
66 wrapper_graph->registerOutput(fusion_group->addOutput());
67 }
68
69 // Creates the stack, registers and runs the fusion
70 Stack stack = fmap<IValue>(inputs);
71 const auto key = fuser::registerFusion(fusion_group);
72 fuser::runFusion(key, stack);
73 return fmap(stack, [](const IValue& iv) { return iv.toTensor(); });
74 }
75
debugGetFusedKernelCode(Graph & graph,at::ArrayRef<at::Tensor> inputs)76 std::string debugGetFusedKernelCode(
77 Graph& graph,
78 at::ArrayRef<at::Tensor> inputs) {
79 // Creates a fusion group node
80 auto wrapper_graph = std::make_shared<Graph>();
81 Node* fusion_group = wrapper_graph->insertNode(
82 wrapper_graph->createWithSubgraph(prim::FusionGroup));
83 fusion_group->g_(attr::Subgraph, graph.copy());
84 for (size_t i = 0; i < graph.inputs().size(); ++i) {
85 fusion_group->addInput(wrapper_graph->addInput());
86 }
87 for (size_t i = 0; i < graph.outputs().size(); ++i) {
88 wrapper_graph->registerOutput(fusion_group->addOutput());
89 }
90
91 // Creates the stack, registers and runs the fusion
92 Stack stack = fmap<IValue>(inputs);
93 const auto key = fuser::registerFusion(fusion_group);
94
95 std::string code;
96 if (!fuser::runFusion(key, stack, &code)) {
97 throw std::runtime_error("Could not run fusion for graph");
98 }
99
100 return code;
101 }
102
nCompiledKernels()103 size_t nCompiledKernels() {
104 return fuser::nCompiledKernels();
105 }
106
107 } // namespace torch::jit
108