xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/cuda/interface.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/cuda/interface.h>
2 
3 #include <ATen/DynamicLibrary.h>
4 #include <ATen/core/dispatch/OperatorOptions.h>
5 #include <ATen/native/NonSymbolicBC.h>
6 #include <ATen/native/TensorShape.h>
7 #include <c10/util/CallOnce.h>
8 #include <c10/util/irange.h>
9 #include <torch/csrc/jit/runtime/custom_operator.h>
10 #include <torch/csrc/jit/runtime/register_ops_utils.h>
11 
12 namespace torch {
13 namespace jit {
14 namespace fuser {
15 namespace cuda {
16 
17 static std::atomic<bool> cuda_fusion_guard_mode{true};
18 
isEnabled()19 bool isEnabled() {
20   TORCH_WARN_ONCE("torch::jit::fuser::cuda::isEnabled() is deprecated");
21   return false;
22 }
23 
setEnabled(bool is_enabled)24 bool setEnabled(bool is_enabled) {
25   TORCH_WARN_ONCE("torch::jit::fuser::cuda::setEnabled() is deprecated");
26   TORCH_INTERNAL_ASSERT(
27       !is_enabled,
28       "nvfuser support in torchscript is removed and cannot be enabled!");
29   return false;
30 }
31 
canBeEnabled()32 bool canBeEnabled() {
33   TORCH_WARN_ONCE(
34       "torch::jit::fuser::cuda::nvfuserCanBeEnabled() is deprecated");
35   return false;
36 }
37 
getSingletonFusion()38 bool getSingletonFusion() {
39   TORCH_WARN_ONCE(
40       "torch::jit::fuser::cuda::getSingletonFusion() is deprecated");
41   return false;
42 }
43 
setSingletonFusion(bool value)44 bool setSingletonFusion(bool value) {
45   TORCH_WARN_ONCE(
46       "torch::jit::fuser::cuda::setSingletonFusion() is deprecated");
47   TORCH_INTERNAL_ASSERT(
48       !value,
49       "nvfuser support in torchscript is removed and singleton fusion cannot be enabled!");
50   return false;
51 }
52 
getHorizontalFusion()53 bool getHorizontalFusion() {
54   TORCH_WARN_ONCE(
55       "torch::jit::fuser::cuda::getHorizontalFusion() is deprecated");
56   return false;
57 }
58 
setHorizontalFusion(bool value)59 bool setHorizontalFusion(bool value) {
60   TORCH_WARN_ONCE(
61       "torch::jit::fuser::cuda::setHorizontalFusion() is deprecated");
62   TORCH_INTERNAL_ASSERT(
63       !value,
64       "nvfuser support in torchscript is removed and horizontal fusion cannot be enabled!");
65   return false;
66 }
67 
getCudaFusionGuardMode()68 std::atomic<bool>& getCudaFusionGuardMode() {
69   TORCH_WARN_ONCE(
70       "torch::jit::fuser::cuda::getCudaFusionGuardMode() is deprecated");
71   return cuda_fusion_guard_mode;
72 }
73 
getFuserInterface()74 CudaFuserInterface* getFuserInterface() {
75   static CudaFuserInterface fuser_interface_;
76   return &fuser_interface_;
77 }
78 
compileFusionGroup(Node * fusion_node)79 void compileFusionGroup(Node* fusion_node) {
80   TORCH_WARN_ONCE(
81       "torch::jit::fuser::cuda::compileFusionGroup() is deprecated");
82   TORCH_CHECK(
83       getFuserInterface()->fn_compile_n != nullptr,
84       "Running the CUDA fuser requires a CUDA build.");
85   getFuserInterface()->fn_compile_n(fusion_node);
86 }
87 
runFusionGroup(const Node * fusion_node,Stack & stack)88 void runFusionGroup(const Node* fusion_node, Stack& stack) {
89   TORCH_WARN_ONCE("torch::jit::fuser::cuda::runFusionGroup() is deprecated");
90   TORCH_CHECK(
91       getFuserInterface()->fn_run_n_s != nullptr,
92       "Running the CUDA fuser requires a CUDA build.");
93   getFuserInterface()->fn_run_n_s(fusion_node, stack);
94 }
95 
fuseGraph(std::shared_ptr<Graph> & graph)96 void fuseGraph(std::shared_ptr<Graph>& graph) {
97   if (!isEnabled()) {
98     return;
99   }
100 
101   TORCH_WARN_ONCE("nvfuser integration in TorchScript is deprecated.");
102   TORCH_CHECK(
103       getFuserInterface()->fn_fuse_graph != nullptr,
104       "Running the CUDA fuser requires a CUDA build.");
105   getFuserInterface()->fn_fuse_graph(graph);
106 }
107 
canFuseNode(const Node * node)108 bool canFuseNode(const Node* node) {
109   TORCH_WARN_ONCE("torch::jit::fuser::cuda::canFuseNode() is deprecated");
110   return getFuserInterface()->fn_can_fuse_n != nullptr &&
111       getFuserInterface()->fn_can_fuse_n(node);
112 }
113 
InsertProfileNodesForCUDAFuser(ProfilingRecord * pr)114 void InsertProfileNodesForCUDAFuser(ProfilingRecord* pr) {
115   TORCH_WARN_ONCE(
116       "torch::jit::fuser::cuda::InsertProfileNodesForCUDAFuser() is deprecated");
117   if (getFuserInterface()->fn_insert_profile_inodes) {
118     getFuserInterface()->fn_insert_profile_inodes(pr);
119   }
120 }
121 
profileNode(const Node * node)122 bool profileNode(const Node* node) {
123   TORCH_WARN_ONCE("torch::jit::fuser::cuda::profileNode() is deprecated");
124   return getFuserInterface()->fn_profile_n != nullptr &&
125       getFuserInterface()->fn_profile_n(node);
126 }
127 
skipNode(const std::string & symbol_str,bool flip)128 bool skipNode(const std::string& symbol_str, bool flip) {
129   TORCH_WARN_ONCE("torch::jit::fuser::cuda::skipNode() is deprecated");
130   return getFuserInterface()->fn_skip_n != nullptr &&
131       getFuserInterface()->fn_skip_n(symbol_str, flip);
132 }
133 
134 } // namespace cuda
135 } // namespace fuser
136 } // namespace jit
137 } // namespace torch
138