xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/cuda/interface.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Export.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/passes/pass_manager.h>
6 #include <torch/csrc/jit/runtime/profiling_record.h>
7 
8 /*
9  * This file contains APIs for cuda fuser;
10  *
11  * We use an empty static struct to hold the function pointers, which are
12  * registered separately. This is to support cpu-only compilation.
13  * Registration is done in torch/csrc/jit/codegen/cuda/register_interface.cpp
14  */
15 
16 namespace torch {
17 namespace jit {
18 namespace fuser {
19 namespace cuda {
20 
21 TORCH_API std::atomic<bool>& getCudaFusionGuardMode();
22 
23 TORCH_API bool getSingletonFusion();
24 TORCH_API bool setSingletonFusion(bool value);
25 TORCH_API bool getHorizontalFusion();
26 TORCH_API bool setHorizontalFusion(bool value);
27 
28 // dummy struct to allow API registration
29 struct CudaFuserInterface {
30   void (*fn_compile_n)(Node*) = nullptr;
31   void (*fn_run_n_s)(const Node*, Stack&) = nullptr;
32   void (*fn_fuse_graph)(std::shared_ptr<Graph>&) = nullptr;
33   bool (*fn_can_fuse_n)(const Node*) = nullptr;
34   void (*fn_insert_profile_inodes)(ProfilingRecord* pr) = nullptr;
35   bool (*fn_profile_n)(const Node*) = nullptr;
36   bool (*fn_skip_n)(const std::string&, bool flip) = nullptr;
37 };
38 
39 // Get interface, this is used by registration and user facing API internally
40 TORCH_API CudaFuserInterface* getFuserInterface();
41 
42 TORCH_API void compileFusionGroup(Node* fusion_node);
43 TORCH_API void runFusionGroup(const Node* fusion_node, Stack& stack);
44 TORCH_API void fuseGraph(std::shared_ptr<Graph>&);
45 TORCH_API bool canFuseNode(const Node* node);
46 TORCH_API void InsertProfileNodesForCUDAFuser(ProfilingRecord* pr);
47 TORCH_API bool profileNode(const Node* node);
48 
49 TORCH_API bool skipNode(const std::string& symbol_str, bool flip = true);
50 
51 TORCH_API bool isEnabled();
52 TORCH_API bool setEnabled(bool is_enabled);
53 TORCH_API bool canBeEnabled();
54 
55 } // namespace cuda
56 } // namespace fuser
57 } // namespace jit
58 } // namespace torch
59