1#import <ATen/native/metal/MetalCommandBuffer.h> 2#import <ATen/native/metal/MetalTensorImpl.h> 3#import <ATen/native/metal/MetalTensorImplStorage.h> 4#import <ATen/native/metal/MetalTensorUtils.h> 5#import <ATen/native/metal/MetalContext.h> 6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 8#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 9 10#include <torch/library.h> 11 12namespace at::native::metal { 13 14static Tensor copy_to_host(const Tensor& input) { 15 TORCH_CHECK(input.is_metal()); 16 MPSImage* X = imageFromTensor(input); 17 if (X && !X.isTemporaryImage) { 18 return input; 19 } 20 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 21 auto&& sizes = [X sizes]; 22 MetalTensorImplStorage mt{sizes}; 23 mt.texture()->setCommandBuffer(commandBuffer); 24 mt.texture()->allocateStorage(sizes); 25 MPSImage* Y = mt.texture()->image(); 26 27 id<MTLComputeCommandEncoder> encoder = 28 [commandBuffer.buffer computeCommandEncoder]; 29 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 30 specializedPipelineState:metal::mpscnn::kernelFor( 31 X, "copy", "copy_nonarray") 32 Constants:@[ 33 @(X.featureChannels), 34 @(X.height), 35 @(X.width) 36 ]]; 37 38 [encoder setComputePipelineState:state]; 39 [encoder setTexture:[X texture] atIndex:0]; 40 [encoder setTexture:[Y texture] atIndex:1]; 41 42 const auto& launchParams = 43 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); 44 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 45 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 46 [encoder endEncoding]; 47 auto output = makeTensor(std::move(mt), input.options()); 48 return output; 49} 50 51TORCH_LIBRARY_IMPL(metal, Metal, m) { 52 m.impl(TORCH_SELECTIVE_NAME("metal::copy_to_host"), TORCH_FN(copy_to_host)); 53} 54 55} // namespace at::native::metal 56