1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <ATen/core/stack.h> 5 #include <torch/csrc/Export.h> 6 #include <torch/csrc/jit/ir/ir.h> 7 8 #include <cstdint> 9 #include <memory> 10 #include <vector> 11 12 namespace torch::jit { 13 14 constexpr int kCPUDevice = -1; 15 16 // Assigns a "key" to the given fusion_group that it can use to run its 17 // fusion later (via runFusion() below). 18 TORCH_API int64_t registerFusion(const Node* fusion_group); 19 20 // Runs the fusion corresponding to the given key on the inputs 21 // found on the stack. Outputs are placed on the same stack. 22 // In some cases a fusion cannot be run and a fallback path where 23 // PyTorch's interpreter runs the graph instead is attempted. 24 TORCH_API void runFusion(const int64_t key, Stack& stack); 25 26 // True if the respective devices can fuse, false otherwise 27 TORCH_API bool canFuseOnCPU(); 28 TORCH_API bool canFuseOnGPU(); 29 30 // Sets whether fusion on the CPU is allowed (disabled by default due to 31 // flakiness) 32 TORCH_API void overrideCanFuseOnCPU(bool value); 33 34 // Sets whether fusion on CPU must use LLVM Codegen and not SimplieIREval 35 TORCH_API void overrideMustUseLLVMOnCPU(bool value); 36 37 // Sets whether fusion on the GPU is allowed (enabled by default) 38 TORCH_API void overrideCanFuseOnGPU(bool value); 39 40 // Treats the given graph as a fusion group and launches it on the 41 // specified device with the given inputs. 42 // Returns the outputs. 43 TORCH_API std::vector<at::Tensor> debugLaunchGraph( 44 Graph& graph, 45 at::ArrayRef<at::Tensor> inputs); 46 47 // Treats the given graph as a fusion group and returns the generated code. 48 TORCH_API std::string debugGetFusedKernelCode( 49 Graph& graph, 50 at::ArrayRef<at::Tensor> inputs); 51 52 TORCH_API size_t nCompiledKernels(); 53 54 } // namespace torch::jit 55