xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/OpaqueTensorImpl.h>
4 
5 namespace at {
6 // The only difference from OpaqueTensorImpl is faking strides(), stride(),
7 // is_contiguous(). The main intention for this is to be able to run torchscript
8 // model on Vulkan backend. Strides are not supported on Vulkan side, plan to
9 // support them.
10 template <typename OpaqueHandle>
11 struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl<OpaqueHandle> {
VulkanOpaqueTensorImplVulkanOpaqueTensorImpl12   VulkanOpaqueTensorImpl(
13       at::DispatchKeySet key_set,
14       const caffe2::TypeMeta data_type,
15       c10::Device device,
16       OpaqueHandle opaque_handle,
17       c10::IntArrayRef sizes,
18       c10::IntArrayRef strides)
19       : OpaqueTensorImpl<OpaqueHandle>(
20             key_set,
21             data_type,
22             device,
23             opaque_handle,
24             sizes,
25             false),
26         strides_(strides.vec()) {}
27 
strides_customVulkanOpaqueTensorImpl28   IntArrayRef strides_custom() const override {
29     return strides_;
30   }
31 
sym_strides_customVulkanOpaqueTensorImpl32   SymIntArrayRef sym_strides_custom() const override {
33     return c10::fromIntArrayRefKnownNonNegative(strides_);
34   }
35 
is_contiguous_customVulkanOpaqueTensorImpl36   bool is_contiguous_custom(c10::MemoryFormat memory_format) const override {
37     (void)memory_format;
38     return true;
39   }
40 
41  private:
tensorimpl_type_nameVulkanOpaqueTensorImpl42   const char* tensorimpl_type_name() const override {
43     return "VulkanOpaqueTensorImpl";
44   }
45 
46   // TODO: storing strides separately is unnecessary, the base TensorImpl
47   // has space for them
48   SmallVector<int64_t, 5> strides_;
49 };
50 
51 } // namespace at
52