1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <[email protected]> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H 12 13 namespace Eigen { 14 15 /** \class TensorConcatenationOp 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Tensor concatenation class. 19 * 20 * 21 */ 22 namespace internal { 23 template<typename Axis, typename LhsXprType, typename RhsXprType> 24 struct traits<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> > 25 { 26 // Type promotion to handle the case where the types of the lhs and the rhs are different. 27 typedef typename promote_storage_type<typename LhsXprType::Scalar, 28 typename RhsXprType::Scalar>::ret Scalar; 29 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind, 30 typename traits<RhsXprType>::StorageKind>::ret StorageKind; 31 typedef typename promote_index_type<typename traits<LhsXprType>::Index, 32 typename traits<RhsXprType>::Index>::type Index; 33 typedef typename LhsXprType::Nested LhsNested; 34 typedef typename RhsXprType::Nested RhsNested; 35 typedef typename remove_reference<LhsNested>::type _LhsNested; 36 typedef typename remove_reference<RhsNested>::type _RhsNested; 37 static const int NumDimensions = traits<LhsXprType>::NumDimensions; 38 static const int Layout = traits<LhsXprType>::Layout; 39 enum { Flags = 0 }; 40 typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val, 41 typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>::type PointerType; 42 }; 43 44 template<typename Axis, typename LhsXprType, typename RhsXprType> 45 struct eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, Eigen::Dense> 46 { 47 typedef const TensorConcatenationOp<Axis, LhsXprType, RhsXprType>& type; 48 }; 49 50 template<typename Axis, typename LhsXprType, typename RhsXprType> 51 struct nested<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, 1, typename eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> >::type> 52 { 53 typedef TensorConcatenationOp<Axis, LhsXprType, RhsXprType> type; 54 }; 55 56 } // end namespace internal 57 58 59 template<typename Axis, typename LhsXprType, typename RhsXprType> 60 class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, WriteAccessors> 61 { 62 public: 63 typedef TensorBase<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, WriteAccessors> Base; 64 typedef typename internal::traits<TensorConcatenationOp>::Scalar Scalar; 65 typedef typename internal::traits<TensorConcatenationOp>::StorageKind StorageKind; 66 typedef typename internal::traits<TensorConcatenationOp>::Index Index; 67 typedef typename internal::nested<TensorConcatenationOp>::type Nested; 68 typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType, 69 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType; 70 typedef typename NumTraits<Scalar>::Real RealScalar; 71 72 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConcatenationOp(const LhsXprType& lhs, const RhsXprType& rhs, Axis axis) 73 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_axis(axis) {} 74 75 EIGEN_DEVICE_FUNC 76 const typename internal::remove_all<typename LhsXprType::Nested>::type& 77 lhsExpression() const { return m_lhs_xpr; } 78 79 EIGEN_DEVICE_FUNC 80 const typename internal::remove_all<typename RhsXprType::Nested>::type& 81 rhsExpression() const { return m_rhs_xpr; } 82 83 EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; } 84 85 EIGEN_TENSOR_INHERIT_ASSIGNMENT_OPERATORS(TensorConcatenationOp) 86 protected: 87 typename LhsXprType::Nested m_lhs_xpr; 88 typename RhsXprType::Nested m_rhs_xpr; 89 const Axis m_axis; 90 }; 91 92 93 // Eval as rvalue 94 template<typename Axis, typename LeftArgType, typename RightArgType, typename Device> 95 struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> 96 { 97 typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType; 98 typedef typename XprType::Index Index; 99 static const int NumDims = internal::array_size<typename TensorEvaluator<LeftArgType, Device>::Dimensions>::value; 100 static const int RightNumDims = internal::array_size<typename TensorEvaluator<RightArgType, Device>::Dimensions>::value; 101 typedef DSizes<Index, NumDims> Dimensions; 102 typedef typename XprType::Scalar Scalar; 103 typedef typename XprType::CoeffReturnType CoeffReturnType; 104 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 105 typedef StorageMemory<CoeffReturnType, Device> Storage; 106 typedef typename Storage::Type EvaluatorPointerType; 107 enum { 108 IsAligned = false, 109 PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess && 110 TensorEvaluator<RightArgType, Device>::PacketAccess, 111 BlockAccess = false, 112 PreferBlockAccess = TensorEvaluator<LeftArgType, Device>::PreferBlockAccess || 113 TensorEvaluator<RightArgType, Device>::PreferBlockAccess, 114 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 115 RawAccess = false 116 }; 117 118 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// 119 typedef internal::TensorBlockNotImplemented TensorBlock; 120 //===--------------------------------------------------------------------===// 121 122 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 123 : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device), m_axis(op.axis()) 124 { 125 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) || NumDims == 1), YOU_MADE_A_PROGRAMMING_MISTAKE); 126 EIGEN_STATIC_ASSERT((NumDims == RightNumDims), YOU_MADE_A_PROGRAMMING_MISTAKE); 127 EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE); 128 129 eigen_assert(0 <= m_axis && m_axis < NumDims); 130 const Dimensions& lhs_dims = m_leftImpl.dimensions(); 131 const Dimensions& rhs_dims = m_rightImpl.dimensions(); 132 { 133 int i = 0; 134 for (; i < m_axis; ++i) { 135 eigen_assert(lhs_dims[i] > 0); 136 eigen_assert(lhs_dims[i] == rhs_dims[i]); 137 m_dimensions[i] = lhs_dims[i]; 138 } 139 eigen_assert(lhs_dims[i] > 0); // Now i == m_axis. 140 eigen_assert(rhs_dims[i] > 0); 141 m_dimensions[i] = lhs_dims[i] + rhs_dims[i]; 142 for (++i; i < NumDims; ++i) { 143 eigen_assert(lhs_dims[i] > 0); 144 eigen_assert(lhs_dims[i] == rhs_dims[i]); 145 m_dimensions[i] = lhs_dims[i]; 146 } 147 } 148 149 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 150 m_leftStrides[0] = 1; 151 m_rightStrides[0] = 1; 152 m_outputStrides[0] = 1; 153 154 for (int j = 1; j < NumDims; ++j) { 155 m_leftStrides[j] = m_leftStrides[j-1] * lhs_dims[j-1]; 156 m_rightStrides[j] = m_rightStrides[j-1] * rhs_dims[j-1]; 157 m_outputStrides[j] = m_outputStrides[j-1] * m_dimensions[j-1]; 158 } 159 } else { 160 m_leftStrides[NumDims - 1] = 1; 161 m_rightStrides[NumDims - 1] = 1; 162 m_outputStrides[NumDims - 1] = 1; 163 164 for (int j = NumDims - 2; j >= 0; --j) { 165 m_leftStrides[j] = m_leftStrides[j+1] * lhs_dims[j+1]; 166 m_rightStrides[j] = m_rightStrides[j+1] * rhs_dims[j+1]; 167 m_outputStrides[j] = m_outputStrides[j+1] * m_dimensions[j+1]; 168 } 169 } 170 } 171 172 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 173 174 // TODO(phli): Add short-circuit memcpy evaluation if underlying data are linear? 175 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) 176 { 177 m_leftImpl.evalSubExprsIfNeeded(NULL); 178 m_rightImpl.evalSubExprsIfNeeded(NULL); 179 return true; 180 } 181 182 EIGEN_STRONG_INLINE void cleanup() 183 { 184 m_leftImpl.cleanup(); 185 m_rightImpl.cleanup(); 186 } 187 188 // TODO(phli): attempt to speed this up. The integer divisions and modulo are slow. 189 // See CL/76180724 comments for more ideas. 190 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const 191 { 192 // Collect dimension-wise indices (subs). 193 array<Index, NumDims> subs; 194 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 195 for (int i = NumDims - 1; i > 0; --i) { 196 subs[i] = index / m_outputStrides[i]; 197 index -= subs[i] * m_outputStrides[i]; 198 } 199 subs[0] = index; 200 } else { 201 for (int i = 0; i < NumDims - 1; ++i) { 202 subs[i] = index / m_outputStrides[i]; 203 index -= subs[i] * m_outputStrides[i]; 204 } 205 subs[NumDims - 1] = index; 206 } 207 208 const Dimensions& left_dims = m_leftImpl.dimensions(); 209 if (subs[m_axis] < left_dims[m_axis]) { 210 Index left_index; 211 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 212 left_index = subs[0]; 213 EIGEN_UNROLL_LOOP 214 for (int i = 1; i < NumDims; ++i) { 215 left_index += (subs[i] % left_dims[i]) * m_leftStrides[i]; 216 } 217 } else { 218 left_index = subs[NumDims - 1]; 219 EIGEN_UNROLL_LOOP 220 for (int i = NumDims - 2; i >= 0; --i) { 221 left_index += (subs[i] % left_dims[i]) * m_leftStrides[i]; 222 } 223 } 224 return m_leftImpl.coeff(left_index); 225 } else { 226 subs[m_axis] -= left_dims[m_axis]; 227 const Dimensions& right_dims = m_rightImpl.dimensions(); 228 Index right_index; 229 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 230 right_index = subs[0]; 231 EIGEN_UNROLL_LOOP 232 for (int i = 1; i < NumDims; ++i) { 233 right_index += (subs[i] % right_dims[i]) * m_rightStrides[i]; 234 } 235 } else { 236 right_index = subs[NumDims - 1]; 237 EIGEN_UNROLL_LOOP 238 for (int i = NumDims - 2; i >= 0; --i) { 239 right_index += (subs[i] % right_dims[i]) * m_rightStrides[i]; 240 } 241 } 242 return m_rightImpl.coeff(right_index); 243 } 244 } 245 246 // TODO(phli): Add a real vectorization. 247 template<int LoadMode> 248 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const 249 { 250 const int packetSize = PacketType<CoeffReturnType, Device>::size; 251 EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 252 eigen_assert(index + packetSize - 1 < dimensions().TotalSize()); 253 254 EIGEN_ALIGN_MAX CoeffReturnType values[packetSize]; 255 EIGEN_UNROLL_LOOP 256 for (int i = 0; i < packetSize; ++i) { 257 values[i] = coeff(index+i); 258 } 259 PacketReturnType rslt = internal::pload<PacketReturnType>(values); 260 return rslt; 261 } 262 263 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 264 costPerCoeff(bool vectorized) const { 265 const double compute_cost = NumDims * (2 * TensorOpCost::AddCost<Index>() + 266 2 * TensorOpCost::MulCost<Index>() + 267 TensorOpCost::DivCost<Index>() + 268 TensorOpCost::ModCost<Index>()); 269 const double lhs_size = m_leftImpl.dimensions().TotalSize(); 270 const double rhs_size = m_rightImpl.dimensions().TotalSize(); 271 return (lhs_size / (lhs_size + rhs_size)) * 272 m_leftImpl.costPerCoeff(vectorized) + 273 (rhs_size / (lhs_size + rhs_size)) * 274 m_rightImpl.costPerCoeff(vectorized) + 275 TensorOpCost(0, 0, compute_cost); 276 } 277 278 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; } 279 280 #ifdef EIGEN_USE_SYCL 281 // binding placeholder accessors to a command group handler for SYCL 282 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const { 283 m_leftImpl.bind(cgh); 284 m_rightImpl.bind(cgh); 285 } 286 #endif 287 288 protected: 289 Dimensions m_dimensions; 290 array<Index, NumDims> m_outputStrides; 291 array<Index, NumDims> m_leftStrides; 292 array<Index, NumDims> m_rightStrides; 293 TensorEvaluator<LeftArgType, Device> m_leftImpl; 294 TensorEvaluator<RightArgType, Device> m_rightImpl; 295 const Axis m_axis; 296 }; 297 298 // Eval as lvalue 299 template<typename Axis, typename LeftArgType, typename RightArgType, typename Device> 300 struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> 301 : public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> 302 { 303 typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base; 304 typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType; 305 typedef typename Base::Dimensions Dimensions; 306 enum { 307 IsAligned = false, 308 PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess && 309 TensorEvaluator<RightArgType, Device>::PacketAccess, 310 BlockAccess = false, 311 PreferBlockAccess = TensorEvaluator<LeftArgType, Device>::PreferBlockAccess || 312 TensorEvaluator<RightArgType, Device>::PreferBlockAccess, 313 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 314 RawAccess = false 315 }; 316 317 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// 318 typedef internal::TensorBlockNotImplemented TensorBlock; 319 //===--------------------------------------------------------------------===// 320 321 EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device) 322 : Base(op, device) 323 { 324 EIGEN_STATIC_ASSERT((static_cast<int>(Layout) == static_cast<int>(ColMajor)), YOU_MADE_A_PROGRAMMING_MISTAKE); 325 } 326 327 typedef typename XprType::Index Index; 328 typedef typename XprType::Scalar Scalar; 329 typedef typename XprType::CoeffReturnType CoeffReturnType; 330 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 331 332 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index) 333 { 334 // Collect dimension-wise indices (subs). 335 array<Index, Base::NumDims> subs; 336 for (int i = Base::NumDims - 1; i > 0; --i) { 337 subs[i] = index / this->m_outputStrides[i]; 338 index -= subs[i] * this->m_outputStrides[i]; 339 } 340 subs[0] = index; 341 342 const Dimensions& left_dims = this->m_leftImpl.dimensions(); 343 if (subs[this->m_axis] < left_dims[this->m_axis]) { 344 Index left_index = subs[0]; 345 for (int i = 1; i < Base::NumDims; ++i) { 346 left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i]; 347 } 348 return this->m_leftImpl.coeffRef(left_index); 349 } else { 350 subs[this->m_axis] -= left_dims[this->m_axis]; 351 const Dimensions& right_dims = this->m_rightImpl.dimensions(); 352 Index right_index = subs[0]; 353 for (int i = 1; i < Base::NumDims; ++i) { 354 right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i]; 355 } 356 return this->m_rightImpl.coeffRef(right_index); 357 } 358 } 359 360 template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 361 void writePacket(Index index, const PacketReturnType& x) 362 { 363 const int packetSize = PacketType<CoeffReturnType, Device>::size; 364 EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 365 eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize()); 366 367 EIGEN_ALIGN_MAX CoeffReturnType values[packetSize]; 368 internal::pstore<CoeffReturnType, PacketReturnType>(values, x); 369 for (int i = 0; i < packetSize; ++i) { 370 coeffRef(index+i) = values[i]; 371 } 372 } 373 }; 374 375 } // end namespace Eigen 376 377 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H 378