xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/MetalTensorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifndef MetalTensorImpl_h
2 #define MetalTensorImpl_h
3 
4 #include <ATen/OpaqueTensorImpl.h>
5 #include <ATen/WrapDimUtils.h>
6 #import <ATen/native/metal/MetalTensorImplStorage.h>
7 #import <ATen/native/metal/mpscnn/MPSImageWrapper.h>
8 
9 namespace at {
10 template <typename OpaqueHandle>
11 struct TORCH_API MetalTensorImpl : public OpaqueTensorImpl<OpaqueHandle> {
MetalTensorImplMetalTensorImpl12   MetalTensorImpl(
13       at::DispatchKeySet key_set,
14       const caffe2::TypeMeta& data_type,
15       c10::Device device,
16       OpaqueHandle opaque_handle,
17       c10::IntArrayRef sizes,
18       c10::IntArrayRef strides)
19       : OpaqueTensorImpl<OpaqueHandle>(
20             key_set,
21             data_type,
22             device,
23             opaque_handle,
24             sizes),
25         strides_(strides.vec()) {
26   }
27 
28   // TODO: manually storing strides here is dumb
29 
strides_customMetalTensorImpl30   IntArrayRef strides_custom() const override {
31     return strides_;
32   }
33 
sym_strides_customMetalTensorImpl34   c10::SymIntArrayRef sym_strides_custom() const override {
35     return c10::fromIntArrayRefKnownNonNegative(strides_);
36   }
37 
is_contiguous_customMetalTensorImpl38   bool is_contiguous_custom(c10::MemoryFormat memory_format) const override {
39     return true;
40   }
41 
42  private:
tensorimpl_type_nameMetalTensorImpl43   const char* tensorimpl_type_name() const override {
44     return "MetalTensorImpl";
45   }
46 
47   SmallVector<int64_t, 5> strides_;
48 };
49 } // namespace at
50 
51 #endif /* MetalTensorImpl_h*/
52