1 // 2 // Copyright (c) 2018-2019, Cem Bassoy, [email protected] 3 // 4 // Distributed under the Boost Software License, Version 1.0. (See 5 // accompanying file LICENSE_1_0.txt or copy at 6 // http://www.boost.org/LICENSE_1_0.txt) 7 // 8 // The authors gratefully acknowledge the support of 9 // Fraunhofer IOSB, Ettlingen, Germany 10 // 11 12 13 #ifndef BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP 14 #define BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP 15 16 #include <algorithm> 17 #include <initializer_list> 18 #include <limits> 19 #include <numeric> 20 #include <stdexcept> 21 #include <vector> 22 23 #include <cassert> 24 25 namespace boost { 26 namespace numeric { 27 namespace ublas { 28 29 30 /** @brief Template class for storing tensor extents with runtime variable size. 31 * 32 * Proxy template class of std::vector<int_type>. 33 * 34 */ 35 template<class int_type> 36 class basic_extents 37 { 38 static_assert( std::numeric_limits<typename std::vector<int_type>::value_type>::is_integer, "Static error in basic_layout: type must be of type integer."); 39 static_assert(!std::numeric_limits<typename std::vector<int_type>::value_type>::is_signed, "Static error in basic_layout: type must be of type unsigned integer."); 40 41 public: 42 using base_type = std::vector<int_type>; 43 using value_type = typename base_type::value_type; 44 using const_reference = typename base_type::const_reference; 45 using reference = typename base_type::reference; 46 using size_type = typename base_type::size_type; 47 using const_pointer = typename base_type::const_pointer; 48 using const_iterator = typename base_type::const_iterator; 49 50 51 /** @brief Default constructs basic_extents 52 * 53 * @code auto ex = basic_extents<unsigned>{}; 54 */ basic_extents()55 constexpr explicit basic_extents() 56 : _base{} 57 { 58 } 59 60 /** @brief Copy constructs basic_extents from a one-dimensional container 61 * 62 * @code auto ex = basic_extents<unsigned>( std::vector<unsigned>(3u,3u) ); 63 * 64 * @note checks if size > 1 and all elements > 0 65 * 66 * @param b one-dimensional std::vector<int_type> container 67 */ basic_extents(base_type const & b)68 explicit basic_extents(base_type const& b) 69 : _base(b) 70 { 71 if (!this->valid()){ 72 throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements."); 73 } 74 } 75 76 /** @brief Move constructs basic_extents from a one-dimensional container 77 * 78 * @code auto ex = basic_extents<unsigned>( std::vector<unsigned>(3u,3u) ); 79 * 80 * @note checks if size > 1 and all elements > 0 81 * 82 * @param b one-dimensional container of type std::vector<int_type> 83 */ basic_extents(base_type && b)84 explicit basic_extents(base_type && b) 85 : _base(std::move(b)) 86 { 87 if (!this->valid()){ 88 throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements."); 89 } 90 } 91 92 /** @brief Constructs basic_extents from an initializer list 93 * 94 * @code auto ex = basic_extents<unsigned>{3,2,4}; 95 * 96 * @note checks if size > 1 and all elements > 0 97 * 98 * @param l one-dimensional list of type std::initializer<int_type> 99 */ basic_extents(std::initializer_list<value_type> l)100 basic_extents(std::initializer_list<value_type> l) 101 : basic_extents( base_type(std::move(l)) ) 102 { 103 } 104 105 /** @brief Constructs basic_extents from a range specified by two iterators 106 * 107 * @code auto ex = basic_extents<unsigned>(a.begin(), a.end()); 108 * 109 * @note checks if size > 1 and all elements > 0 110 * 111 * @param first iterator pointing to the first element 112 * @param last iterator pointing to the next position after the last element 113 */ basic_extents(const_iterator first,const_iterator last)114 basic_extents(const_iterator first, const_iterator last) 115 : basic_extents ( base_type( first,last ) ) 116 { 117 } 118 119 /** @brief Copy constructs basic_extents */ basic_extents(basic_extents const & l)120 basic_extents(basic_extents const& l ) 121 : _base(l._base) 122 { 123 } 124 125 /** @brief Move constructs basic_extents */ basic_extents(basic_extents && l)126 basic_extents(basic_extents && l ) noexcept 127 : _base(std::move(l._base)) 128 { 129 } 130 131 ~basic_extents() = default; 132 operator =(basic_extents other)133 basic_extents& operator=(basic_extents other) noexcept 134 { 135 swap (*this, other); 136 return *this; 137 } 138 swap(basic_extents & lhs,basic_extents & rhs)139 friend void swap(basic_extents& lhs, basic_extents& rhs) { 140 std::swap(lhs._base , rhs._base ); 141 } 142 143 144 145 /** @brief Returns true if this has a scalar shape 146 * 147 * @returns true if (1,1,[1,...,1]) 148 */ is_scalar() const149 bool is_scalar() const 150 { 151 return 152 _base.size() != 0 && 153 std::all_of(_base.begin(), _base.end(), 154 [](const_reference a){ return a == 1;}); 155 } 156 157 /** @brief Returns true if this has a vector shape 158 * 159 * @returns true if (1,n,[1,...,1]) or (n,1,[1,...,1]) with n > 1 160 */ is_vector() const161 bool is_vector() const 162 { 163 if(_base.size() == 0){ 164 return false; 165 } 166 167 if(_base.size() == 1){ 168 return _base.at(0) > 1; 169 } 170 171 auto greater_one = [](const_reference a){ return a > 1;}; 172 auto equal_one = [](const_reference a){ return a == 1;}; 173 174 return 175 std::any_of(_base.begin(), _base.begin()+2, greater_one) && 176 std::any_of(_base.begin(), _base.begin()+2, equal_one ) && 177 std::all_of(_base.begin()+2, _base.end(), equal_one); 178 } 179 180 /** @brief Returns true if this has a matrix shape 181 * 182 * @returns true if (m,n,[1,...,1]) with m > 1 and n > 1 183 */ is_matrix() const184 bool is_matrix() const 185 { 186 if(_base.size() < 2){ 187 return false; 188 } 189 190 auto greater_one = [](const_reference a){ return a > 1;}; 191 auto equal_one = [](const_reference a){ return a == 1;}; 192 193 return 194 std::all_of(_base.begin(), _base.begin()+2, greater_one) && 195 std::all_of(_base.begin()+2, _base.end(), equal_one ); 196 } 197 198 /** @brief Returns true if this is has a tensor shape 199 * 200 * @returns true if !empty() && !is_scalar() && !is_vector() && !is_matrix() 201 */ is_tensor() const202 bool is_tensor() const 203 { 204 if(_base.size() < 3){ 205 return false; 206 } 207 208 auto greater_one = [](const_reference a){ return a > 1;}; 209 210 return std::any_of(_base.begin()+2, _base.end(), greater_one); 211 } 212 data() const213 const_pointer data() const 214 { 215 return this->_base.data(); 216 } 217 operator [](size_type p) const218 const_reference operator[] (size_type p) const 219 { 220 return this->_base[p]; 221 } 222 at(size_type p) const223 const_reference at (size_type p) const 224 { 225 return this->_base.at(p); 226 } 227 operator [](size_type p)228 reference operator[] (size_type p) 229 { 230 return this->_base[p]; 231 } 232 at(size_type p)233 reference at (size_type p) 234 { 235 return this->_base.at(p); 236 } 237 238 empty() const239 bool empty() const 240 { 241 return this->_base.empty(); 242 } 243 size() const244 size_type size() const 245 { 246 return this->_base.size(); 247 } 248 249 /** @brief Returns true if size > 1 and all elements > 0 */ valid() const250 bool valid() const 251 { 252 return 253 this->size() > 1 && 254 std::none_of(_base.begin(), _base.end(), 255 [](const_reference a){ return a == value_type(0); }); 256 } 257 258 /** @brief Returns the number of elements a tensor holds with this */ product() const259 size_type product() const 260 { 261 if(_base.empty()){ 262 return 0; 263 } 264 265 return std::accumulate(_base.begin(), _base.end(), 1ul, std::multiplies<>()); 266 267 } 268 269 270 /** @brief Eliminates singleton dimensions when size > 2 271 * 272 * squeeze { 1,1} -> { 1,1} 273 * squeeze { 2,1} -> { 2,1} 274 * squeeze { 1,2} -> { 1,2} 275 * 276 * squeeze {1,2,3} -> { 2,3} 277 * squeeze {2,1,3} -> { 2,3} 278 * squeeze {1,3,1} -> { 3,1} 279 * 280 */ squeeze() const281 basic_extents squeeze() const 282 { 283 if(this->size() <= 2){ 284 return *this; 285 } 286 287 auto new_extent = basic_extents{}; 288 auto insert_iter = std::back_insert_iterator<typename basic_extents::base_type>(new_extent._base); 289 std::remove_copy(this->_base.begin(), this->_base.end(), insert_iter ,value_type{1}); 290 return new_extent; 291 292 } 293 clear()294 void clear() 295 { 296 this->_base.clear(); 297 } 298 operator ==(basic_extents const & b) const299 bool operator == (basic_extents const& b) const 300 { 301 return _base == b._base; 302 } 303 operator !=(basic_extents const & b) const304 bool operator != (basic_extents const& b) const 305 { 306 return !( _base == b._base ); 307 } 308 309 const_iterator begin() const310 begin() const 311 { 312 return _base.begin(); 313 } 314 315 const_iterator end() const316 end() const 317 { 318 return _base.end(); 319 } 320 base() const321 base_type const& base() const { return _base; } 322 323 private: 324 325 base_type _base; 326 327 }; 328 329 using shape = basic_extents<std::size_t>; 330 331 } // namespace ublas 332 } // namespace numeric 333 } // namespace boost 334 335 #endif 336