1 #include <ATen/Tensor.h>
2 #include <ATen/native/metal/MetalContext.h>
3 #include <ATen/native/metal/MetalCommandBuffer.h>
4 #include <ATen/native/metal/MetalTensorImpl.h>
5 #include <ATen/native/metal/MetalTensorImplStorage.h>
6
7 #if (defined(__ARM_NEON__) || defined(__ARM_NEON))
8 typedef float16_t fp16_t;
9 #else
10 typedef uint16_t fp16_t;
11 #endif
12
13 namespace at::native::metal {
14
15 uint32_t batchSize(const Tensor& tensor);
16 uint32_t channelsSize(const Tensor& tensor);
17 uint32_t heightSize(const Tensor& tensor);
18 uint32_t widthSize(const Tensor& tensor);
19
20 // When copying the result back to a CPU tensor, the memory format becomes NCHW.
21 // Thus,we compute the strides based on contiguous memory format.
computeStrides(const std::vector<int64_t> & sizes)22 static inline std::vector<int64_t> computeStrides(
23 const std::vector<int64_t>& sizes) {
24 const auto dim = sizes.size();
25 std::vector<int64_t> strides(dim, 0);
26 if (dim > 0) {
27 const auto last_idx = dim - 1;
28 strides[last_idx] = 1;
29 for (int64_t i = last_idx - 1; i >= 0; --i) {
30 strides[i] = strides[i + 1] * std::max<int64_t>(sizes[i + 1], 1);
31 }
32 }
33 return strides;
34 }
35
getTensorImplStorage(const at::Tensor & tensor)36 static inline MetalTensorImplStorage& getTensorImplStorage(
37 const at::Tensor& tensor) {
38 using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
39 TORCH_CHECK(tensor.is_metal());
40 MetalTensorImpl* impl =
41 static_cast<MetalTensorImpl*>(tensor.unsafeGetTensorImpl());
42 return impl->unsafe_opaque_handle();
43 }
44
makeTensor(MetalTensorImplStorage && mt,const TensorOptions & options)45 static inline at::Tensor makeTensor(
46 MetalTensorImplStorage&& mt,
47 const TensorOptions& options) {
48 using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
49 auto sizes = mt.sizes(); // sizes is stored in TensorImpl
50 auto strides = mt.strides(); // strides is stored in MetalTensorImpl
51 return detail::make_tensor<MetalTensorImpl>(
52 DispatchKeySet(DispatchKey::Metal),
53 options.dtype(),
54 at::Device(at::kMetal),
55 std::move(mt),
56 std::vector<int64_t>(sizes.begin(), sizes.end()),
57 std::vector<int64_t>(strides.begin(), strides.end()));
58 }
59
getCommandBuffer(const Tensor & tensor)60 static inline MetalCommandBuffer* getCommandBuffer(
61 const Tensor& tensor) {
62 TORCH_CHECK(tensor.is_metal());
63 auto implStorage = getTensorImplStorage(tensor);
64 MetalCommandBuffer* cmdBuffer = implStorage.texture()->commandBuffer();
65 if (!cmdBuffer || !cmdBuffer.valid) {
66 cmdBuffer = [MetalCommandBuffer currentBuffer];
67 }
68 return cmdBuffer;
69 }
70
71 } // namespace at::native::metal
72