1 #pragma once 2 3 #include <c10/core/MemoryFormat.h> 4 #include <c10/core/SymIntArrayRef.h> 5 #include <c10/core/TensorImpl.h> 6 #include <c10/util/Exception.h> 7 8 namespace at { 9 10 // An "Opaque" TensorImpl -- there are no strides and (for now) 11 // even data() is not supported (thus no pointer arithmetic). 12 13 // NOTE: We could allow data() in the future, but would have to ensure pointer 14 // arithmetic code is properly guarded. 15 // 16 // NOTE: This does not support resize_ (and other metadata-changing ops) because 17 // of `shallow_copy_and_detach`. We would need to define an interface to 18 // "shallow copy" in order to add support. 19 20 template <typename OpaqueHandle> 21 struct TORCH_API OpaqueTensorImpl : public TensorImpl { 22 // public constructor for now... 23 OpaqueTensorImpl( 24 at::DispatchKeySet key_set, 25 const caffe2::TypeMeta data_type, 26 c10::Device device, 27 OpaqueHandle opaque_handle, 28 c10::IntArrayRef sizes, 29 bool is_non_overlapping_and_dense = true) TensorImplOpaqueTensorImpl30 : TensorImpl(key_set, data_type, device), 31 opaque_handle_(std::move(opaque_handle)) { 32 set_storage_access_should_throw(); 33 set_custom_sizes_strides(SizesStridesPolicy::CustomStrides); 34 sizes_and_strides_.set_sizes(sizes); 35 refresh_numel(); 36 // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) 37 is_non_overlapping_and_dense_ = is_non_overlapping_and_dense; 38 } 39 40 // Destructor doesn't call release_resources because it's 41 // unnecessary; don't forget to change that if needed! release_resourcesOpaqueTensorImpl42 void release_resources() override { 43 TensorImpl::release_resources(); 44 opaque_handle_ = {}; 45 } 46 set_sizeOpaqueTensorImpl47 void set_size(int64_t dim, int64_t new_size) override { 48 AT_ERROR("opaque tensors do not have set_size"); 49 } 50 set_strideOpaqueTensorImpl51 void set_stride(int64_t dim, int64_t new_stride) override { 52 AT_ERROR("opaque tensors do not have set_stride"); 53 } 54 set_storage_offsetOpaqueTensorImpl55 void set_storage_offset(int64_t storage_offset) override { 56 AT_ERROR("opaque tensors do not have set_storage_offset"); 57 } 58 59 #ifdef DEBUG has_storageOpaqueTensorImpl60 bool has_storage() const override { 61 TORCH_INTERNAL_ASSERT_DEBUG_ONLY( 62 !storage_, "OpaqueTensorImpl assumes that storage_ is never set"); 63 return false; 64 } 65 #endif 66 67 /** 68 * Return a TensorImpl that is a shallow-copy of this TensorImpl. 69 * 70 * For usage of `version_counter` and `allow_tensor_metadata_change`, 71 * see NOTE [ TensorImpl Shallow-Copying ]. 72 */ shallow_copy_and_detachOpaqueTensorImpl73 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( 74 const c10::VariableVersion& version_counter, 75 bool allow_tensor_metadata_change) const override { 76 auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>( 77 key_set(), 78 dtype(), 79 device(), 80 opaque_handle_, 81 sizes_and_strides_.sizes_arrayref()); 82 copy_tensor_metadata( 83 /*src_opaque_impl=*/this, 84 /*dest_opaque_impl=*/impl.get(), 85 /*version_counter=*/version_counter, 86 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); 87 impl->refresh_numel(); 88 return impl; 89 } 90 91 /** 92 * Return a TensorImpl that is a shallow-copy of this TensorImpl. 93 * 94 * For usage of `version_counter` and `allow_tensor_metadata_change`, 95 * see NOTE [ TensorImpl Shallow-Copying ]. 96 */ shallow_copy_and_detachOpaqueTensorImpl97 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( 98 c10::VariableVersion&& version_counter, 99 bool allow_tensor_metadata_change) const override { 100 auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>( 101 key_set(), 102 dtype(), 103 device(), 104 opaque_handle_, 105 sizes_and_strides_.sizes_arrayref()); 106 copy_tensor_metadata( 107 /*src_opaque_impl=*/this, 108 /*dest_opaque_impl=*/impl.get(), 109 /*version_counter=*/std::move(version_counter), 110 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); 111 impl->refresh_numel(); 112 return impl; 113 } 114 115 /** 116 * Shallow-copies data from another TensorImpl into this TensorImpl. 117 * 118 * For why this function doesn't check this TensorImpl's 119 * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ]. 120 */ shallow_copy_fromOpaqueTensorImpl121 void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override { 122 AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); 123 auto opaque_impl = 124 static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get()); 125 copy_tensor_metadata( 126 /*src_impl=*/opaque_impl, 127 /*dest_impl=*/this, 128 /*version_counter=*/version_counter(), 129 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); 130 refresh_numel(); 131 } 132 opaque_handleOpaqueTensorImpl133 const OpaqueHandle& opaque_handle() const { 134 return opaque_handle_; 135 } 136 unsafe_opaque_handleOpaqueTensorImpl137 OpaqueHandle& unsafe_opaque_handle() { 138 return opaque_handle_; 139 } 140 141 protected: 142 /** 143 * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / 144 * storage_offset) from one TensorImpl to another TensorImpl. 145 * 146 * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE 147 * [ TensorImpl Shallow-Copying ]. 148 */ copy_tensor_metadataOpaqueTensorImpl149 static void copy_tensor_metadata( 150 const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl, 151 OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl, 152 const c10::VariableVersion& version_counter, 153 bool allow_tensor_metadata_change) { 154 TensorImpl::copy_tensor_metadata( 155 src_opaque_impl, 156 dest_opaque_impl, 157 version_counter, 158 allow_tensor_metadata_change); 159 160 // OpaqueTensorImpl-specific fields. 161 dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_; 162 } 163 copy_tensor_metadataOpaqueTensorImpl164 static void copy_tensor_metadata( 165 const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl, 166 OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl, 167 c10::VariableVersion&& version_counter, 168 bool allow_tensor_metadata_change) { 169 TensorImpl::copy_tensor_metadata( 170 src_opaque_impl, 171 dest_opaque_impl, 172 std::move(version_counter), 173 allow_tensor_metadata_change); 174 175 // OpaqueTensorImpl-specific fields. 176 dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_; 177 } 178 179 private: tensorimpl_type_nameOpaqueTensorImpl180 const char* tensorimpl_type_name() const override { 181 return "OpaqueTensorImpl"; 182 } 183 184 OpaqueHandle opaque_handle_; 185 }; 186 187 } // namespace at 188