1 #pragma once 2 3 #include <ATen/core/List.h> 4 #include <ATen/core/Tensor.h> 5 6 namespace at { 7 class Tensor; 8 class OptionalTensorRef; 9 } 10 11 12 namespace c10::detail { 13 14 /* 15 * Specializations of `IListRefTagImplBase` that implement the default 16 * implementation for `IListRefTag::Unboxed`. 17 */ 18 template <typename T, typename ListElemT> 19 class IListRefTagImplBase<IListRefTag::Unboxed, T, ListElemT> { 20 public: 21 using elem_type = ListElemT; 22 using list_type = ArrayRef<elem_type>; 23 24 /* 25 * These `unwrap` static methods unwraps the inner containers out 26 * of `IListRef<T>` (and `IListRefIterator<T>`). They are required when 27 * the macro `TORCH_ILISTREF_UNWRAP` is called. 28 */ unwrap(const IListRef<T> & ilist)29 static const list_type& unwrap(const IListRef<T>& ilist) { 30 return ilist.payload_.unboxed; 31 } 32 unwrap(IListRefIterator<T> & it)33 static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) { 34 return it.payload_.unboxed_iterator; 35 } 36 unwrap(const IListRefIterator<T> & it)37 static const typename list_type::const_iterator& unwrap( 38 const IListRefIterator<T>& it) { 39 return it.payload_.unboxed_iterator; 40 } 41 42 /* 43 * We have these function (besides the `unwrap`s above) because the 44 * implementation for both `IListRef::operator[]` and `IListRefIterator::operator*` 45 * weren't syntatically equal for the existing tags at the time 46 * (`Unboxed` and `Boxed`). 47 */ front(const list_type & lst)48 static IListRefConstRef<T> front(const list_type& lst) { 49 return lst.front(); 50 } 51 iterator_get(const typename list_type::const_iterator & it)52 static IListRefConstRef<T> iterator_get( 53 const typename list_type::const_iterator& it) { 54 return *it; 55 } 56 }; 57 58 /* 59 * Specializations of `IListRefTagImplBase` that implement the default 60 * implementation for `IListRefTag::Boxed`. 61 */ 62 template <typename T, typename ListElemT> 63 class IListRefTagImplBase<IListRefTag::Boxed, T, ListElemT> { 64 public: 65 using elem_type = ListElemT; 66 using list_type = List<elem_type>; 67 unwrap(const IListRef<T> & ilist)68 static const list_type& unwrap(const IListRef<T>& ilist) { 69 return *ilist.payload_.boxed; 70 } 71 unwrap(IListRefIterator<T> & it)72 static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) { 73 return it.payload_.boxed_iterator; 74 } 75 unwrap(const IListRefIterator<T> & it)76 static const typename list_type::const_iterator& unwrap( 77 const IListRefIterator<T>& it) { 78 return it.payload_.boxed_iterator; 79 } 80 front(const list_type & lst)81 static IListRefConstRef<T> front(const list_type& lst) { 82 return lst[0]; 83 } 84 iterator_get(const typename list_type::const_iterator & it)85 static IListRefConstRef<T> iterator_get( 86 const typename list_type::const_iterator& it) { 87 return (*it).get().toTensor(); 88 } 89 }; 90 91 /* 92 * Specializations of `IListRefTagImplBase` that implement the default 93 * implementation for `IListRefTag::Materialized`. 94 */ 95 template <typename T> 96 class IListRefTagImplBase<IListRefTag::Materialized, T, MaterializedIListRefElem<T>> { 97 public: 98 using elem_type = MaterializedIListRefElem<T>; 99 using list_type = MaterializedIListRef<T>; 100 unwrap(const IListRef<T> & ilist)101 static const list_type& unwrap(const IListRef<T>& ilist) { 102 return *ilist.payload_.materialized; 103 } 104 unwrap(IListRefIterator<T> & it)105 static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) { 106 return it.payload_.materialized_iterator; 107 } 108 unwrap(const IListRefIterator<T> & it)109 static const typename list_type::const_iterator& unwrap( 110 const IListRefIterator<T>& it) { 111 return it.payload_.materialized_iterator; 112 } 113 front(const list_type & lst)114 static IListRefConstRef<T> front(const list_type& lst) { 115 return lst[0]; 116 } 117 iterator_get(const typename list_type::const_iterator & it)118 static IListRefConstRef<T> iterator_get( 119 const typename list_type::const_iterator& it) { 120 return *it; 121 } 122 }; 123 124 /* 125 * [Note: ITensorListRef] 126 * Specializations necessary for `IListRef<at::Tensor>` type. 127 * 128 * Since the default implementations are usually done with supporting 129 * `Tensor` in mind, we only have to inherit from the base implementations. 130 */ 131 template <> 132 class IListRefTagImpl<IListRefTag::Unboxed, at::Tensor> 133 : public IListRefTagImplBase<IListRefTag::Unboxed, at::Tensor> {}; 134 135 template <> 136 class IListRefTagImpl<IListRefTag::Boxed, at::Tensor> 137 : public IListRefTagImplBase<IListRefTag::Boxed, at::Tensor> {}; 138 139 template <> 140 class IListRefTagImpl<IListRefTag::Materialized, at::Tensor> 141 : public IListRefTagImplBase< 142 IListRefTag::Materialized, 143 at::Tensor, 144 MaterializedIListRefElem<at::Tensor>> {}; 145 146 /* 147 * [Note: IOptTensorListRef] 148 * Specializations necessary for `IListRef<at::OptionalTensorRef>` type. 149 * 150 * We can't get an `at::OptionalTensorRef` directly from an instance of 151 * `List<optional<Tensor>>` (the type that corresponds to the boxed world). 152 * 153 * So, the default implementation won't help us. Thus, we have to implement 154 * this method ourselves. 155 */ 156 template <> 157 class IListRefTagImpl<IListRefTag::Unboxed, at::OptionalTensorRef> 158 : public IListRefTagImplBase<IListRefTag::Unboxed, at::OptionalTensorRef> {}; 159 160 template <> 161 class IListRefTagImpl<IListRefTag::Boxed, at::OptionalTensorRef> 162 : public IListRefTagImplBase<IListRefTag::Boxed, at::OptionalTensorRef, std::optional<at::Tensor>> { 163 164 public: 165 /* 166 * Given an instance of the types corresponding to the `Boxed` tag, we override 167 * the default implementation, so that we can return a `at::OptionalTensorRef`. 168 */ iterator_get(const typename list_type::const_iterator & it)169 static IListRefConstRef<at::OptionalTensorRef> iterator_get( 170 const typename list_type::const_iterator& it) { 171 const auto& ivalue = (*it).get(); 172 if (!ivalue.isNone()) { 173 const auto& tensor = ivalue.toTensor(); 174 return (tensor.defined()) ? tensor : at::OptionalTensorRef{}; 175 } 176 return {}; 177 } 178 }; 179 180 template <> 181 class IListRefTagImpl<IListRefTag::Materialized, at::OptionalTensorRef> 182 : public IListRefTagImplBase< 183 IListRefTag::Materialized, 184 at::OptionalTensorRef, 185 MaterializedIListRefElem<at::OptionalTensorRef>> {}; 186 187 } // namespace c10::detail 188 189 190 namespace at { 191 192 // [Note: ITensorListRef] 193 using ITensorListRef = c10::IListRef<at::Tensor>; 194 using ITensorListRefIterator = c10::IListRefIterator<at::Tensor>; 195 using MaterializedITensorListRef = c10::detail::MaterializedIListRef<at::Tensor>; 196 // [Note: IOptTensorListRef] 197 using IOptTensorListRef = c10::IListRef<at::OptionalTensorRef>; 198 using IOptTensorListRefIterator = c10::IListRefIterator<at::OptionalTensorRef>; 199 using MaterializedIOptTensorListRef = c10::detail::MaterializedIListRef<at::OptionalTensorRef>; 200 201 } // namespace at 202