1 2#import <ATen/native/metal/MetalCommandBuffer.h> 3#import <ATen/native/metal/MetalTensorImpl.h> 4#import <ATen/native/metal/MetalTensorImplStorage.h> 5#import <ATen/native/metal/MetalTensorUtils.h> 6#import <ATen/native/metal/MetalContext.h> 7#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 8#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 9#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 10 11#include <ATen/Tensor.h> 12#include <ATen/native/UpSample.h> 13#include <torch/library.h> 14 15namespace at::native::metal { 16 17static Tensor cat_batch(const Tensor& tensor, const ITensorListRef& tensors, MetalTensorImplStorage& mt) { 18 MetalCommandBuffer* commandBuffer = getCommandBuffer(tensor); 19 MPSImage* Y = mt.texture()->image(); 20 ushort cat_dim4_pointer = 0; 21 for (const auto& t : tensors) { 22 MPSImage* X = imageFromTensor(t); 23 MetalCommandBuffer* Xcb = getCommandBuffer(t); 24 TORCH_CHECK( 25 [commandBuffer isEqual:Xcb], 26 @"inputs have different Metal command buffers"); 27 id<MTLComputeCommandEncoder> encoder = 28 [commandBuffer.buffer computeCommandEncoder]; 29 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 30 pipelineState:mpscnn::kernelFor( 31 X, "copy_offset", "copy_offset_nonarray")]; 32 id<MTLBuffer> offsetBuffer = [[MetalContext sharedInstance].device 33 newBufferWithLength:1 * sizeof(ushort) 34 options:MTLResourceCPUCacheModeWriteCombined]; 35 ushort* offsetBufferPtr = (ushort*)[offsetBuffer contents]; 36 offsetBufferPtr[0] = cat_dim4_pointer; 37 38 [encoder setComputePipelineState:state]; 39 [encoder setTexture:[X texture] atIndex:0]; 40 [encoder setTexture:[Y texture] atIndex:1]; 41 [encoder setBuffer:offsetBuffer offset:0 atIndex:0]; 42 43 const auto& launchParams = 44 mpscnn::spatialPointwiseKernelLaunchParams(state, X); 45 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 46 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 47 [encoder endEncoding]; 48 cat_dim4_pointer += t.size(0) * ((t.size(1) + 3) / 4); 49 } 50 auto output = makeTensor(std::move(mt), tensor.options()); 51 return output; 52} 53 54static Tensor cat_feature(const Tensor& tensor, const ITensorListRef& tensors, MetalTensorImplStorage& mt) { 55 MetalCommandBuffer* commandBuffer = getCommandBuffer(tensor); 56 MPSImage* Y = mt.texture()->image(); 57 ushort channel_offset = 0; 58 59 auto temp_size = tensor.sizes().vec(); 60 temp_size[1] = 4; 61 MetalTensorImplStorage tt{temp_size}; 62 tt.texture()->setCommandBuffer(commandBuffer); 63 tt.texture()->allocateTemporaryStorage(temp_size, commandBuffer); 64 MPSImage* T = tt.texture()->image(); 65 66 for (const auto& t : tensors) { 67 MPSImage* X = imageFromTensor(t); 68 MetalCommandBuffer* Xcb = getCommandBuffer(t); 69 TORCH_CHECK( 70 [commandBuffer isEqual:Xcb], 71 @"inputs have different Metal command buffers"); 72 ushort tex_offset = channel_offset % 4; 73 std::string kernelString = tex_offset == 0 ? "append_features" : "append_features_off"; 74 75 { 76 id<MTLComputeCommandEncoder> encoder = 77 [commandBuffer.buffer computeCommandEncoder]; 78 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 79 specializedPipelineState:kernelString 80 Constants:@[ 81 @(T.height), 82 @(T.width), 83 @(T.featureChannels), 84 @(T.numberOfImages), 85 @(X.height), 86 @(X.width), 87 @(X.featureChannels), 88 @(X.numberOfImages), 89 ]]; 90 id<MTLBuffer> offsetBuffer = [[MetalContext sharedInstance].device 91 newBufferWithLength:6 * sizeof(ushort) 92 options:MTLResourceCPUCacheModeWriteCombined]; 93 ushort* offsetBufferPtr = (ushort*)[offsetBuffer contents]; 94 offsetBufferPtr[0] = (X.featureChannels + tex_offset + 3) / 4; 95 offsetBufferPtr[1] = (Y.featureChannels + 3) / 4; 96 offsetBufferPtr[2] = channel_offset / 4; 97 offsetBufferPtr[3] = (X.featureChannels + 3) / 4; 98 offsetBufferPtr[4] = X.numberOfImages * offsetBufferPtr[0]; 99 offsetBufferPtr[5] = tex_offset; 100 101 [encoder setComputePipelineState:state]; 102 if (tex_offset == 0) { 103 [encoder setTexture:[X texture] atIndex:0]; 104 [encoder setTexture:[Y texture] atIndex:1]; 105 [encoder setBuffer:offsetBuffer offset:0 atIndex:0]; 106 } 107 else { 108 [encoder setTexture:[X texture] atIndex:0]; 109 [encoder setTexture:[T texture] atIndex:1]; 110 [encoder setTexture:[Y texture] atIndex:2]; 111 [encoder setBuffer:offsetBuffer offset:0 atIndex:0]; 112 } 113 114 ushort featureChannels = X.featureChannels; 115 if (channel_offset % 4 > 0) { 116 featureChannels += tex_offset; 117 } 118 const auto& launchParams = 119 metal::mpscnn::spatialPointwiseKernelLaunchParams( 120 state, X.numberOfImages, featureChannels, X.height, X.width); 121 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 122 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 123 [encoder endEncoding]; 124 } 125 126 channel_offset += X.featureChannels; 127 128 { 129 id<MTLComputeCommandEncoder> encoder = 130 [commandBuffer.buffer computeCommandEncoder]; 131 132 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 133 specializedPipelineState:"store_features" 134 Constants:@[ 135 @(T.height), 136 @(T.width), 137 @(T.featureChannels), 138 @(T.numberOfImages), 139 ]]; 140 id<MTLBuffer> offsetBuffer = [[MetalContext sharedInstance].device 141 newBufferWithLength:2 * sizeof(ushort) 142 options:MTLResourceCPUCacheModeWriteCombined]; 143 ushort* offsetBufferPtr = (ushort*)[offsetBuffer contents]; 144 offsetBufferPtr[0] = channel_offset / 4; 145 offsetBufferPtr[1] = (Y.featureChannels + 3) / 4; 146 147 [encoder setComputePipelineState:state]; 148 [encoder setTexture:[Y texture] atIndex:0]; 149 [encoder setTexture:[T texture] atIndex:1]; 150 [encoder setBuffer:offsetBuffer offset:0 atIndex:0]; 151 152 const auto& launchParams = 153 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, T); 154 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 155 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 156 [encoder endEncoding]; 157 } 158 } 159 auto output = makeTensor(std::move(mt), tensor.options()); 160 return output; 161} 162 163static Tensor cat(const ITensorListRef& tensors, int64_t dim) { 164 TORCH_CHECK( 165 dim == 0 || dim == 1, 166 "Metal cat is implemented only for batch dimension"); 167 int64_t cat_dim_size = 0; 168 TORCH_CHECK(!tensors.empty(), "cat expected a non-empty list of Tensor"); 169 at::Tensor tensor = *tensors.begin(); 170 MetalCommandBuffer* commandBuffer = getCommandBuffer(tensor); 171 for (const auto& t : tensors) { 172 TORCH_CHECK(t.dim() == 4, "Metal cat expects 4 dimensional inputs"); 173 TORCH_CHECK(t.is_metal(), "Metal cat expects metal tensors"); 174 175 for (int d = 0; d < 4; ++d) { 176 if (d == dim) { 177 continue; 178 } 179 TORCH_CHECK( 180 t.size(d) == tensor.size(d), 181 "Metal cat inputs must have matching sizes except concatenated dimension"); 182 } 183 cat_dim_size += t.size(dim); 184 } 185 auto result_size = tensor.sizes().vec(); 186 result_size[dim] = cat_dim_size; 187 TORCH_CHECK( 188 result_size[0] * ((result_size[1] + 3) / 4) > 1, 189 "Output tensor must be a texture array"); 190 MetalTensorImplStorage mt{result_size}; 191 mt.texture()->setCommandBuffer(commandBuffer); 192 mt.texture()->allocateTemporaryStorage(result_size, commandBuffer); 193 194 if (dim == 1) { 195 return cat_feature(tensor, tensors, mt); 196 } 197 return cat_batch(tensor, tensors, mt); 198} 199 200TORCH_LIBRARY_IMPL(aten, Metal, m) { 201 m.impl(TORCH_SELECTIVE_NAME("aten::cat"), TORCH_FN(cat)); 202} 203 204} // namespace at::native::metal 205