xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/MetalTensorUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Tensor.h>
2 #include <ATen/native/metal/MetalContext.h>
3 #include <ATen/native/metal/MetalCommandBuffer.h>
4 #include <ATen/native/metal/MetalTensorImpl.h>
5 #include <ATen/native/metal/MetalTensorImplStorage.h>
6 
7 #if (defined(__ARM_NEON__) || defined(__ARM_NEON))
8 typedef float16_t fp16_t;
9 #else
10 typedef uint16_t fp16_t;
11 #endif
12 
13 namespace at::native::metal {
14 
15 uint32_t batchSize(const Tensor& tensor);
16 uint32_t channelsSize(const Tensor& tensor);
17 uint32_t heightSize(const Tensor& tensor);
18 uint32_t widthSize(const Tensor& tensor);
19 
20 // When copying the result back to a CPU tensor, the memory format becomes NCHW.
21 // Thus,we compute the strides based on contiguous memory format.
computeStrides(const std::vector<int64_t> & sizes)22 static inline std::vector<int64_t> computeStrides(
23     const std::vector<int64_t>& sizes) {
24   const auto dim = sizes.size();
25   std::vector<int64_t> strides(dim, 0);
26   if (dim > 0) {
27     const auto last_idx = dim - 1;
28     strides[last_idx] = 1;
29     for (int64_t i = last_idx - 1; i >= 0; --i) {
30       strides[i] = strides[i + 1] * std::max<int64_t>(sizes[i + 1], 1);
31     }
32   }
33   return strides;
34 }
35 
getTensorImplStorage(const at::Tensor & tensor)36 static inline MetalTensorImplStorage& getTensorImplStorage(
37     const at::Tensor& tensor) {
38   using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
39   TORCH_CHECK(tensor.is_metal());
40   MetalTensorImpl* impl =
41       static_cast<MetalTensorImpl*>(tensor.unsafeGetTensorImpl());
42   return impl->unsafe_opaque_handle();
43 }
44 
makeTensor(MetalTensorImplStorage && mt,const TensorOptions & options)45 static inline at::Tensor makeTensor(
46     MetalTensorImplStorage&& mt,
47     const TensorOptions& options) {
48   using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
49   auto sizes = mt.sizes(); // sizes is stored in TensorImpl
50   auto strides = mt.strides(); // strides is stored in MetalTensorImpl
51   return detail::make_tensor<MetalTensorImpl>(
52       DispatchKeySet(DispatchKey::Metal),
53       options.dtype(),
54       at::Device(at::kMetal),
55       std::move(mt),
56       std::vector<int64_t>(sizes.begin(), sizes.end()),
57       std::vector<int64_t>(strides.begin(), strides.end()));
58 }
59 
getCommandBuffer(const Tensor & tensor)60 static inline MetalCommandBuffer* getCommandBuffer(
61     const Tensor& tensor) {
62   TORCH_CHECK(tensor.is_metal());
63   auto implStorage = getTensorImplStorage(tensor);
64   MetalCommandBuffer* cmdBuffer = implStorage.texture()->commandBuffer();
65   if (!cmdBuffer || !cmdBuffer.valid) {
66     cmdBuffer = [MetalCommandBuffer currentBuffer];
67   }
68   return cmdBuffer;
69 }
70 
71 } // namespace at::native::metal
72