xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/interface.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/fuser/interface.h>
2 
3 #include <torch/csrc/jit/codegen/fuser/compiler.h>
4 #include <torch/csrc/jit/codegen/fuser/executor.h>
5 #include <torch/csrc/jit/codegen/fuser/fallback.h>
6 #include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
7 
8 #include <c10/util/Flags.h>
9 #include <stdexcept>
10 
11 namespace torch::jit {
12 
13 namespace detail {
14 
15 #ifdef TORCH_ENABLE_LLVM
16 bool cpu_fuser_enabled = true;
17 #else
18 bool cpu_fuser_enabled = false;
19 #endif
20 
21 // note: this doesn't necessarily enable NNC because NVFuser might override it
22 bool gpu_fuser_enabled = true;
23 
24 } // namespace detail
25 
registerFusion(const Node * fusion_group)26 int64_t registerFusion(const Node* fusion_group) {
27   return fuser::registerFusion(fusion_group);
28 }
29 
runFusion(const int64_t key,Stack & stack)30 void runFusion(const int64_t key, Stack& stack) {
31   const auto result = fuser::runFusion(key, stack);
32   if (!result)
33     fuser::runFallback(key, stack);
34 }
35 
canFuseOnCPU()36 bool canFuseOnCPU() {
37   return fuser::hasFusionBackend(DeviceType::CPU) && detail::cpu_fuser_enabled;
38 }
39 
canFuseOnGPU()40 bool canFuseOnGPU() {
41   return fuser::hasFusionBackend(DeviceType::CUDA) && detail::gpu_fuser_enabled;
42 }
43 
overrideCanFuseOnCPU(bool value)44 void overrideCanFuseOnCPU(bool value) {
45   detail::cpu_fuser_enabled = value;
46 }
47 
overrideCanFuseOnGPU(bool value)48 void overrideCanFuseOnGPU(bool value) {
49   detail::gpu_fuser_enabled = value;
50 }
51 
52 // Uses the above interface by stuffing the graph into a node and treating that
53 // node as a fusion group.
debugLaunchGraph(Graph & graph,at::ArrayRef<at::Tensor> inputs)54 std::vector<at::Tensor> debugLaunchGraph(
55     Graph& graph,
56     at::ArrayRef<at::Tensor> inputs) {
57   // Creates a fusion group node
58   auto wrapper_graph = std::make_shared<Graph>();
59   Node* fusion_group = wrapper_graph->insertNode(
60       wrapper_graph->createWithSubgraph(prim::FusionGroup));
61   fusion_group->g_(attr::Subgraph, graph.copy());
62   for (size_t i = 0; i < graph.inputs().size(); ++i) {
63     fusion_group->addInput(wrapper_graph->addInput());
64   }
65   for (size_t i = 0; i < graph.outputs().size(); ++i) {
66     wrapper_graph->registerOutput(fusion_group->addOutput());
67   }
68 
69   // Creates the stack, registers and runs the fusion
70   Stack stack = fmap<IValue>(inputs);
71   const auto key = fuser::registerFusion(fusion_group);
72   fuser::runFusion(key, stack);
73   return fmap(stack, [](const IValue& iv) { return iv.toTensor(); });
74 }
75 
debugGetFusedKernelCode(Graph & graph,at::ArrayRef<at::Tensor> inputs)76 std::string debugGetFusedKernelCode(
77     Graph& graph,
78     at::ArrayRef<at::Tensor> inputs) {
79   // Creates a fusion group node
80   auto wrapper_graph = std::make_shared<Graph>();
81   Node* fusion_group = wrapper_graph->insertNode(
82       wrapper_graph->createWithSubgraph(prim::FusionGroup));
83   fusion_group->g_(attr::Subgraph, graph.copy());
84   for (size_t i = 0; i < graph.inputs().size(); ++i) {
85     fusion_group->addInput(wrapper_graph->addInput());
86   }
87   for (size_t i = 0; i < graph.outputs().size(); ++i) {
88     wrapper_graph->registerOutput(fusion_group->addOutput());
89   }
90 
91   // Creates the stack, registers and runs the fusion
92   Stack stack = fmap<IValue>(inputs);
93   const auto key = fuser::registerFusion(fusion_group);
94 
95   std::string code;
96   if (!fuser::runFusion(key, stack, &code)) {
97     throw std::runtime_error("Could not run fusion for graph");
98   }
99 
100   return code;
101 }
102 
nCompiledKernels()103 size_t nCompiledKernels() {
104   return fuser::nCompiledKernels();
105 }
106 
107 } // namespace torch::jit
108