1 #include <torch/csrc/jit/codegen/cuda/interface.h>
2
3 #include <ATen/DynamicLibrary.h>
4 #include <ATen/core/dispatch/OperatorOptions.h>
5 #include <ATen/native/NonSymbolicBC.h>
6 #include <ATen/native/TensorShape.h>
7 #include <c10/util/CallOnce.h>
8 #include <c10/util/irange.h>
9 #include <torch/csrc/jit/runtime/custom_operator.h>
10 #include <torch/csrc/jit/runtime/register_ops_utils.h>
11
12 namespace torch {
13 namespace jit {
14 namespace fuser {
15 namespace cuda {
16
17 static std::atomic<bool> cuda_fusion_guard_mode{true};
18
isEnabled()19 bool isEnabled() {
20 TORCH_WARN_ONCE("torch::jit::fuser::cuda::isEnabled() is deprecated");
21 return false;
22 }
23
setEnabled(bool is_enabled)24 bool setEnabled(bool is_enabled) {
25 TORCH_WARN_ONCE("torch::jit::fuser::cuda::setEnabled() is deprecated");
26 TORCH_INTERNAL_ASSERT(
27 !is_enabled,
28 "nvfuser support in torchscript is removed and cannot be enabled!");
29 return false;
30 }
31
canBeEnabled()32 bool canBeEnabled() {
33 TORCH_WARN_ONCE(
34 "torch::jit::fuser::cuda::nvfuserCanBeEnabled() is deprecated");
35 return false;
36 }
37
getSingletonFusion()38 bool getSingletonFusion() {
39 TORCH_WARN_ONCE(
40 "torch::jit::fuser::cuda::getSingletonFusion() is deprecated");
41 return false;
42 }
43
setSingletonFusion(bool value)44 bool setSingletonFusion(bool value) {
45 TORCH_WARN_ONCE(
46 "torch::jit::fuser::cuda::setSingletonFusion() is deprecated");
47 TORCH_INTERNAL_ASSERT(
48 !value,
49 "nvfuser support in torchscript is removed and singleton fusion cannot be enabled!");
50 return false;
51 }
52
getHorizontalFusion()53 bool getHorizontalFusion() {
54 TORCH_WARN_ONCE(
55 "torch::jit::fuser::cuda::getHorizontalFusion() is deprecated");
56 return false;
57 }
58
setHorizontalFusion(bool value)59 bool setHorizontalFusion(bool value) {
60 TORCH_WARN_ONCE(
61 "torch::jit::fuser::cuda::setHorizontalFusion() is deprecated");
62 TORCH_INTERNAL_ASSERT(
63 !value,
64 "nvfuser support in torchscript is removed and horizontal fusion cannot be enabled!");
65 return false;
66 }
67
getCudaFusionGuardMode()68 std::atomic<bool>& getCudaFusionGuardMode() {
69 TORCH_WARN_ONCE(
70 "torch::jit::fuser::cuda::getCudaFusionGuardMode() is deprecated");
71 return cuda_fusion_guard_mode;
72 }
73
getFuserInterface()74 CudaFuserInterface* getFuserInterface() {
75 static CudaFuserInterface fuser_interface_;
76 return &fuser_interface_;
77 }
78
compileFusionGroup(Node * fusion_node)79 void compileFusionGroup(Node* fusion_node) {
80 TORCH_WARN_ONCE(
81 "torch::jit::fuser::cuda::compileFusionGroup() is deprecated");
82 TORCH_CHECK(
83 getFuserInterface()->fn_compile_n != nullptr,
84 "Running the CUDA fuser requires a CUDA build.");
85 getFuserInterface()->fn_compile_n(fusion_node);
86 }
87
runFusionGroup(const Node * fusion_node,Stack & stack)88 void runFusionGroup(const Node* fusion_node, Stack& stack) {
89 TORCH_WARN_ONCE("torch::jit::fuser::cuda::runFusionGroup() is deprecated");
90 TORCH_CHECK(
91 getFuserInterface()->fn_run_n_s != nullptr,
92 "Running the CUDA fuser requires a CUDA build.");
93 getFuserInterface()->fn_run_n_s(fusion_node, stack);
94 }
95
fuseGraph(std::shared_ptr<Graph> & graph)96 void fuseGraph(std::shared_ptr<Graph>& graph) {
97 if (!isEnabled()) {
98 return;
99 }
100
101 TORCH_WARN_ONCE("nvfuser integration in TorchScript is deprecated.");
102 TORCH_CHECK(
103 getFuserInterface()->fn_fuse_graph != nullptr,
104 "Running the CUDA fuser requires a CUDA build.");
105 getFuserInterface()->fn_fuse_graph(graph);
106 }
107
canFuseNode(const Node * node)108 bool canFuseNode(const Node* node) {
109 TORCH_WARN_ONCE("torch::jit::fuser::cuda::canFuseNode() is deprecated");
110 return getFuserInterface()->fn_can_fuse_n != nullptr &&
111 getFuserInterface()->fn_can_fuse_n(node);
112 }
113
InsertProfileNodesForCUDAFuser(ProfilingRecord * pr)114 void InsertProfileNodesForCUDAFuser(ProfilingRecord* pr) {
115 TORCH_WARN_ONCE(
116 "torch::jit::fuser::cuda::InsertProfileNodesForCUDAFuser() is deprecated");
117 if (getFuserInterface()->fn_insert_profile_inodes) {
118 getFuserInterface()->fn_insert_profile_inodes(pr);
119 }
120 }
121
profileNode(const Node * node)122 bool profileNode(const Node* node) {
123 TORCH_WARN_ONCE("torch::jit::fuser::cuda::profileNode() is deprecated");
124 return getFuserInterface()->fn_profile_n != nullptr &&
125 getFuserInterface()->fn_profile_n(node);
126 }
127
skipNode(const std::string & symbol_str,bool flip)128 bool skipNode(const std::string& symbol_str, bool flip) {
129 TORCH_WARN_ONCE("torch::jit::fuser::cuda::skipNode() is deprecated");
130 return getFuserInterface()->fn_skip_n != nullptr &&
131 getFuserInterface()->fn_skip_n(symbol_str, flip);
132 }
133
134 } // namespace cuda
135 } // namespace fuser
136 } // namespace jit
137 } // namespace torch
138