1#import <ATen/native/metal/MetalCommandBuffer.h> 2#import <ATen/native/metal/MetalTensorImpl.h> 3#import <ATen/native/metal/MetalTensorImplStorage.h> 4#import <ATen/native/metal/MetalTensorUtils.h> 5#import <ATen/native/metal/MetalContext.h> 6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 8#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 9 10#include <ATen/ATen.h> 11#include <torch/library.h> 12 13namespace at::native::metal { 14 15// TODO: Move this function to MetalContext 16template<typename T> 17id<MTLBuffer> _makeMTLBuffer(const std::vector<T>& src) { 18 id<MTLBuffer> buffer = [[MetalContext sharedInstance].device 19 newBufferWithLength:src.size() * sizeof(T) 20 options:MTLResourceCPUCacheModeWriteCombined]; 21 memcpy(buffer.contents, src.data(), src.size() * sizeof(T)); 22 return buffer; 23} 24 25static Tensor transpose(const Tensor& input, int64_t dim0, int64_t dim1) { 26 TORCH_CHECK(input.is_metal()); 27 auto ndims = input.dim(); 28 // Support maximum eight channels on mobile 29 TORCH_CHECK(ndims <= 8); 30 dim0 = maybe_wrap_dim(dim0, ndims); 31 dim1 = maybe_wrap_dim(dim1, ndims); 32 if (dim0 == dim1) { 33 return input; 34 } 35 auto outputSizes = input.sizes().vec(); 36 std::swap(outputSizes[dim0], outputSizes[dim1]); 37 MPSImage* X = imageFromTensor(input); 38 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 39 if (input.dim() == 2) { 40 MetalTensorImplStorage mt{outputSizes}; 41 mt.texture()->allocateTemporaryStorage(outputSizes, commandBuffer); 42 MPSImage* Y = mt.texture()->image(); 43 MPSImageTranspose* transpose = [[MPSImageTranspose alloc] 44 initWithDevice:[MetalContext sharedInstance].device]; 45 [transpose encodeToCommandBuffer:commandBuffer.buffer 46 sourceImage:X 47 destinationImage:Y]; 48 auto output = makeTensor(std::move(mt), input.options()); 49 return output; 50 } else { 51 id<MTLBuffer> sizeBuf1 = _makeMTLBuffer<ushort>( 52 std::vector<ushort>{input.sizes().begin(), input.sizes().end()}); 53 id<MTLBuffer> sizeBuf2 = _makeMTLBuffer<ushort>( 54 std::vector<ushort>{outputSizes.begin(), outputSizes.end()}); 55 MetalTensorImplStorage mt{outputSizes}; 56 mt.texture()->allocateTemporaryStorage(outputSizes, commandBuffer); 57 MPSImage* Y = mt.texture()->image(); 58 id<MTLComputeCommandEncoder> encoder = 59 [commandBuffer.buffer computeCommandEncoder]; 60 id<MTLComputePipelineState> state = 61 [[MetalContext sharedInstance] specializedPipelineState:"transpose" 62 Constants:@[ 63 @(dim0), 64 @(dim1), 65 @(input.dim()), 66 @(X.numberOfImages), 67 @(X.featureChannels), 68 @(Y.numberOfImages), 69 @(Y.featureChannels), 70 ]]; 71 72 [encoder setComputePipelineState:state]; 73 [encoder setTexture:[X texture] atIndex:0]; 74 [encoder setTexture:[Y texture] atIndex:1]; 75 [encoder setBuffer:sizeBuf1 offset:0 atIndex:0]; 76 [encoder setBuffer:sizeBuf2 offset:0 atIndex:1]; 77 78 const auto& launchParams = 79 mpscnn::spatialPointwiseKernelLaunchParams(state, Y); 80 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 81 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 82 [encoder endEncoding]; 83 auto output = makeTensor(std::move(mt), input.options()); 84 return output; 85 } 86} 87 88static Tensor t(const Tensor& input) { 89 TORCH_CHECK(input.is_metal()); 90 TORCH_CHECK(input.dim() == 2); 91 return metal::transpose(input, 0, input.dim() < 2 ? 0 : 1); 92} 93 94TORCH_LIBRARY_IMPL(aten, Metal, m) { 95 m.impl(TORCH_SELECTIVE_NAME("aten::t"), TORCH_FN(t)); 96 m.impl(TORCH_SELECTIVE_NAME("aten::transpose.int"), TORCH_FN(transpose)); 97}; 98 99} // namespace at::native::metal 100