1 #include <torch/csrc/jit/codegen/fuser/executor.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/ExpandUtils.h>
5 #include <ATen/core/functional.h>
6 #include <ATen/core/stack.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/jit/codegen/fuser/compiler.h>
9 #include <torch/csrc/jit/codegen/fuser/interface.h>
10 #include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
11 #include <torch/csrc/jit/codegen/fuser/kernel_spec.h>
12 #include <torch/csrc/jit/codegen/fuser/tensor_info.h>
13 #include <torch/csrc/jit/passes/graph_fuser.h>
14 #include <optional>
15
16 #include <algorithm>
17 #include <vector>
18
19 namespace torch::jit::fuser {
20
21 // Returns the "map size" for this run, which is the common size for all
22 // intermediate tensors.
getMapSize(const KernelSpec & spec,at::TensorList args,at::IntArrayRef arg_subset)23 static std::optional<std::vector<int64_t>> getMapSize(
24 const KernelSpec& spec,
25 at::TensorList args,
26 at::IntArrayRef arg_subset) {
27 // TODO: this keeps reallocating map_size at every iteration, but we know
28 // exactly how much storage do we need, so this could be fixed in-place at
29 // every step. We're just missing a few functions for ATen, but the fix
30 // should be straightforward.
31 // Note: left unitialized since empty shape is broadcastable to any shape
32 std::vector<int64_t> map_size;
33 map_size.reserve(8);
34 for (const auto arg_idx : arg_subset) {
35 auto& arg = args.at(arg_idx);
36 auto& chunk_desc = spec.inputChunks().at(arg_idx);
37 if (chunk_desc.nSubTensors() == 1) {
38 try {
39 map_size = at::infer_size(map_size, arg.sizes());
40 } catch (...) {
41 return std::nullopt;
42 }
43 } else {
44 auto tensor_sizes = arg.sizes().vec();
45 const auto num_chunks = chunk_desc.nSubTensors();
46 const auto dim =
47 at::maybe_wrap_dim(chunk_desc.dim(), tensor_sizes.size());
48 if (tensor_sizes[dim] % num_chunks != 0) {
49 return std::nullopt;
50 }
51 tensor_sizes[dim] /= num_chunks;
52 try {
53 map_size = at::infer_size(map_size, tensor_sizes);
54 } catch (...) {
55 return std::nullopt;
56 }
57 }
58 }
59
60 return {map_size};
61 }
62
63 // Tries to determine a map size for the instantiated kernel (see above)
canRunKernel(const KernelSpec & spec,at::TensorList args)64 static std::optional<std::vector<int64_t>> canRunKernel(
65 const KernelSpec& spec,
66 at::TensorList args) {
67 // Short-circuits on size mismatch
68 TORCH_CHECK(
69 args.size() == spec.inputChunks().size(),
70 "Expected ",
71 spec.inputChunks().size(),
72 " arguments, but got ",
73 args.size());
74
75 std::optional<std::vector<int64_t>> map_size;
76 for (const auto& broadcast_group : spec.inputBroadcastGroups()) {
77 if (!map_size) {
78 map_size = getMapSize(spec, args, broadcast_group);
79 if (!map_size)
80 return std::nullopt;
81 } else {
82 const auto group_map_size = getMapSize(spec, args, broadcast_group);
83 // Note: this checks that group_map_size is defined AND equal to map_size
84 if (map_size != group_map_size)
85 return std::nullopt;
86 }
87 }
88
89 return map_size;
90 }
91
92 // Arguments are expanded to a common shape, referred to as the "map size,"
93 // (see above).
94 // Note: Arguments are mutated by this call, although map_size is restored
95 // to its original value.
expandArgs(const KernelSpec & spec,std::vector<at::Tensor> & args,std::vector<int64_t> & map_size,bool dry_run)96 static bool expandArgs(
97 const KernelSpec& spec,
98 std::vector<at::Tensor>& args,
99 std::vector<int64_t>& map_size,
100 bool dry_run) {
101 bool has_broadcast = false;
102 for (size_t i = 0; i < args.size(); ++i) {
103 auto& arg = args[i];
104 const auto& pdesc = spec.inputChunks()[i];
105 if (pdesc.nSubTensors() == 1) {
106 if (arg.sizes().equals(map_size))
107 continue;
108 if (!dry_run) {
109 arg = arg.expand(map_size);
110 has_broadcast = true;
111 } else {
112 return true;
113 }
114 } else {
115 map_size.at(pdesc.dim()) *= pdesc.nSubTensors();
116 if (!arg.sizes().equals(map_size)) {
117 if (!dry_run) {
118 arg = arg.expand(map_size);
119 has_broadcast = true;
120 } else {
121 return true;
122 }
123 }
124 map_size.at(pdesc.dim()) /= pdesc.nSubTensors();
125 }
126 }
127 return has_broadcast;
128 }
129
shouldExpandArgs(const KernelSpec & spec,std::vector<at::Tensor> & args,std::vector<int64_t> & map_size)130 static bool shouldExpandArgs(
131 const KernelSpec& spec,
132 std::vector<at::Tensor>& args,
133 std::vector<int64_t>& map_size) {
134 return expandArgs(spec, args, map_size, /*dry_run=*/true);
135 }
136
137 // Note: assumes that inputs are 32-bit addressable
computeNumel(const at::ArrayRef<int64_t> sizes)138 static uint32_t computeNumel(const at::ArrayRef<int64_t> sizes) {
139 uint32_t result = 1;
140
141 for (const auto& size : sizes)
142 result *= size;
143
144 return result;
145 }
146
147 // Note: Assumes that after at::chunk, all inputs are the same size
computeMapSize(const at::Tensor & tensor,const PartitionDesc & chunkDesc)148 static std::vector<int64_t> computeMapSize(
149 const at::Tensor& tensor,
150 const PartitionDesc& chunkDesc) {
151 std::vector<int64_t> sizes(tensor.sizes().begin(), tensor.sizes().end());
152 AT_ASSERT(sizes[chunkDesc.dim()] % chunkDesc.nSubTensors() == 0);
153 sizes[chunkDesc.dim()] /= chunkDesc.nSubTensors();
154 return sizes;
155 }
156
157 // Tries to compress sizes and strides according to cont. Emits the result t
158 // c_sizes, c_strides and throws an error on failure (if can't compress)
compressContiguous(const at::IntArrayRef & sizes,const at::IntArrayRef & strides,const std::vector<bool> & cont,uint32_t * c_sizes,uint32_t * c_strides)159 static void compressContiguous(
160 const at::IntArrayRef& sizes,
161 const at::IntArrayRef& strides,
162 const std::vector<bool>& cont,
163 uint32_t* c_sizes,
164 uint32_t* c_strides) {
165 size_t compressed_dims = 0;
166 size_t cur = 0;
167 size_t ndim = sizes.size();
168 while (cur < ndim) {
169 size_t total_size = sizes[cur];
170 cur++;
171 while (cont[cur - 1] && cur < ndim) {
172 AT_ASSERT(strides[cur - 1] == sizes[cur] * strides[cur]);
173 total_size *= sizes[cur];
174 cur++;
175 }
176 c_sizes[compressed_dims] = total_size;
177 c_strides[compressed_dims] = strides[cur - 1];
178 compressed_dims++;
179 }
180
181 if (ndim > 0)
182 AT_ASSERT(!cont.back() || strides.back() == 1);
183 }
184
185 // Launches the requested fusion on the given device with the given inputs.
186 // Output pointers are stored in outputs (to be put on the stack later).
launchFusion(const FusedKernel & fusion,const at::Device device,const at::ArrayRef<at::Tensor> & inputs,const at::ArrayRef<IValue> & all_inputs,std::vector<at::Tensor> & outputs)187 static void launchFusion(
188 const FusedKernel& fusion,
189 const at::Device device,
190 const at::ArrayRef<at::Tensor>& inputs,
191 const at::ArrayRef<IValue>& all_inputs,
192 std::vector<at::Tensor>& outputs) {
193 // Fails if fusion and given inputs disagree
194 AT_ASSERT(inputs.size() == fusion.inputDesc().size());
195
196 // Computes number of flattened inputs and outputs
197 size_t flat_inputs_size = 0;
198 size_t flat_outputs_size = 0;
199 for (const auto& c : fusion.chunkDesc())
200 flat_inputs_size += c.nSubTensors();
201 for (const auto& c : fusion.concatDesc())
202 flat_outputs_size += c.nSubTensors();
203
204 // Fails if the elements of the first (any) tensor are not expressable as
205 // a 32-bit integer.
206 // Note: this code assumes that inputs are 32-bit addressable
207 // Note: this code assumes that all inputs are of the same size
208 AT_ASSERT(inputs[0].numel() <= std::numeric_limits<uint32_t>::max());
209
210 // Computes map_size, numel from the first input
211 at::IntArrayRef map_size;
212 uint32_t numel = 0;
213 std::vector<int64_t> keep_alive_size;
214 if (fusion.chunkDesc()[0].isNoop()) {
215 map_size = inputs[0].sizes();
216 numel = inputs[0].numel();
217 } else {
218 keep_alive_size = computeMapSize(inputs[0], fusion.chunkDesc()[0]);
219 map_size = keep_alive_size;
220 numel = computeNumel(map_size);
221 }
222
223 // compute number of scalar inputs and convert them to float
224 std::vector<double> scalar_inputs;
225 scalar_inputs.reserve(all_inputs.size());
226 for (auto const& input : all_inputs) {
227 if (input.isDouble())
228 scalar_inputs.push_back(input.to<float>());
229 }
230
231 // Computes the storage needed to store TensorInfo structs for inputs and
232 // outputs.
233 size_t uncompressedDim = fusion.inputDesc().at(0).contiguity.size();
234 size_t maxPossibleTensorInfoSize =
235 sizeof(TensorInfo) + 2 * sizeof(uint32_t) * uncompressedDim;
236 size_t maxPossibleBufferSize =
237 maxPossibleTensorInfoSize * (flat_inputs_size + flat_outputs_size);
238 std::vector<char> buffer(maxPossibleBufferSize);
239 char* buffer_next = buffer.data();
240
241 // A vector of arguments to the kernel (numel, *input_desc_s, *output_desc_s)
242 std::vector<void*> arguments;
243 arguments.reserve(
244 3 + scalar_inputs.size() + flat_inputs_size + flat_outputs_size);
245 arguments.push_back(&numel);
246
247 auto addTensorInfoRaw = [&](const TensorDesc& desc,
248 void* data_ptr,
249 at::IntArrayRef sizes,
250 at::IntArrayRef strides) {
251 const auto nDim = desc.nDim(); // NOTE: this is the compressed dim
252 AT_ASSERT(nDim <= uncompressedDim); // We'd overflow the space otherwise
253 auto ti = reinterpret_cast<TensorInfo*>(buffer_next);
254 ti->data = data_ptr;
255 compressContiguous(
256 sizes, strides, desc.contiguity, ti->sizes(nDim), ti->strides(nDim));
257 buffer_next += maxPossibleTensorInfoSize;
258 arguments.push_back(ti);
259 };
260
261 // Asserts that t's dims can be compressed in the same way as in desc
262 // (that's what the kernel assumes), and appends it to the arguments vector.
263 auto addTensorInfo = [&](const TensorDesc& desc, const at::Tensor& t) {
264 addTensorInfoRaw(desc, t.data_ptr(), t.sizes(), t.strides());
265 };
266
267 // Adds (flattened) input arguments
268 for (size_t i = 0; i < fusion.inputDesc().size(); ++i) {
269 const auto& chunk = fusion.chunkDesc()[i];
270 const at::Tensor& tensor = inputs[i];
271 if (chunk.isNoop()) {
272 addTensorInfo(fusion.inputDesc()[i], tensor);
273 } else {
274 size_t chunk_offset = map_size[chunk.dim()] * tensor.stride(chunk.dim()) *
275 elementSize(tensor.scalar_type());
276 char* data_ptr = reinterpret_cast<char*>(tensor.data_ptr());
277 for (size_t chunks = 0; chunks < chunk.nSubTensors(); ++chunks) {
278 addTensorInfoRaw(
279 *chunk.subTensorDesc(), data_ptr, map_size, tensor.strides());
280 data_ptr += chunk_offset;
281 }
282 }
283 }
284 // Adds scalar arguments
285 for (double& s : scalar_inputs) {
286 arguments.push_back(&s);
287 }
288
289 // Adds (flattened) output arguments
290 outputs.reserve(fusion.outputDesc().size());
291 const auto& ref_options = inputs[0].options();
292 for (size_t i = 0; i < fusion.outputDesc().size(); ++i) {
293 const auto& c = fusion.concatDesc()[i];
294 if (c.isNoop()) {
295 outputs.push_back(at::empty(
296 map_size, ref_options.dtype(fusion.outputDesc()[i].scalar_type)));
297 addTensorInfo(fusion.outputDesc()[i], outputs[i]);
298 } else {
299 size_t small_size = map_size[c.dim()];
300 std::vector<int64_t> concat_size(map_size.begin(), map_size.end());
301 concat_size[c.dim()] = small_size * c.nSubTensors();
302 outputs.push_back(at::empty(concat_size, ref_options));
303 const auto& o = outputs[i];
304 size_t offset = 0;
305 for (size_t j = 0; j < c.nSubTensors(); ++j) {
306 // because the concatenated_output stays live, the underlying data
307 // in this view remains live through the end of this function
308 // so there is not need to hold onto this tensor
309 const auto view = o.narrow(c.dim(), offset, small_size);
310 addTensorInfo(*c.subTensorDesc(), view);
311 offset += small_size;
312 }
313 }
314 }
315 // Skip launching the kernel for zero-element tensor inputs
316 // launches are skipped, empty zero-sized output is returned
317 if (numel > 0) {
318 fusion.launch_raw(numel, arguments);
319 }
320 }
321
runFusion(const int64_t key,Stack & stack,std::string * code_out)322 bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
323 // Short-circuits if fusion isn't enabled
324 if (!canFuseOnCPULegacy() && !canFuseOnGPU())
325 return false;
326
327 // Acquires the FusionSpec
328 auto maybe_spec = retrieve(key);
329 AT_ASSERT(maybe_spec);
330 auto& spec = *(*maybe_spec);
331 // Acquires inputs from stack
332 auto all_inputs = last(stack, spec.nInputs());
333 std::vector<at::Tensor> inputs;
334 inputs.reserve(spec.nTensorInputs());
335 // we know that tensor inputs are first
336 for (const auto i : c10::irange(spec.nTensorInputs())) {
337 inputs.emplace_back(all_inputs[i].toTensor());
338 }
339
340 if (!inputs.at(0).defined()) {
341 return false;
342 }
343
344 // Determines device to dispatch to.
345 at::Device device = inputs.at(0).device();
346 // If there's a device mismatch in the inputs or if one of the input is a
347 // sparse tensor, we use the fallback (which should give a nice error
348 // message).
349 for (const auto& t : at::TensorList(inputs).slice(1)) {
350 // Sparse tensor could not by supported by CUDA fusion, so we bail out.
351 if (t.device() != device || t.is_sparse()) {
352 return false;
353 }
354 }
355
356 // Attempts to run fallback if device fusion is disabled
357 if (device.is_cuda() && !canFuseOnGPU())
358 return false;
359 if (device.is_cpu() && !canFuseOnCPULegacy())
360 return false;
361 if (device.is_xpu())
362 return false;
363
364 // Validates sizes and expands inputs as needed
365 auto maybe_map_size = canRunKernel(spec, inputs);
366
367 // Tries to run fallback if map size can't be computed
368 if (!maybe_map_size)
369 return false;
370 if (spec.hasRandom()) {
371 bool hasBroadcast = shouldExpandArgs(spec, inputs, *maybe_map_size);
372 if (hasBroadcast)
373 return false;
374 }
375 expandArgs(spec, inputs, *maybe_map_size, /*dry_run=*/false);
376
377 // Retrieves the kernel, compiling (and caching) if necessary
378 ArgSpec arg_spec{inputs, device.index()};
379 auto maybe_kernel = spec.findKernel(arg_spec);
380 if (!maybe_kernel) {
381 const auto kernel = compileKernel(spec, arg_spec, *maybe_map_size, device);
382 spec.cacheKernel(arg_spec, kernel);
383 }
384 maybe_kernel = spec.findKernel(arg_spec);
385 AT_ASSERT(maybe_kernel);
386
387 if (code_out) {
388 *code_out = maybe_kernel.value()->code();
389 }
390
391 // Launches fusion
392 std::vector<at::Tensor> outputs;
393 launchFusion(*(*maybe_kernel), device, inputs, all_inputs, outputs);
394
395 // Updates stack
396 drop(stack, spec.nInputs());
397 stack.insert(
398 stack.end(),
399 std::make_move_iterator(outputs.begin()),
400 std::make_move_iterator(outputs.end()));
401
402 return true;
403 }
404
405 } // namespace torch::jit::fuser
406