xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/executor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/fuser/executor.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/ExpandUtils.h>
5 #include <ATen/core/functional.h>
6 #include <ATen/core/stack.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/jit/codegen/fuser/compiler.h>
9 #include <torch/csrc/jit/codegen/fuser/interface.h>
10 #include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
11 #include <torch/csrc/jit/codegen/fuser/kernel_spec.h>
12 #include <torch/csrc/jit/codegen/fuser/tensor_info.h>
13 #include <torch/csrc/jit/passes/graph_fuser.h>
14 #include <optional>
15 
16 #include <algorithm>
17 #include <vector>
18 
19 namespace torch::jit::fuser {
20 
21 // Returns the "map size" for this run, which is the common size for all
22 // intermediate tensors.
getMapSize(const KernelSpec & spec,at::TensorList args,at::IntArrayRef arg_subset)23 static std::optional<std::vector<int64_t>> getMapSize(
24     const KernelSpec& spec,
25     at::TensorList args,
26     at::IntArrayRef arg_subset) {
27   // TODO: this keeps reallocating map_size at every iteration, but we know
28   // exactly how much storage do we need, so this could be fixed in-place at
29   // every step. We're just missing a few functions for ATen, but the fix
30   // should be straightforward.
31   // Note: left unitialized since empty shape is broadcastable to any shape
32   std::vector<int64_t> map_size;
33   map_size.reserve(8);
34   for (const auto arg_idx : arg_subset) {
35     auto& arg = args.at(arg_idx);
36     auto& chunk_desc = spec.inputChunks().at(arg_idx);
37     if (chunk_desc.nSubTensors() == 1) {
38       try {
39         map_size = at::infer_size(map_size, arg.sizes());
40       } catch (...) {
41         return std::nullopt;
42       }
43     } else {
44       auto tensor_sizes = arg.sizes().vec();
45       const auto num_chunks = chunk_desc.nSubTensors();
46       const auto dim =
47           at::maybe_wrap_dim(chunk_desc.dim(), tensor_sizes.size());
48       if (tensor_sizes[dim] % num_chunks != 0) {
49         return std::nullopt;
50       }
51       tensor_sizes[dim] /= num_chunks;
52       try {
53         map_size = at::infer_size(map_size, tensor_sizes);
54       } catch (...) {
55         return std::nullopt;
56       }
57     }
58   }
59 
60   return {map_size};
61 }
62 
63 // Tries to determine a map size for the instantiated kernel (see above)
canRunKernel(const KernelSpec & spec,at::TensorList args)64 static std::optional<std::vector<int64_t>> canRunKernel(
65     const KernelSpec& spec,
66     at::TensorList args) {
67   // Short-circuits on size mismatch
68   TORCH_CHECK(
69       args.size() == spec.inputChunks().size(),
70       "Expected ",
71       spec.inputChunks().size(),
72       " arguments, but got ",
73       args.size());
74 
75   std::optional<std::vector<int64_t>> map_size;
76   for (const auto& broadcast_group : spec.inputBroadcastGroups()) {
77     if (!map_size) {
78       map_size = getMapSize(spec, args, broadcast_group);
79       if (!map_size)
80         return std::nullopt;
81     } else {
82       const auto group_map_size = getMapSize(spec, args, broadcast_group);
83       // Note: this checks that group_map_size is defined AND equal to map_size
84       if (map_size != group_map_size)
85         return std::nullopt;
86     }
87   }
88 
89   return map_size;
90 }
91 
92 // Arguments are expanded to a common shape, referred to as the "map size,"
93 // (see above).
94 // Note: Arguments are mutated by this call, although map_size is restored
95 // to its original value.
expandArgs(const KernelSpec & spec,std::vector<at::Tensor> & args,std::vector<int64_t> & map_size,bool dry_run)96 static bool expandArgs(
97     const KernelSpec& spec,
98     std::vector<at::Tensor>& args,
99     std::vector<int64_t>& map_size,
100     bool dry_run) {
101   bool has_broadcast = false;
102   for (size_t i = 0; i < args.size(); ++i) {
103     auto& arg = args[i];
104     const auto& pdesc = spec.inputChunks()[i];
105     if (pdesc.nSubTensors() == 1) {
106       if (arg.sizes().equals(map_size))
107         continue;
108       if (!dry_run) {
109         arg = arg.expand(map_size);
110         has_broadcast = true;
111       } else {
112         return true;
113       }
114     } else {
115       map_size.at(pdesc.dim()) *= pdesc.nSubTensors();
116       if (!arg.sizes().equals(map_size)) {
117         if (!dry_run) {
118           arg = arg.expand(map_size);
119           has_broadcast = true;
120         } else {
121           return true;
122         }
123       }
124       map_size.at(pdesc.dim()) /= pdesc.nSubTensors();
125     }
126   }
127   return has_broadcast;
128 }
129 
shouldExpandArgs(const KernelSpec & spec,std::vector<at::Tensor> & args,std::vector<int64_t> & map_size)130 static bool shouldExpandArgs(
131     const KernelSpec& spec,
132     std::vector<at::Tensor>& args,
133     std::vector<int64_t>& map_size) {
134   return expandArgs(spec, args, map_size, /*dry_run=*/true);
135 }
136 
137 // Note: assumes that inputs are 32-bit addressable
computeNumel(const at::ArrayRef<int64_t> sizes)138 static uint32_t computeNumel(const at::ArrayRef<int64_t> sizes) {
139   uint32_t result = 1;
140 
141   for (const auto& size : sizes)
142     result *= size;
143 
144   return result;
145 }
146 
147 // Note: Assumes that after at::chunk, all inputs are the same size
computeMapSize(const at::Tensor & tensor,const PartitionDesc & chunkDesc)148 static std::vector<int64_t> computeMapSize(
149     const at::Tensor& tensor,
150     const PartitionDesc& chunkDesc) {
151   std::vector<int64_t> sizes(tensor.sizes().begin(), tensor.sizes().end());
152   AT_ASSERT(sizes[chunkDesc.dim()] % chunkDesc.nSubTensors() == 0);
153   sizes[chunkDesc.dim()] /= chunkDesc.nSubTensors();
154   return sizes;
155 }
156 
157 // Tries to compress sizes and strides according to cont. Emits the result t
158 // c_sizes, c_strides and throws an error on failure (if can't compress)
compressContiguous(const at::IntArrayRef & sizes,const at::IntArrayRef & strides,const std::vector<bool> & cont,uint32_t * c_sizes,uint32_t * c_strides)159 static void compressContiguous(
160     const at::IntArrayRef& sizes,
161     const at::IntArrayRef& strides,
162     const std::vector<bool>& cont,
163     uint32_t* c_sizes,
164     uint32_t* c_strides) {
165   size_t compressed_dims = 0;
166   size_t cur = 0;
167   size_t ndim = sizes.size();
168   while (cur < ndim) {
169     size_t total_size = sizes[cur];
170     cur++;
171     while (cont[cur - 1] && cur < ndim) {
172       AT_ASSERT(strides[cur - 1] == sizes[cur] * strides[cur]);
173       total_size *= sizes[cur];
174       cur++;
175     }
176     c_sizes[compressed_dims] = total_size;
177     c_strides[compressed_dims] = strides[cur - 1];
178     compressed_dims++;
179   }
180 
181   if (ndim > 0)
182     AT_ASSERT(!cont.back() || strides.back() == 1);
183 }
184 
185 // Launches the requested fusion on the given device with the given inputs.
186 // Output pointers are stored in outputs (to be put on the stack later).
launchFusion(const FusedKernel & fusion,const at::Device device,const at::ArrayRef<at::Tensor> & inputs,const at::ArrayRef<IValue> & all_inputs,std::vector<at::Tensor> & outputs)187 static void launchFusion(
188     const FusedKernel& fusion,
189     const at::Device device,
190     const at::ArrayRef<at::Tensor>& inputs,
191     const at::ArrayRef<IValue>& all_inputs,
192     std::vector<at::Tensor>& outputs) {
193   // Fails if fusion and given inputs disagree
194   AT_ASSERT(inputs.size() == fusion.inputDesc().size());
195 
196   // Computes number of flattened inputs and outputs
197   size_t flat_inputs_size = 0;
198   size_t flat_outputs_size = 0;
199   for (const auto& c : fusion.chunkDesc())
200     flat_inputs_size += c.nSubTensors();
201   for (const auto& c : fusion.concatDesc())
202     flat_outputs_size += c.nSubTensors();
203 
204   // Fails if the elements of the first (any) tensor are not expressable as
205   // a 32-bit integer.
206   // Note: this code assumes that inputs are 32-bit addressable
207   // Note: this code assumes that all inputs are of the same size
208   AT_ASSERT(inputs[0].numel() <= std::numeric_limits<uint32_t>::max());
209 
210   // Computes map_size, numel from the first input
211   at::IntArrayRef map_size;
212   uint32_t numel = 0;
213   std::vector<int64_t> keep_alive_size;
214   if (fusion.chunkDesc()[0].isNoop()) {
215     map_size = inputs[0].sizes();
216     numel = inputs[0].numel();
217   } else {
218     keep_alive_size = computeMapSize(inputs[0], fusion.chunkDesc()[0]);
219     map_size = keep_alive_size;
220     numel = computeNumel(map_size);
221   }
222 
223   // compute number of scalar inputs and convert them to float
224   std::vector<double> scalar_inputs;
225   scalar_inputs.reserve(all_inputs.size());
226   for (auto const& input : all_inputs) {
227     if (input.isDouble())
228       scalar_inputs.push_back(input.to<float>());
229   }
230 
231   // Computes the storage needed to store TensorInfo structs for inputs and
232   // outputs.
233   size_t uncompressedDim = fusion.inputDesc().at(0).contiguity.size();
234   size_t maxPossibleTensorInfoSize =
235       sizeof(TensorInfo) + 2 * sizeof(uint32_t) * uncompressedDim;
236   size_t maxPossibleBufferSize =
237       maxPossibleTensorInfoSize * (flat_inputs_size + flat_outputs_size);
238   std::vector<char> buffer(maxPossibleBufferSize);
239   char* buffer_next = buffer.data();
240 
241   // A vector of arguments to the kernel (numel, *input_desc_s, *output_desc_s)
242   std::vector<void*> arguments;
243   arguments.reserve(
244       3 + scalar_inputs.size() + flat_inputs_size + flat_outputs_size);
245   arguments.push_back(&numel);
246 
247   auto addTensorInfoRaw = [&](const TensorDesc& desc,
248                               void* data_ptr,
249                               at::IntArrayRef sizes,
250                               at::IntArrayRef strides) {
251     const auto nDim = desc.nDim(); // NOTE: this is the compressed dim
252     AT_ASSERT(nDim <= uncompressedDim); // We'd overflow the space otherwise
253     auto ti = reinterpret_cast<TensorInfo*>(buffer_next);
254     ti->data = data_ptr;
255     compressContiguous(
256         sizes, strides, desc.contiguity, ti->sizes(nDim), ti->strides(nDim));
257     buffer_next += maxPossibleTensorInfoSize;
258     arguments.push_back(ti);
259   };
260 
261   // Asserts that t's dims can be compressed in the same way as in desc
262   // (that's what the kernel assumes), and appends it to the arguments vector.
263   auto addTensorInfo = [&](const TensorDesc& desc, const at::Tensor& t) {
264     addTensorInfoRaw(desc, t.data_ptr(), t.sizes(), t.strides());
265   };
266 
267   // Adds (flattened) input arguments
268   for (size_t i = 0; i < fusion.inputDesc().size(); ++i) {
269     const auto& chunk = fusion.chunkDesc()[i];
270     const at::Tensor& tensor = inputs[i];
271     if (chunk.isNoop()) {
272       addTensorInfo(fusion.inputDesc()[i], tensor);
273     } else {
274       size_t chunk_offset = map_size[chunk.dim()] * tensor.stride(chunk.dim()) *
275           elementSize(tensor.scalar_type());
276       char* data_ptr = reinterpret_cast<char*>(tensor.data_ptr());
277       for (size_t chunks = 0; chunks < chunk.nSubTensors(); ++chunks) {
278         addTensorInfoRaw(
279             *chunk.subTensorDesc(), data_ptr, map_size, tensor.strides());
280         data_ptr += chunk_offset;
281       }
282     }
283   }
284   // Adds scalar arguments
285   for (double& s : scalar_inputs) {
286     arguments.push_back(&s);
287   }
288 
289   // Adds (flattened) output arguments
290   outputs.reserve(fusion.outputDesc().size());
291   const auto& ref_options = inputs[0].options();
292   for (size_t i = 0; i < fusion.outputDesc().size(); ++i) {
293     const auto& c = fusion.concatDesc()[i];
294     if (c.isNoop()) {
295       outputs.push_back(at::empty(
296           map_size, ref_options.dtype(fusion.outputDesc()[i].scalar_type)));
297       addTensorInfo(fusion.outputDesc()[i], outputs[i]);
298     } else {
299       size_t small_size = map_size[c.dim()];
300       std::vector<int64_t> concat_size(map_size.begin(), map_size.end());
301       concat_size[c.dim()] = small_size * c.nSubTensors();
302       outputs.push_back(at::empty(concat_size, ref_options));
303       const auto& o = outputs[i];
304       size_t offset = 0;
305       for (size_t j = 0; j < c.nSubTensors(); ++j) {
306         // because the concatenated_output stays live, the underlying data
307         // in this view remains live through the end of this function
308         // so there is not need to hold onto this tensor
309         const auto view = o.narrow(c.dim(), offset, small_size);
310         addTensorInfo(*c.subTensorDesc(), view);
311         offset += small_size;
312       }
313     }
314   }
315   // Skip launching the kernel for zero-element tensor inputs
316   // launches are skipped, empty zero-sized output is returned
317   if (numel > 0) {
318     fusion.launch_raw(numel, arguments);
319   }
320 }
321 
runFusion(const int64_t key,Stack & stack,std::string * code_out)322 bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
323   // Short-circuits if fusion isn't enabled
324   if (!canFuseOnCPULegacy() && !canFuseOnGPU())
325     return false;
326 
327   // Acquires the FusionSpec
328   auto maybe_spec = retrieve(key);
329   AT_ASSERT(maybe_spec);
330   auto& spec = *(*maybe_spec);
331   // Acquires inputs from stack
332   auto all_inputs = last(stack, spec.nInputs());
333   std::vector<at::Tensor> inputs;
334   inputs.reserve(spec.nTensorInputs());
335   // we know that tensor inputs are first
336   for (const auto i : c10::irange(spec.nTensorInputs())) {
337     inputs.emplace_back(all_inputs[i].toTensor());
338   }
339 
340   if (!inputs.at(0).defined()) {
341     return false;
342   }
343 
344   // Determines device to dispatch to.
345   at::Device device = inputs.at(0).device();
346   // If there's a device mismatch in the inputs or if one of the input is a
347   // sparse tensor, we use the fallback (which should give a nice error
348   // message).
349   for (const auto& t : at::TensorList(inputs).slice(1)) {
350     // Sparse tensor could not by supported by CUDA fusion, so we bail out.
351     if (t.device() != device || t.is_sparse()) {
352       return false;
353     }
354   }
355 
356   // Attempts to run fallback if device fusion is disabled
357   if (device.is_cuda() && !canFuseOnGPU())
358     return false;
359   if (device.is_cpu() && !canFuseOnCPULegacy())
360     return false;
361   if (device.is_xpu())
362     return false;
363 
364   // Validates sizes and expands inputs as needed
365   auto maybe_map_size = canRunKernel(spec, inputs);
366 
367   // Tries to run fallback if map size can't be computed
368   if (!maybe_map_size)
369     return false;
370   if (spec.hasRandom()) {
371     bool hasBroadcast = shouldExpandArgs(spec, inputs, *maybe_map_size);
372     if (hasBroadcast)
373       return false;
374   }
375   expandArgs(spec, inputs, *maybe_map_size, /*dry_run=*/false);
376 
377   // Retrieves the kernel, compiling (and caching) if necessary
378   ArgSpec arg_spec{inputs, device.index()};
379   auto maybe_kernel = spec.findKernel(arg_spec);
380   if (!maybe_kernel) {
381     const auto kernel = compileKernel(spec, arg_spec, *maybe_map_size, device);
382     spec.cacheKernel(arg_spec, kernel);
383   }
384   maybe_kernel = spec.findKernel(arg_spec);
385   AT_ASSERT(maybe_kernel);
386 
387   if (code_out) {
388     *code_out = maybe_kernel.value()->code();
389   }
390 
391   // Launches fusion
392   std::vector<at::Tensor> outputs;
393   launchFusion(*(*maybe_kernel), device, inputs, all_inputs, outputs);
394 
395   // Updates stack
396   drop(stack, spec.nInputs());
397   stack.insert(
398       stack.end(),
399       std::make_move_iterator(outputs.begin()),
400       std::make_move_iterator(outputs.end()));
401 
402   return true;
403 }
404 
405 } // namespace torch::jit::fuser
406