xref: /aosp_15_r20/external/pytorch/aten/src/ATen/OpaqueTensorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/MemoryFormat.h>
4 #include <c10/core/SymIntArrayRef.h>
5 #include <c10/core/TensorImpl.h>
6 #include <c10/util/Exception.h>
7 
8 namespace at {
9 
10 // An "Opaque" TensorImpl -- there are no strides and (for now)
11 // even data() is not supported (thus no pointer arithmetic).
12 
13 // NOTE: We could allow data() in the future, but would have to ensure pointer
14 // arithmetic code is properly guarded.
15 //
16 // NOTE: This does not support resize_ (and other metadata-changing ops) because
17 // of `shallow_copy_and_detach`. We would need to define an interface to
18 // "shallow copy" in order to add support.
19 
20 template <typename OpaqueHandle>
21 struct TORCH_API OpaqueTensorImpl : public TensorImpl {
22   // public constructor for now...
23   OpaqueTensorImpl(
24       at::DispatchKeySet key_set,
25       const caffe2::TypeMeta data_type,
26       c10::Device device,
27       OpaqueHandle opaque_handle,
28       c10::IntArrayRef sizes,
29       bool is_non_overlapping_and_dense = true)
TensorImplOpaqueTensorImpl30       : TensorImpl(key_set, data_type, device),
31         opaque_handle_(std::move(opaque_handle)) {
32     set_storage_access_should_throw();
33     set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
34     sizes_and_strides_.set_sizes(sizes);
35     refresh_numel();
36     // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
37     is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
38   }
39 
40   // Destructor doesn't call release_resources because it's
41   // unnecessary; don't forget to change that if needed!
release_resourcesOpaqueTensorImpl42   void release_resources() override {
43     TensorImpl::release_resources();
44     opaque_handle_ = {};
45   }
46 
set_sizeOpaqueTensorImpl47   void set_size(int64_t dim, int64_t new_size) override {
48     AT_ERROR("opaque tensors do not have set_size");
49   }
50 
set_strideOpaqueTensorImpl51   void set_stride(int64_t dim, int64_t new_stride) override {
52     AT_ERROR("opaque tensors do not have set_stride");
53   }
54 
set_storage_offsetOpaqueTensorImpl55   void set_storage_offset(int64_t storage_offset) override {
56     AT_ERROR("opaque tensors do not have set_storage_offset");
57   }
58 
59 #ifdef DEBUG
has_storageOpaqueTensorImpl60   bool has_storage() const override {
61     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
62         !storage_, "OpaqueTensorImpl assumes that storage_ is never set");
63     return false;
64   }
65 #endif
66 
67   /**
68    * Return a TensorImpl that is a shallow-copy of this TensorImpl.
69    *
70    * For usage of `version_counter` and `allow_tensor_metadata_change`,
71    * see NOTE [ TensorImpl Shallow-Copying ].
72    */
shallow_copy_and_detachOpaqueTensorImpl73   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
74       const c10::VariableVersion& version_counter,
75       bool allow_tensor_metadata_change) const override {
76     auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
77         key_set(),
78         dtype(),
79         device(),
80         opaque_handle_,
81         sizes_and_strides_.sizes_arrayref());
82     copy_tensor_metadata(
83         /*src_opaque_impl=*/this,
84         /*dest_opaque_impl=*/impl.get(),
85         /*version_counter=*/version_counter,
86         /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
87     impl->refresh_numel();
88     return impl;
89   }
90 
91   /**
92    * Return a TensorImpl that is a shallow-copy of this TensorImpl.
93    *
94    * For usage of `version_counter` and `allow_tensor_metadata_change`,
95    * see NOTE [ TensorImpl Shallow-Copying ].
96    */
shallow_copy_and_detachOpaqueTensorImpl97   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
98       c10::VariableVersion&& version_counter,
99       bool allow_tensor_metadata_change) const override {
100     auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
101         key_set(),
102         dtype(),
103         device(),
104         opaque_handle_,
105         sizes_and_strides_.sizes_arrayref());
106     copy_tensor_metadata(
107         /*src_opaque_impl=*/this,
108         /*dest_opaque_impl=*/impl.get(),
109         /*version_counter=*/std::move(version_counter),
110         /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
111     impl->refresh_numel();
112     return impl;
113   }
114 
115   /**
116    * Shallow-copies data from another TensorImpl into this TensorImpl.
117    *
118    * For why this function doesn't check this TensorImpl's
119    * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
120    */
shallow_copy_fromOpaqueTensorImpl121   void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
122     AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
123     auto opaque_impl =
124         static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
125     copy_tensor_metadata(
126         /*src_impl=*/opaque_impl,
127         /*dest_impl=*/this,
128         /*version_counter=*/version_counter(),
129         /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
130     refresh_numel();
131   }
132 
opaque_handleOpaqueTensorImpl133   const OpaqueHandle& opaque_handle() const {
134     return opaque_handle_;
135   }
136 
unsafe_opaque_handleOpaqueTensorImpl137   OpaqueHandle& unsafe_opaque_handle() {
138     return opaque_handle_;
139   }
140 
141  protected:
142   /**
143    * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
144    * storage_offset) from one TensorImpl to another TensorImpl.
145    *
146    * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
147    * [ TensorImpl Shallow-Copying ].
148    */
copy_tensor_metadataOpaqueTensorImpl149   static void copy_tensor_metadata(
150       const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
151       OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
152       const c10::VariableVersion& version_counter,
153       bool allow_tensor_metadata_change) {
154     TensorImpl::copy_tensor_metadata(
155         src_opaque_impl,
156         dest_opaque_impl,
157         version_counter,
158         allow_tensor_metadata_change);
159 
160     // OpaqueTensorImpl-specific fields.
161     dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
162   }
163 
copy_tensor_metadataOpaqueTensorImpl164   static void copy_tensor_metadata(
165       const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
166       OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
167       c10::VariableVersion&& version_counter,
168       bool allow_tensor_metadata_change) {
169     TensorImpl::copy_tensor_metadata(
170         src_opaque_impl,
171         dest_opaque_impl,
172         std::move(version_counter),
173         allow_tensor_metadata_change);
174 
175     // OpaqueTensorImpl-specific fields.
176     dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
177   }
178 
179  private:
tensorimpl_type_nameOpaqueTensorImpl180   const char* tensorimpl_type_name() const override {
181     return "OpaqueTensorImpl";
182   }
183 
184   OpaqueHandle opaque_handle_;
185 };
186 
187 } // namespace at
188