1#include <ATen/Tensor.h> 2#import <ATen/native/metal/MetalCommandBuffer.h> 3#import <ATen/native/metal/MetalTensorImpl.h> 4#import <ATen/native/metal/MetalTensorImplStorage.h> 5#import <ATen/native/metal/MetalTensorUtils.h> 6#import <ATen/native/metal/MetalContext.h> 7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 8#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 9 10#include <ATen/ATen.h> 11#include <torch/library.h> 12 13namespace at::native::metal { 14 15template <typename T> 16Tensor mpscnn_softmax( 17 const Tensor& input, 18 int64_t dim, 19 std::optional<ScalarType> dtype) { 20 TORCH_CHECK(input.is_metal()); 21 // TODO: [T87180544] Implement softmax/log_softmax in metal shaders 22 TORCH_CHECK(input.dim() == 2); 23 if(input.numel() == 0){ 24 return makeTensor({input.sizes().vec()}, input.options()); 25 } 26 std::vector<int64_t> newSize(4, 1); 27 if (dim == 0) { 28 newSize[1] = input.size(0); 29 newSize[2] = input.size(1); 30 } else { 31 newSize[0] = input.size(0); 32 newSize[1] = input.size(1); 33 } 34 auto input_ = input.view(newSize); 35 MPSImage* X = imageFromTensor(input_); 36 // MPSCNNSoftmax kernels operate on feature channels 37 // https://developer.apple.com/documentation/metalperformanceshaders/mpscnnsoftmax?changes=_1&language=objc 38 T* softmax = [[T alloc] initWithDevice:[MetalContext sharedInstance].device]; 39 MetalTensorImplStorage mt{newSize}; 40 MetalCommandBuffer* commandBuffer = getCommandBuffer(input_); 41 mt.texture()->allocateTemporaryStorage(newSize, commandBuffer); 42 MPSImage* Y = mt.texture()->image(); 43 [softmax encodeToCommandBuffer:commandBuffer.buffer 44 sourceImage:X 45 destinationImage:Y]; 46 // restore the original sizes 47 auto output = makeTensor(std::move(mt), input.options()).view(input.sizes()); 48 return output; 49} 50 51static Tensor log_softmax_int( 52 const Tensor& input, 53 int64_t dim, 54 std::optional<ScalarType> dtype) { 55 return mpscnn_softmax<MPSCNNLogSoftMax>(input, dim, dtype); 56} 57 58static Tensor softmax_int( 59 const Tensor& input, 60 int64_t dim, 61 std::optional<ScalarType> dtype) { 62 return mpscnn_softmax<MPSCNNSoftMax>(input, dim, dtype); 63} 64 65TORCH_LIBRARY_IMPL(aten, Metal, m) { 66 m.impl(TORCH_SELECTIVE_NAME("aten::log_softmax.int"), TORCH_FN(metal::log_softmax_int)); 67 m.impl(TORCH_SELECTIVE_NAME("aten::softmax.int"), TORCH_FN(metal::softmax_int)); 68}; 69 70} // namespace at::native::metal 71