1 #pragma once 2 3 #include <ATen/OpaqueTensorImpl.h> 4 5 namespace at { 6 // The only difference from OpaqueTensorImpl is faking strides(), stride(), 7 // is_contiguous(). The main intention for this is to be able to run torchscript 8 // model on Vulkan backend. Strides are not supported on Vulkan side, plan to 9 // support them. 10 template <typename OpaqueHandle> 11 struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl<OpaqueHandle> { VulkanOpaqueTensorImplVulkanOpaqueTensorImpl12 VulkanOpaqueTensorImpl( 13 at::DispatchKeySet key_set, 14 const caffe2::TypeMeta data_type, 15 c10::Device device, 16 OpaqueHandle opaque_handle, 17 c10::IntArrayRef sizes, 18 c10::IntArrayRef strides) 19 : OpaqueTensorImpl<OpaqueHandle>( 20 key_set, 21 data_type, 22 device, 23 opaque_handle, 24 sizes, 25 false), 26 strides_(strides.vec()) {} 27 strides_customVulkanOpaqueTensorImpl28 IntArrayRef strides_custom() const override { 29 return strides_; 30 } 31 sym_strides_customVulkanOpaqueTensorImpl32 SymIntArrayRef sym_strides_custom() const override { 33 return c10::fromIntArrayRefKnownNonNegative(strides_); 34 } 35 is_contiguous_customVulkanOpaqueTensorImpl36 bool is_contiguous_custom(c10::MemoryFormat memory_format) const override { 37 (void)memory_format; 38 return true; 39 } 40 41 private: tensorimpl_type_nameVulkanOpaqueTensorImpl42 const char* tensorimpl_type_name() const override { 43 return "VulkanOpaqueTensorImpl"; 44 } 45 46 // TODO: storing strides separately is unnecessary, the base TensorImpl 47 // has space for them 48 SmallVector<int64_t, 5> strides_; 49 }; 50 51 } // namespace at 52