1 #ifndef MetalTensorImpl_h 2 #define MetalTensorImpl_h 3 4 #include <ATen/OpaqueTensorImpl.h> 5 #include <ATen/WrapDimUtils.h> 6 #import <ATen/native/metal/MetalTensorImplStorage.h> 7 #import <ATen/native/metal/mpscnn/MPSImageWrapper.h> 8 9 namespace at { 10 template <typename OpaqueHandle> 11 struct TORCH_API MetalTensorImpl : public OpaqueTensorImpl<OpaqueHandle> { MetalTensorImplMetalTensorImpl12 MetalTensorImpl( 13 at::DispatchKeySet key_set, 14 const caffe2::TypeMeta& data_type, 15 c10::Device device, 16 OpaqueHandle opaque_handle, 17 c10::IntArrayRef sizes, 18 c10::IntArrayRef strides) 19 : OpaqueTensorImpl<OpaqueHandle>( 20 key_set, 21 data_type, 22 device, 23 opaque_handle, 24 sizes), 25 strides_(strides.vec()) { 26 } 27 28 // TODO: manually storing strides here is dumb 29 strides_customMetalTensorImpl30 IntArrayRef strides_custom() const override { 31 return strides_; 32 } 33 sym_strides_customMetalTensorImpl34 c10::SymIntArrayRef sym_strides_custom() const override { 35 return c10::fromIntArrayRefKnownNonNegative(strides_); 36 } 37 is_contiguous_customMetalTensorImpl38 bool is_contiguous_custom(c10::MemoryFormat memory_format) const override { 39 return true; 40 } 41 42 private: tensorimpl_type_nameMetalTensorImpl43 const char* tensorimpl_type_name() const override { 44 return "MetalTensorImpl"; 45 } 46 47 SmallVector<int64_t, 5> strides_; 48 }; 49 } // namespace at 50 51 #endif /* MetalTensorImpl_h*/ 52