1#import <ATen/native/metal/MetalCommandBuffer.h> 2#import <ATen/native/metal/MetalContext.h> 3#import <ATen/native/metal/MetalTensorImpl.h> 4#import <ATen/native/metal/MetalTensorImplStorage.h> 5#import <ATen/native/metal/MetalTensorUtils.h> 6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 8#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 9 10#include <ATen/Tensor.h> 11#include <ATen/native/UpSample.h> 12#include <torch/library.h> 13 14namespace at::native::metal { 15 16static Tensor upsample_nearest2d_vec( 17 const Tensor& input, 18 at::OptionalIntArrayRef output_size, 19 std::optional<ArrayRef<double>> scale_factors) { 20 TORCH_CHECK(input.is_metal()); 21 auto osize = 22 upsample::compute_output_size(input.sizes(), output_size, scale_factors); 23 auto scale_h = upsample::get_scale_value(scale_factors, 0); 24 auto scale_w = upsample::get_scale_value(scale_factors, 1); 25 int64_t output_height = osize[0]; 26 int64_t output_width = osize[1]; 27 int64_t nbatch = input.size(0); 28 int64_t channels = input.size(1); 29 int64_t input_height = input.size(2); 30 int64_t input_width = input.size(3); 31 upsample_2d_shape_check( 32 input, 33 Tensor(), 34 nbatch, 35 channels, 36 input_height, 37 input_width, 38 output_height, 39 output_width); 40 std::vector<int64_t> outputSizes{ 41 nbatch, channels, output_height, output_width}; 42 if (input.numel() == 0) { 43 return makeTensor({outputSizes}, input.options()); 44 } 45 MPSImage* X = imageFromTensor(input); 46 MetalTensorImplStorage mt{outputSizes}; 47 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 48 mt.texture()->allocateTemporaryStorage(outputSizes, commandBuffer); 49 MPSImage* Y = mt.texture()->image(); 50 if (@available(iOS 11.0, *)) { 51 MPSCNNUpsamplingNearest* kernel = [[MPSCNNUpsamplingNearest alloc] 52 initWithDevice:[MetalContext sharedInstance].device 53 integerScaleFactorX:(NSUInteger)scale_w.value() 54 integerScaleFactorY:(NSUInteger)scale_h.value()]; 55 [kernel encodeToCommandBuffer:commandBuffer.buffer 56 sourceImage:X 57 destinationImage:Y]; 58 } else { 59 TORCH_CHECK( 60 false, 61 "MPSCNNUpsamplingNearest is only available on iOS 11.0 and above"); 62 } 63 auto output = makeTensor(std::move(mt), input.options()); 64 return output; 65} 66 67TORCH_LIBRARY_IMPL(aten, Metal, m) { 68 m.impl(TORCH_SELECTIVE_NAME("aten::upsample_nearest2d.vec"), TORCH_FN(upsample_nearest2d_vec)); 69}; 70 71} // namespace at::native::metal 72