xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/compiler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/fuser/compiler.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/core/jit_type.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/jit/codegen/fuser/codegen.h>
8 #include <torch/csrc/jit/codegen/fuser/interface.h>
9 #include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
10 #include <torch/csrc/jit/codegen/fuser/tensor_desc.h>
11 #include <torch/csrc/jit/ir/ir.h>
12 #include <torch/csrc/jit/passes/canonicalize.h>
13 #include <torch/csrc/jit/passes/shape_analysis.h>
14 #include <torch/csrc/jit/runtime/operator.h>
15 
16 #include <atomic>
17 #include <iostream>
18 #include <memory>
19 #include <sstream>
20 #include <stdexcept>
21 #include <string>
22 #include <tuple>
23 #include <unordered_set>
24 #include <utility>
25 
26 namespace {
fusionBackendLock()27 std::mutex& fusionBackendLock() {
28   static std::mutex fusion_backends_lock_{};
29   return fusion_backends_lock_;
30 }
31 } // namespace
32 
33 namespace torch::jit::fuser {
34 
35 static std::unordered_map<at::Device::Type, FusedKernelConstructor>&
getFusionBackends()36 getFusionBackends() {
37   static std::unordered_map<at::Device::Type, FusedKernelConstructor>
38       fusion_backends;
39   return fusion_backends;
40 }
41 
registerFusionBackend(at::Device::Type backend_type,FusedKernelConstructor ctor)42 void registerFusionBackend(
43     at::Device::Type backend_type,
44     FusedKernelConstructor ctor) {
45   std::lock_guard<std::mutex> guard(fusionBackendLock());
46   getFusionBackends()[backend_type] = std::move(ctor);
47 }
48 
hasFusionBackend(at::Device::Type backend_type)49 bool hasFusionBackend(at::Device::Type backend_type) {
50   std::lock_guard<std::mutex> guard(fusionBackendLock());
51   return getFusionBackends().count(backend_type);
52 }
53 
getConstructor(at::Device::Type backend_type)54 static const FusedKernelConstructor& getConstructor(
55     at::Device::Type backend_type) {
56   std::lock_guard<std::mutex> guard(fusionBackendLock());
57   return getFusionBackends().at(backend_type);
58 }
59 
60 // Counter for number of kernels compiled, used for debugging and
61 // creating arbitrary kernel names.
62 static std::atomic<size_t> next_kernel_id{0};
63 static int debug_fusion{-1};
64 
nCompiledKernels()65 size_t nCompiledKernels() {
66   return next_kernel_id.load();
67 }
68 
debugFuser()69 int debugFuser() {
70   if (debug_fusion < 0) {
71     const char* debug_env = getenv("PYTORCH_FUSION_DEBUG");
72     debug_fusion = debug_env ? atoi(debug_env) : 0;
73   }
74   return debug_fusion;
75 }
76 
77 // If the given node is used once by a chunk node, returns that node.
78 // Returns nullptr otherwise.
usedInFusedChunk(const Value * input)79 static const Node* usedInFusedChunk(const Value* input) {
80   const auto& uses = input->uses();
81   if (uses.size() == 1) {
82     const Node* user = uses[0].user;
83     if (user->kind() == prim::ConstantChunk) {
84       return user;
85     }
86   }
87   return nullptr;
88 }
89 
setInputChunkDescriptors(KernelSpec & spec)90 static void setInputChunkDescriptors(KernelSpec& spec) {
91   // We only have as many chunk descriptors as tensor inputs,
92   // furthermore we know that the tensor inputs are in the
93   // beginning of the fusion group's inputs.
94   spec.inputChunks().reserve(spec.nTensorInputs());
95   for (const auto i : c10::irange(spec.nTensorInputs())) {
96     const Value* input = spec.graph()->inputs()[i];
97     if (const Node* chunk = usedInFusedChunk(input)) {
98       spec.inputChunks().emplace_back(
99           chunk->i(attr::chunks), chunk->i(attr::dim));
100     } else {
101       spec.inputChunks().emplace_back(1, 0);
102     }
103   }
104 }
105 
106 // Run a DFS traversal to find all inputs that affect a given output value
getInputDependencies(const Value * output)107 static std::vector<int64_t> getInputDependencies(const Value* output) {
108   std::vector<const Value*> queue{output};
109   std::unordered_set<const Value*> inputs;
110   std::unordered_set<const Value*> seen;
111   while (!queue.empty()) {
112     const Value* val = queue.back();
113     queue.pop_back();
114     const Node* producer = val->node();
115     // Here we assume that only tensor inputs are used in
116     // the computation of the outputs.
117     // This is currently true, as the only inputs will be
118     // sizes (for _grad_sum_to_size as the derivative
119     // of broadcasts), which will only be used after
120     // the fusion kernel, and Tensors.
121     // This needs to be revisited when you start allowing
122     // other things e.g. nonconstant scalars.
123     if (producer->kind() == prim::Param &&
124         val->type()->isSubtypeOf(*TensorType::get())) {
125       inputs.insert(val);
126       continue;
127     }
128     for (const Value* input : producer->inputs()) {
129       if (/*bool inserted = */ seen.insert(input).second) {
130         queue.push_back(input);
131       }
132     }
133   }
134 
135   // Convert Value* into offsets into the graph's input list
136   std::vector<int64_t> offsets;
137   offsets.reserve(inputs.size());
138   for (const Value* input : inputs) {
139     offsets.push_back(input->offset());
140   }
141 
142   std::sort(offsets.begin(), offsets.end());
143   return offsets;
144 }
145 
setInputBroadcastGroups(KernelSpec & spec)146 static void setInputBroadcastGroups(KernelSpec& spec) {
147   std::unordered_set<std::vector<int64_t>, c10::hash<std::vector<int64_t>>>
148       broadcast_groups;
149   for (const Value* output : (spec.graph())->outputs()) {
150     if (output->node()->kind() == prim::FusedConcat) {
151       for (const Value* concat_input : output->node()->inputs()) {
152         broadcast_groups.insert(getInputDependencies(concat_input));
153       }
154     } else {
155       broadcast_groups.insert(getInputDependencies(output));
156     }
157   }
158   std::copy(
159       broadcast_groups.begin(),
160       broadcast_groups.end(),
161       std::back_inserter(spec.inputBroadcastGroups()));
162 }
163 
164 // Performs "upfront" compilation where storage is known but shapes are not.
165 // Currently identifies how to expand all tensors so that all intermediate
166 // tensors are the same shape, simplifying code generation.
167 // Broadcast groups and chunks are identified without shape information
168 // using logical properties of how each works. In particular, tensors
169 // are always expandable to the outputs of pointwise operations they
170 // or their descendants are involved in, which means that in a DAG of
171 // pointwise operations all tensors are expandable to the (single) output.
172 // Note: The logic is slightly complicated by concatenation and chunking.
upfrontCompilation(KernelSpec & spec)173 static void upfrontCompilation(KernelSpec& spec) {
174   setInputBroadcastGroups(spec);
175   setInputChunkDescriptors(spec);
176 }
177 
registerFusion(const Node * fusion_group)178 int64_t registerFusion(const Node* fusion_group) {
179   auto graph = normalizeGraphForCache(fusion_group->g(attr::Subgraph));
180 
181   // Don't re-register the fusion if we can use a pre-existing one
182   const auto maybe_spec = lookupGraph(graph);
183   if (maybe_spec) {
184     return (*maybe_spec)->key();
185   }
186 
187   // Unconditionally create and register the fusion
188   // This is necessary to support our global disable fusions flag: if someone
189   // runs some code under no-fusions mode and then runs some code with fusions
190   // enabled, the second time around the returned spec from the cache should
191   // be a valid spec (must have had upfrontCompilation run on it).
192   const auto key = store(graph);
193   const auto maybe_retrieved_spec = retrieve(key);
194   AT_ASSERT(maybe_retrieved_spec);
195   upfrontCompilation(**maybe_retrieved_spec);
196 
197   return key;
198 }
199 
compileKernel(const KernelSpec & spec,const ArgSpec & arg_spec,const std::vector<int64_t> & map_size,const at::Device device)200 std::shared_ptr<FusedKernel> compileKernel(
201     const KernelSpec& spec,
202     const ArgSpec& arg_spec,
203     const std::vector<int64_t>& map_size,
204     const at::Device device) {
205   const std::vector<TensorDesc>& input_desc = arg_spec.descs();
206 
207   auto graph = spec.graph()->copy();
208 
209   for (const auto i : c10::irange(input_desc.size())) {
210     const auto& desc = input_desc[i];
211 
212     // TODO: can't get rid of this use of TensorType
213     // until we switch to ProfilingGraphExecutor, so we don't have to
214     // run PropagateInputShapes below
215     graph->inputs()[i]->setType(TensorType::create(
216         desc.scalar_type,
217         device,
218         {desc.nDim()},
219         false)); // TODO: nDim is bad, as it is collapsed
220   }
221 
222   PropagateInputShapes(graph);
223 
224   // Creates chunk and flattened input descriptions
225   std::vector<PartitionDesc> chunk_desc;
226   std::vector<std::pair<const Value*, const std::optional<TensorDesc>>>
227       flat_inputs;
228   {
229     size_t input_index = 0;
230     for (const auto& p : graph->inputs()) {
231       if (p->type()->isSubtypeOf(*FloatType::get())) {
232         flat_inputs.emplace_back(p, std::nullopt);
233       }
234       if (!p->type()->isSubtypeOf(*TensorType::get())) {
235         continue;
236       }
237       if (const Node* chunk = usedInFusedChunk(p)) {
238         int64_t dim = chunk->i(attr::dim);
239         int64_t chunks = chunk->i(attr::chunks);
240         chunk_desc.emplace_back(input_desc[input_index++], chunks, dim);
241         for (const auto* o : chunk->outputs()) {
242           flat_inputs.emplace_back(o, *chunk_desc.back().subTensorDesc());
243         }
244       } else {
245         chunk_desc.emplace_back();
246         flat_inputs.emplace_back(p, input_desc[input_index++]);
247       }
248     }
249   }
250 
251   // Creates output, concat, and flattened output descriptions
252   std::vector<TensorDesc> output_desc;
253   std::vector<PartitionDesc> concat_desc;
254   std::vector<std::pair<const Value*, const TensorDesc>> flat_outputs;
255   for (const Value* o : graph->outputs()) {
256     // Creates output description
257     std::vector<int64_t> sizes = map_size;
258     if (o->node()->kind() == prim::FusedConcat) {
259       sizes.at(o->node()->i(attr::dim)) *= o->node()->inputs().size();
260     }
261 
262     auto scalar_type = o->type()->expectRef<TensorType>().scalarType();
263     TORCH_INTERNAL_ASSERT(scalar_type);
264     auto type = TensorType::createContiguous(*scalar_type, device, sizes);
265     output_desc.emplace_back(type);
266     const auto& desc = output_desc.back();
267 
268     // Creates concat and flattened output descriptions (relies on output desc)
269     if (o->node()->kind() != prim::FusedConcat) {
270       concat_desc.emplace_back();
271       flat_outputs.emplace_back(o, desc);
272     } else {
273       const auto cat = o->node();
274       concat_desc.emplace_back(desc, cat->inputs().size(), cat->i(attr::dim));
275       for (const auto& c : cat->inputs()) {
276         flat_outputs.emplace_back(c, *concat_desc.back().subTensorDesc());
277       }
278     }
279   }
280 
281   const bool use_cuda = device.is_cuda();
282   const std::string name = "kernel_" + std::to_string(next_kernel_id++);
283   std::string code =
284       generateKernel(name, *graph, flat_inputs, flat_outputs, use_cuda);
285   const FusedKernelConstructor& kernel_ctor =
286       getConstructor(use_cuda ? DeviceType::CUDA : DeviceType::CPU);
287   return kernel_ctor(
288       device.index(),
289       name,
290       code,
291       input_desc,
292       output_desc,
293       chunk_desc,
294       concat_desc,
295       spec.hasRandom());
296 }
297 
298 } // namespace torch::jit::fuser
299