1 //
2 //  Copyright (c) 2018-2019, Cem Bassoy, [email protected]
3 //
4 //  Distributed under the Boost Software License, Version 1.0. (See
5 //  accompanying file LICENSE_1_0.txt or copy at
6 //  http://www.boost.org/LICENSE_1_0.txt)
7 //
8 //  The authors gratefully acknowledge the support of
9 //  Fraunhofer IOSB, Ettlingen, Germany
10 //
11 
12 
13 #ifndef BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP
14 #define BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP
15 
16 #include <algorithm>
17 #include <initializer_list>
18 #include <limits>
19 #include <numeric>
20 #include <stdexcept>
21 #include <vector>
22 
23 #include <cassert>
24 
25 namespace boost {
26 namespace numeric {
27 namespace ublas {
28 
29 
30 /** @brief Template class for storing tensor extents with runtime variable size.
31  *
32  * Proxy template class of std::vector<int_type>.
33  *
34  */
35 template<class int_type>
36 class basic_extents
37 {
38 	static_assert( std::numeric_limits<typename std::vector<int_type>::value_type>::is_integer, "Static error in basic_layout: type must be of type integer.");
39 	static_assert(!std::numeric_limits<typename std::vector<int_type>::value_type>::is_signed,  "Static error in basic_layout: type must be of type unsigned integer.");
40 
41 public:
42 	using base_type = std::vector<int_type>;
43 	using value_type = typename base_type::value_type;
44 	using const_reference = typename base_type::const_reference;
45 	using reference = typename base_type::reference;
46 	using size_type = typename base_type::size_type;
47 	using const_pointer = typename base_type::const_pointer;
48 	using const_iterator = typename base_type::const_iterator;
49 
50 
51 	/** @brief Default constructs basic_extents
52 	 *
53 	 * @code auto ex = basic_extents<unsigned>{};
54 	 */
basic_extents()55 	constexpr explicit basic_extents()
56 	  : _base{}
57 	{
58 	}
59 
60 	/** @brief Copy constructs basic_extents from a one-dimensional container
61 	 *
62 	 * @code auto ex = basic_extents<unsigned>(  std::vector<unsigned>(3u,3u) );
63 	 *
64 	 * @note checks if size > 1 and all elements > 0
65 	 *
66 	 * @param b one-dimensional std::vector<int_type> container
67 	 */
basic_extents(base_type const & b)68 	explicit basic_extents(base_type const& b)
69 	  : _base(b)
70 	{
71 		if (!this->valid()){
72 			throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements.");
73 		}
74 	}
75 
76 	/** @brief Move constructs basic_extents from a one-dimensional container
77 	 *
78 	 * @code auto ex = basic_extents<unsigned>(  std::vector<unsigned>(3u,3u) );
79 	 *
80 	 * @note checks if size > 1 and all elements > 0
81 	 *
82 	 * @param b one-dimensional container of type std::vector<int_type>
83 	 */
basic_extents(base_type && b)84 	explicit basic_extents(base_type && b)
85 	  : _base(std::move(b))
86 	{
87 		if (!this->valid()){
88 			throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements.");
89 		}
90 	}
91 
92 	/** @brief Constructs basic_extents from an initializer list
93 	 *
94 	 * @code auto ex = basic_extents<unsigned>{3,2,4};
95 	 *
96 	 * @note checks if size > 1 and all elements > 0
97 	 *
98 	 * @param l one-dimensional list of type std::initializer<int_type>
99 	 */
basic_extents(std::initializer_list<value_type> l)100 	basic_extents(std::initializer_list<value_type> l)
101 	  : basic_extents( base_type(std::move(l)) )
102 	{
103 	}
104 
105 	/** @brief Constructs basic_extents from a range specified by two iterators
106 	 *
107 	 * @code auto ex = basic_extents<unsigned>(a.begin(), a.end());
108 	 *
109 	 * @note checks if size > 1 and all elements > 0
110 	 *
111 	 * @param first iterator pointing to the first element
112 	 * @param last iterator pointing to the next position after the last element
113 	 */
basic_extents(const_iterator first,const_iterator last)114 	basic_extents(const_iterator first, const_iterator last)
115 	  : basic_extents ( base_type( first,last ) )
116 	{
117 	}
118 
119 	/** @brief Copy constructs basic_extents */
basic_extents(basic_extents const & l)120 	basic_extents(basic_extents const& l )
121 	  : _base(l._base)
122 	{
123 	}
124 
125 	/** @brief Move constructs basic_extents */
basic_extents(basic_extents && l)126 	basic_extents(basic_extents && l ) noexcept
127 	  : _base(std::move(l._base))
128 	{
129 	}
130 
131 	~basic_extents() = default;
132 
operator =(basic_extents other)133 	basic_extents& operator=(basic_extents other) noexcept
134 	{
135 		swap (*this, other);
136 		return *this;
137 	}
138 
swap(basic_extents & lhs,basic_extents & rhs)139 	friend void swap(basic_extents& lhs, basic_extents& rhs) {
140 		std::swap(lhs._base   , rhs._base   );
141 	}
142 
143 
144 
145 	/** @brief Returns true if this has a scalar shape
146 	 *
147 	 * @returns true if (1,1,[1,...,1])
148 	*/
is_scalar() const149 	bool is_scalar() const
150 	{
151 		return
152 		    _base.size() != 0 &&
153 		    std::all_of(_base.begin(), _base.end(),
154 		                [](const_reference a){ return a == 1;});
155 	}
156 
157 	/** @brief Returns true if this has a vector shape
158 	 *
159 	 * @returns true if (1,n,[1,...,1]) or (n,1,[1,...,1]) with n > 1
160 	*/
is_vector() const161 	bool is_vector() const
162 	{
163 		if(_base.size() == 0){
164 			return false;
165 		}
166 
167 		if(_base.size() == 1){
168 			return _base.at(0) > 1;
169 		}
170 
171 		auto greater_one = [](const_reference a){ return a >  1;};
172 		auto equal_one   = [](const_reference a){ return a == 1;};
173 
174 		return
175 		    std::any_of(_base.begin(),   _base.begin()+2, greater_one) &&
176 		    std::any_of(_base.begin(),   _base.begin()+2, equal_one  ) &&
177 		    std::all_of(_base.begin()+2, _base.end(),     equal_one);
178 	}
179 
180 	/** @brief Returns true if this has a matrix shape
181 	 *
182 	 * @returns true if (m,n,[1,...,1]) with m > 1 and n > 1
183 	*/
is_matrix() const184 	bool is_matrix() const
185 	{
186 		if(_base.size() < 2){
187 			return false;
188 		}
189 
190 		auto greater_one = [](const_reference a){ return a >  1;};
191 		auto equal_one   = [](const_reference a){ return a == 1;};
192 
193 		return
194 		    std::all_of(_base.begin(),   _base.begin()+2, greater_one) &&
195 		    std::all_of(_base.begin()+2, _base.end(),     equal_one  );
196 	}
197 
198 	/** @brief Returns true if this is has a tensor shape
199 	 *
200 	 * @returns true if !empty() && !is_scalar() && !is_vector() && !is_matrix()
201 	*/
is_tensor() const202 	bool is_tensor() const
203 	{
204 		if(_base.size() < 3){
205 			return false;
206 		}
207 
208 		auto greater_one = [](const_reference a){ return a > 1;};
209 
210 		return std::any_of(_base.begin()+2, _base.end(), greater_one);
211 	}
212 
data() const213 	const_pointer data() const
214 	{
215 		return this->_base.data();
216 	}
217 
operator [](size_type p) const218 	const_reference operator[] (size_type p) const
219 	{
220 		return this->_base[p];
221 	}
222 
at(size_type p) const223 	const_reference at (size_type p) const
224 	{
225 		return this->_base.at(p);
226 	}
227 
operator [](size_type p)228 	reference operator[] (size_type p)
229 	{
230 		return this->_base[p];
231 	}
232 
at(size_type p)233 	reference at (size_type p)
234 	{
235 		return this->_base.at(p);
236 	}
237 
238 
empty() const239 	bool empty() const
240 	{
241 		return this->_base.empty();
242 	}
243 
size() const244 	size_type size() const
245 	{
246 		return this->_base.size();
247 	}
248 
249 	/** @brief Returns true if size > 1 and all elements > 0 */
valid() const250 	bool valid() const
251 	{
252 		return
253 		    this->size() > 1 &&
254 		    std::none_of(_base.begin(), _base.end(),
255 		                 [](const_reference a){ return a == value_type(0); });
256 	}
257 
258 	/** @brief Returns the number of elements a tensor holds with this */
product() const259 	size_type product() const
260 	{
261 		if(_base.empty()){
262 			return 0;
263 		}
264 
265 		return std::accumulate(_base.begin(), _base.end(), 1ul, std::multiplies<>());
266 
267 	}
268 
269 
270 	/** @brief Eliminates singleton dimensions when size > 2
271 	 *
272 	 * squeeze {  1,1} -> {  1,1}
273 	 * squeeze {  2,1} -> {  2,1}
274 	 * squeeze {  1,2} -> {  1,2}
275 	 *
276 	 * squeeze {1,2,3} -> {  2,3}
277 	 * squeeze {2,1,3} -> {  2,3}
278 	 * squeeze {1,3,1} -> {  3,1}
279 	 *
280 	*/
squeeze() const281 	basic_extents squeeze() const
282 	{
283 		if(this->size() <= 2){
284 			return *this;
285 		}
286 
287 		auto new_extent = basic_extents{};
288 		auto insert_iter = std::back_insert_iterator<typename basic_extents::base_type>(new_extent._base);
289 		std::remove_copy(this->_base.begin(), this->_base.end(), insert_iter ,value_type{1});
290 		return new_extent;
291 
292 	}
293 
clear()294 	void clear()
295 	{
296 		this->_base.clear();
297 	}
298 
operator ==(basic_extents const & b) const299 	bool operator == (basic_extents const& b) const
300 	{
301 		return _base == b._base;
302 	}
303 
operator !=(basic_extents const & b) const304 	bool operator != (basic_extents const& b) const
305 	{
306 		return !( _base == b._base );
307 	}
308 
309 	const_iterator
begin() const310 	begin() const
311 	{
312 		return _base.begin();
313 	}
314 
315 	const_iterator
end() const316 	end() const
317 	{
318 		return _base.end();
319 	}
320 
base() const321 	base_type const& base() const { return _base; }
322 
323 private:
324 
325 	base_type _base;
326 
327 };
328 
329 using shape = basic_extents<std::size_t>;
330 
331 } // namespace ublas
332 } // namespace numeric
333 } // namespace boost
334 
335 #endif
336