xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/interface.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/core/stack.h>
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 
8 #include <cstdint>
9 #include <memory>
10 #include <vector>
11 
12 namespace torch::jit {
13 
14 constexpr int kCPUDevice = -1;
15 
16 // Assigns a "key" to the given fusion_group that it can use to run its
17 // fusion later (via runFusion() below).
18 TORCH_API int64_t registerFusion(const Node* fusion_group);
19 
20 // Runs the fusion corresponding to the given key on the inputs
21 // found on the stack. Outputs are placed on the same stack.
22 // In some cases a fusion cannot be run and a fallback path where
23 // PyTorch's interpreter runs the graph instead is attempted.
24 TORCH_API void runFusion(const int64_t key, Stack& stack);
25 
26 // True if the respective devices can fuse, false otherwise
27 TORCH_API bool canFuseOnCPU();
28 TORCH_API bool canFuseOnGPU();
29 
30 // Sets whether fusion on the CPU is allowed (disabled by default due to
31 // flakiness)
32 TORCH_API void overrideCanFuseOnCPU(bool value);
33 
34 // Sets whether fusion on CPU must use LLVM Codegen and not SimplieIREval
35 TORCH_API void overrideMustUseLLVMOnCPU(bool value);
36 
37 // Sets whether fusion on the GPU is allowed (enabled by default)
38 TORCH_API void overrideCanFuseOnGPU(bool value);
39 
40 // Treats the given graph as a fusion group and launches it on the
41 // specified device with the given inputs.
42 // Returns the outputs.
43 TORCH_API std::vector<at::Tensor> debugLaunchGraph(
44     Graph& graph,
45     at::ArrayRef<at::Tensor> inputs);
46 
47 // Treats the given graph as a fusion group and returns the generated code.
48 TORCH_API std::string debugGetFusedKernelCode(
49     Graph& graph,
50     at::ArrayRef<at::Tensor> inputs);
51 
52 TORCH_API size_t nCompiledKernels();
53 
54 } // namespace torch::jit
55