xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalCommandBuffer.h>
2#import <ATen/native/metal/MetalContext.h>
3#import <ATen/native/metal/MetalTensorImpl.h>
4#import <ATen/native/metal/MetalTensorImplStorage.h>
5#import <ATen/native/metal/MetalTensorUtils.h>
6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
8#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
9
10#include <ATen/Tensor.h>
11#include <ATen/native/UpSample.h>
12#include <torch/library.h>
13
14namespace at::native::metal {
15
16static Tensor upsample_nearest2d_vec(
17    const Tensor& input,
18    at::OptionalIntArrayRef output_size,
19    std::optional<ArrayRef<double>> scale_factors) {
20  TORCH_CHECK(input.is_metal());
21  auto osize =
22      upsample::compute_output_size(input.sizes(), output_size, scale_factors);
23  auto scale_h = upsample::get_scale_value(scale_factors, 0);
24  auto scale_w = upsample::get_scale_value(scale_factors, 1);
25  int64_t output_height = osize[0];
26  int64_t output_width = osize[1];
27  int64_t nbatch = input.size(0);
28  int64_t channels = input.size(1);
29  int64_t input_height = input.size(2);
30  int64_t input_width = input.size(3);
31  upsample_2d_shape_check(
32      input,
33      Tensor(),
34      nbatch,
35      channels,
36      input_height,
37      input_width,
38      output_height,
39      output_width);
40  std::vector<int64_t> outputSizes{
41      nbatch, channels, output_height, output_width};
42  if (input.numel() == 0) {
43    return makeTensor({outputSizes}, input.options());
44  }
45  MPSImage* X = imageFromTensor(input);
46  MetalTensorImplStorage mt{outputSizes};
47  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
48  mt.texture()->allocateTemporaryStorage(outputSizes, commandBuffer);
49  MPSImage* Y = mt.texture()->image();
50  if (@available(iOS 11.0, *)) {
51    MPSCNNUpsamplingNearest* kernel = [[MPSCNNUpsamplingNearest alloc]
52             initWithDevice:[MetalContext sharedInstance].device
53        integerScaleFactorX:(NSUInteger)scale_w.value()
54        integerScaleFactorY:(NSUInteger)scale_h.value()];
55    [kernel encodeToCommandBuffer:commandBuffer.buffer
56                      sourceImage:X
57                 destinationImage:Y];
58  } else {
59      TORCH_CHECK(
60          false,
61          "MPSCNNUpsamplingNearest is only available on iOS 11.0 and above");
62  }
63  auto output = makeTensor(std::move(mt), input.options());
64  return output;
65}
66
67TORCH_LIBRARY_IMPL(aten, Metal, m) {
68  m.impl(TORCH_SELECTIVE_NAME("aten::upsample_nearest2d.vec"), TORCH_FN(upsample_nearest2d_vec));
69};
70
71} // namespace at::native::metal
72