xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalSoftmax.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#include <ATen/Tensor.h>
2#import <ATen/native/metal/MetalCommandBuffer.h>
3#import <ATen/native/metal/MetalTensorImpl.h>
4#import <ATen/native/metal/MetalTensorImplStorage.h>
5#import <ATen/native/metal/MetalTensorUtils.h>
6#import <ATen/native/metal/MetalContext.h>
7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
8#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
9
10#include <ATen/ATen.h>
11#include <torch/library.h>
12
13namespace at::native::metal {
14
15template <typename T>
16Tensor mpscnn_softmax(
17    const Tensor& input,
18    int64_t dim,
19    std::optional<ScalarType> dtype) {
20  TORCH_CHECK(input.is_metal());
21  // TODO: [T87180544] Implement softmax/log_softmax in metal shaders
22  TORCH_CHECK(input.dim() == 2);
23  if(input.numel() == 0){
24      return makeTensor({input.sizes().vec()}, input.options());
25  }
26  std::vector<int64_t> newSize(4, 1);
27  if (dim == 0) {
28    newSize[1] = input.size(0);
29    newSize[2] = input.size(1);
30  } else {
31    newSize[0] = input.size(0);
32    newSize[1] = input.size(1);
33  }
34  auto input_ = input.view(newSize);
35  MPSImage* X = imageFromTensor(input_);
36  // MPSCNNSoftmax kernels operate on feature channels
37  // https://developer.apple.com/documentation/metalperformanceshaders/mpscnnsoftmax?changes=_1&language=objc
38  T* softmax = [[T alloc] initWithDevice:[MetalContext sharedInstance].device];
39  MetalTensorImplStorage mt{newSize};
40  MetalCommandBuffer* commandBuffer = getCommandBuffer(input_);
41  mt.texture()->allocateTemporaryStorage(newSize, commandBuffer);
42  MPSImage* Y = mt.texture()->image();
43  [softmax encodeToCommandBuffer:commandBuffer.buffer
44                     sourceImage:X
45                destinationImage:Y];
46  // restore the original sizes
47  auto output = makeTensor(std::move(mt), input.options()).view(input.sizes());
48  return output;
49}
50
51static Tensor log_softmax_int(
52    const Tensor& input,
53    int64_t dim,
54    std::optional<ScalarType> dtype) {
55  return mpscnn_softmax<MPSCNNLogSoftMax>(input, dim, dtype);
56}
57
58static Tensor softmax_int(
59    const Tensor& input,
60    int64_t dim,
61    std::optional<ScalarType> dtype) {
62  return mpscnn_softmax<MPSCNNSoftMax>(input, dim, dtype);
63}
64
65TORCH_LIBRARY_IMPL(aten, Metal, m) {
66  m.impl(TORCH_SELECTIVE_NAME("aten::log_softmax.int"), TORCH_FN(metal::log_softmax_int));
67  m.impl(TORCH_SELECTIVE_NAME("aten::softmax.int"), TORCH_FN(metal::softmax_int));
68};
69
70} // namespace at::native::metal
71