xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/IListRef_inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/List.h>
4 #include <ATen/core/Tensor.h>
5 
6 namespace at {
7 class Tensor;
8 class OptionalTensorRef;
9 }
10 
11 
12 namespace c10::detail {
13 
14 /*
15  * Specializations of `IListRefTagImplBase` that implement the default
16  * implementation for `IListRefTag::Unboxed`.
17  */
18 template <typename T, typename ListElemT>
19 class IListRefTagImplBase<IListRefTag::Unboxed, T, ListElemT> {
20  public:
21   using elem_type = ListElemT;
22   using list_type = ArrayRef<elem_type>;
23 
24   /*
25    * These `unwrap` static methods unwraps the inner containers out
26    * of `IListRef<T>` (and `IListRefIterator<T>`). They are required when
27    * the macro `TORCH_ILISTREF_UNWRAP` is called.
28    */
unwrap(const IListRef<T> & ilist)29   static const list_type& unwrap(const IListRef<T>& ilist) {
30     return ilist.payload_.unboxed;
31   }
32 
unwrap(IListRefIterator<T> & it)33   static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
34     return it.payload_.unboxed_iterator;
35   }
36 
unwrap(const IListRefIterator<T> & it)37   static const typename list_type::const_iterator& unwrap(
38       const IListRefIterator<T>& it) {
39     return it.payload_.unboxed_iterator;
40   }
41 
42   /*
43    * We have these function (besides the `unwrap`s above) because the
44    * implementation for both `IListRef::operator[]` and `IListRefIterator::operator*`
45    * weren't syntatically equal for the existing tags at the time
46    * (`Unboxed` and `Boxed`).
47    */
front(const list_type & lst)48   static IListRefConstRef<T> front(const list_type& lst) {
49     return lst.front();
50   }
51 
iterator_get(const typename list_type::const_iterator & it)52   static IListRefConstRef<T> iterator_get(
53       const typename list_type::const_iterator& it) {
54     return *it;
55   }
56 };
57 
58 /*
59  * Specializations of `IListRefTagImplBase` that implement the default
60  * implementation for `IListRefTag::Boxed`.
61  */
62 template <typename T, typename ListElemT>
63 class IListRefTagImplBase<IListRefTag::Boxed, T, ListElemT> {
64  public:
65   using elem_type = ListElemT;
66   using list_type = List<elem_type>;
67 
unwrap(const IListRef<T> & ilist)68   static const list_type& unwrap(const IListRef<T>& ilist) {
69     return *ilist.payload_.boxed;
70   }
71 
unwrap(IListRefIterator<T> & it)72   static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
73     return it.payload_.boxed_iterator;
74   }
75 
unwrap(const IListRefIterator<T> & it)76   static const typename list_type::const_iterator& unwrap(
77       const IListRefIterator<T>& it) {
78     return it.payload_.boxed_iterator;
79   }
80 
front(const list_type & lst)81   static IListRefConstRef<T> front(const list_type& lst) {
82     return lst[0];
83   }
84 
iterator_get(const typename list_type::const_iterator & it)85   static IListRefConstRef<T> iterator_get(
86       const typename list_type::const_iterator& it) {
87     return (*it).get().toTensor();
88   }
89 };
90 
91 /*
92  * Specializations of `IListRefTagImplBase` that implement the default
93  * implementation for `IListRefTag::Materialized`.
94  */
95 template <typename T>
96 class IListRefTagImplBase<IListRefTag::Materialized, T, MaterializedIListRefElem<T>> {
97  public:
98   using elem_type = MaterializedIListRefElem<T>;
99   using list_type = MaterializedIListRef<T>;
100 
unwrap(const IListRef<T> & ilist)101   static const list_type& unwrap(const IListRef<T>& ilist) {
102     return *ilist.payload_.materialized;
103   }
104 
unwrap(IListRefIterator<T> & it)105   static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
106     return it.payload_.materialized_iterator;
107   }
108 
unwrap(const IListRefIterator<T> & it)109   static const typename list_type::const_iterator& unwrap(
110       const IListRefIterator<T>& it) {
111     return it.payload_.materialized_iterator;
112   }
113 
front(const list_type & lst)114   static IListRefConstRef<T> front(const list_type& lst) {
115     return lst[0];
116   }
117 
iterator_get(const typename list_type::const_iterator & it)118   static IListRefConstRef<T> iterator_get(
119       const typename list_type::const_iterator& it) {
120     return *it;
121   }
122 };
123 
124 /*
125  * [Note: ITensorListRef]
126  * Specializations necessary for `IListRef<at::Tensor>` type.
127  *
128  * Since the default implementations are usually done with supporting
129  * `Tensor` in mind, we only have to inherit from the base implementations.
130  */
131 template <>
132 class IListRefTagImpl<IListRefTag::Unboxed, at::Tensor>
133     : public IListRefTagImplBase<IListRefTag::Unboxed, at::Tensor> {};
134 
135 template <>
136 class IListRefTagImpl<IListRefTag::Boxed, at::Tensor>
137     : public IListRefTagImplBase<IListRefTag::Boxed, at::Tensor> {};
138 
139 template <>
140 class IListRefTagImpl<IListRefTag::Materialized, at::Tensor>
141     : public IListRefTagImplBase<
142           IListRefTag::Materialized,
143           at::Tensor,
144           MaterializedIListRefElem<at::Tensor>> {};
145 
146 /*
147  * [Note: IOptTensorListRef]
148  * Specializations necessary for `IListRef<at::OptionalTensorRef>` type.
149  *
150  * We can't get an `at::OptionalTensorRef` directly from an instance of
151  * `List<optional<Tensor>>` (the type that corresponds to the boxed world).
152  *
153  * So, the default implementation won't help us. Thus, we have to implement
154  * this method ourselves.
155  */
156 template <>
157 class IListRefTagImpl<IListRefTag::Unboxed, at::OptionalTensorRef>
158     : public IListRefTagImplBase<IListRefTag::Unboxed, at::OptionalTensorRef> {};
159 
160 template <>
161 class IListRefTagImpl<IListRefTag::Boxed, at::OptionalTensorRef>
162     : public IListRefTagImplBase<IListRefTag::Boxed, at::OptionalTensorRef, std::optional<at::Tensor>> {
163 
164  public:
165   /*
166    * Given an instance of the types corresponding to the `Boxed` tag, we override
167    * the default implementation, so that we can return a `at::OptionalTensorRef`.
168    */
iterator_get(const typename list_type::const_iterator & it)169   static IListRefConstRef<at::OptionalTensorRef> iterator_get(
170       const typename list_type::const_iterator& it) {
171     const auto& ivalue = (*it).get();
172     if (!ivalue.isNone()) {
173         const auto& tensor = ivalue.toTensor();
174         return (tensor.defined()) ? tensor : at::OptionalTensorRef{};
175     }
176     return {};
177   }
178 };
179 
180 template <>
181 class IListRefTagImpl<IListRefTag::Materialized, at::OptionalTensorRef>
182     : public IListRefTagImplBase<
183           IListRefTag::Materialized,
184           at::OptionalTensorRef,
185           MaterializedIListRefElem<at::OptionalTensorRef>> {};
186 
187 } // namespace c10::detail
188 
189 
190 namespace at {
191 
192 // [Note: ITensorListRef]
193 using ITensorListRef = c10::IListRef<at::Tensor>;
194 using ITensorListRefIterator = c10::IListRefIterator<at::Tensor>;
195 using MaterializedITensorListRef = c10::detail::MaterializedIListRef<at::Tensor>;
196 // [Note: IOptTensorListRef]
197 using IOptTensorListRef = c10::IListRef<at::OptionalTensorRef>;
198 using IOptTensorListRefIterator = c10::IListRefIterator<at::OptionalTensorRef>;
199 using MaterializedIOptTensorListRef = c10::detail::MaterializedIListRef<at::OptionalTensorRef>;
200 
201 } // namespace at
202