xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalCopy.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalCommandBuffer.h>
2#import <ATen/native/metal/MetalTensorImpl.h>
3#import <ATen/native/metal/MetalTensorImplStorage.h>
4#import <ATen/native/metal/MetalTensorUtils.h>
5#import <ATen/native/metal/MetalContext.h>
6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
8#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
9
10#include <torch/library.h>
11
12namespace at::native::metal {
13
14static Tensor copy_to_host(const Tensor& input) {
15  TORCH_CHECK(input.is_metal());
16  MPSImage* X = imageFromTensor(input);
17  if (X && !X.isTemporaryImage) {
18    return input;
19  }
20  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
21  auto&& sizes = [X sizes];
22  MetalTensorImplStorage mt{sizes};
23  mt.texture()->setCommandBuffer(commandBuffer);
24  mt.texture()->allocateStorage(sizes);
25  MPSImage* Y = mt.texture()->image();
26
27  id<MTLComputeCommandEncoder> encoder =
28      [commandBuffer.buffer computeCommandEncoder];
29  id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
30      specializedPipelineState:metal::mpscnn::kernelFor(
31                                   X, "copy", "copy_nonarray")
32                     Constants:@[
33                       @(X.featureChannels),
34                       @(X.height),
35                       @(X.width)
36                     ]];
37
38  [encoder setComputePipelineState:state];
39  [encoder setTexture:[X texture] atIndex:0];
40  [encoder setTexture:[Y texture] atIndex:1];
41
42  const auto& launchParams =
43      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
44  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
45          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
46  [encoder endEncoding];
47  auto output = makeTensor(std::move(mt), input.options());
48  return output;
49}
50
51TORCH_LIBRARY_IMPL(metal, Metal, m) {
52  m.impl(TORCH_SELECTIVE_NAME("metal::copy_to_host"), TORCH_FN(copy_to_host));
53}
54
55} // namespace at::native::metal
56