1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2008-2014 Gael Guennebaud <[email protected]> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SPARSE_CWISE_BINARY_OP_H 11 #define EIGEN_SPARSE_CWISE_BINARY_OP_H 12 13 namespace Eigen { 14 15 // Here we have to handle 3 cases: 16 // 1 - sparse op dense 17 // 2 - dense op sparse 18 // 3 - sparse op sparse 19 // We also need to implement a 4th iterator for: 20 // 4 - dense op dense 21 // Finally, we also need to distinguish between the product and other operations : 22 // configuration returned mode 23 // 1 - sparse op dense product sparse 24 // generic dense 25 // 2 - dense op sparse product sparse 26 // generic dense 27 // 3 - sparse op sparse product sparse 28 // generic sparse 29 // 4 - dense op dense product dense 30 // generic dense 31 // 32 // TODO to ease compiler job, we could specialize product/quotient with a scalar 33 // and fallback to cwise-unary evaluator using bind1st_op and bind2nd_op. 34 35 template<typename BinaryOp, typename Lhs, typename Rhs> 36 class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse> 37 : public SparseMatrixBase<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > 38 { 39 public: 40 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> Derived; 41 typedef SparseMatrixBase<Derived> Base; 42 EIGEN_SPARSE_PUBLIC_INTERFACE(Derived) CwiseBinaryOpImpl()43 CwiseBinaryOpImpl() 44 { 45 EIGEN_STATIC_ASSERT(( 46 (!internal::is_same<typename internal::traits<Lhs>::StorageKind, 47 typename internal::traits<Rhs>::StorageKind>::value) 48 || ((internal::evaluator<Lhs>::Flags&RowMajorBit) == (internal::evaluator<Rhs>::Flags&RowMajorBit))), 49 THE_STORAGE_ORDER_OF_BOTH_SIDES_MUST_MATCH); 50 } 51 }; 52 53 namespace internal { 54 55 56 // Generic "sparse OP sparse" 57 template<typename XprType> struct binary_sparse_evaluator; 58 59 template<typename BinaryOp, typename Lhs, typename Rhs> 60 struct binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs>, IteratorBased, IteratorBased> 61 : evaluator_base<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > 62 { 63 protected: 64 typedef typename evaluator<Lhs>::InnerIterator LhsIterator; 65 typedef typename evaluator<Rhs>::InnerIterator RhsIterator; 66 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType; 67 typedef typename traits<XprType>::Scalar Scalar; 68 typedef typename XprType::StorageIndex StorageIndex; 69 public: 70 71 class InnerIterator 72 { 73 public: 74 75 EIGEN_STRONG_INLINE InnerIterator(const binary_evaluator& aEval, Index outer) 76 : m_lhsIter(aEval.m_lhsImpl,outer), m_rhsIter(aEval.m_rhsImpl,outer), m_functor(aEval.m_functor) 77 { 78 this->operator++(); 79 } 80 81 EIGEN_STRONG_INLINE InnerIterator& operator++() 82 { 83 if (m_lhsIter && m_rhsIter && (m_lhsIter.index() == m_rhsIter.index())) 84 { 85 m_id = m_lhsIter.index(); 86 m_value = m_functor(m_lhsIter.value(), m_rhsIter.value()); 87 ++m_lhsIter; 88 ++m_rhsIter; 89 } 90 else if (m_lhsIter && (!m_rhsIter || (m_lhsIter.index() < m_rhsIter.index()))) 91 { 92 m_id = m_lhsIter.index(); 93 m_value = m_functor(m_lhsIter.value(), Scalar(0)); 94 ++m_lhsIter; 95 } 96 else if (m_rhsIter && (!m_lhsIter || (m_lhsIter.index() > m_rhsIter.index()))) 97 { 98 m_id = m_rhsIter.index(); 99 m_value = m_functor(Scalar(0), m_rhsIter.value()); 100 ++m_rhsIter; 101 } 102 else 103 { 104 m_value = Scalar(0); // this is to avoid a compilation warning 105 m_id = -1; 106 } 107 return *this; 108 } 109 110 EIGEN_STRONG_INLINE Scalar value() const { return m_value; } 111 112 EIGEN_STRONG_INLINE StorageIndex index() const { return m_id; } 113 EIGEN_STRONG_INLINE Index outer() const { return m_lhsIter.outer(); } 114 EIGEN_STRONG_INLINE Index row() const { return Lhs::IsRowMajor ? m_lhsIter.row() : index(); } 115 EIGEN_STRONG_INLINE Index col() const { return Lhs::IsRowMajor ? index() : m_lhsIter.col(); } 116 117 EIGEN_STRONG_INLINE operator bool() const { return m_id>=0; } 118 119 protected: 120 LhsIterator m_lhsIter; 121 RhsIterator m_rhsIter; 122 const BinaryOp& m_functor; 123 Scalar m_value; 124 StorageIndex m_id; 125 }; 126 127 128 enum { 129 CoeffReadCost = int(evaluator<Lhs>::CoeffReadCost) + int(evaluator<Rhs>::CoeffReadCost) + int(functor_traits<BinaryOp>::Cost), 130 Flags = XprType::Flags 131 }; 132 133 explicit binary_evaluator(const XprType& xpr) 134 : m_functor(xpr.functor()), 135 m_lhsImpl(xpr.lhs()), 136 m_rhsImpl(xpr.rhs()) 137 { 138 EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<BinaryOp>::Cost); 139 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 140 } 141 142 inline Index nonZerosEstimate() const { 143 return m_lhsImpl.nonZerosEstimate() + m_rhsImpl.nonZerosEstimate(); 144 } 145 146 protected: 147 const BinaryOp m_functor; 148 evaluator<Lhs> m_lhsImpl; 149 evaluator<Rhs> m_rhsImpl; 150 }; 151 152 // dense op sparse 153 template<typename BinaryOp, typename Lhs, typename Rhs> 154 struct binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs>, IndexBased, IteratorBased> 155 : evaluator_base<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > 156 { 157 protected: 158 typedef typename evaluator<Rhs>::InnerIterator RhsIterator; 159 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType; 160 typedef typename traits<XprType>::Scalar Scalar; 161 typedef typename XprType::StorageIndex StorageIndex; 162 public: 163 164 class InnerIterator 165 { 166 enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit }; 167 public: 168 169 EIGEN_STRONG_INLINE InnerIterator(const binary_evaluator& aEval, Index outer) 170 : m_lhsEval(aEval.m_lhsImpl), m_rhsIter(aEval.m_rhsImpl,outer), m_functor(aEval.m_functor), m_value(0), m_id(-1), m_innerSize(aEval.m_expr.rhs().innerSize()) 171 { 172 this->operator++(); 173 } 174 175 EIGEN_STRONG_INLINE InnerIterator& operator++() 176 { 177 ++m_id; 178 if(m_id<m_innerSize) 179 { 180 Scalar lhsVal = m_lhsEval.coeff(IsRowMajor?m_rhsIter.outer():m_id, 181 IsRowMajor?m_id:m_rhsIter.outer()); 182 if(m_rhsIter && m_rhsIter.index()==m_id) 183 { 184 m_value = m_functor(lhsVal, m_rhsIter.value()); 185 ++m_rhsIter; 186 } 187 else 188 m_value = m_functor(lhsVal, Scalar(0)); 189 } 190 191 return *this; 192 } 193 194 EIGEN_STRONG_INLINE Scalar value() const { eigen_internal_assert(m_id<m_innerSize); return m_value; } 195 196 EIGEN_STRONG_INLINE StorageIndex index() const { return m_id; } 197 EIGEN_STRONG_INLINE Index outer() const { return m_rhsIter.outer(); } 198 EIGEN_STRONG_INLINE Index row() const { return IsRowMajor ? m_rhsIter.outer() : m_id; } 199 EIGEN_STRONG_INLINE Index col() const { return IsRowMajor ? m_id : m_rhsIter.outer(); } 200 201 EIGEN_STRONG_INLINE operator bool() const { return m_id<m_innerSize; } 202 203 protected: 204 const evaluator<Lhs> &m_lhsEval; 205 RhsIterator m_rhsIter; 206 const BinaryOp& m_functor; 207 Scalar m_value; 208 StorageIndex m_id; 209 StorageIndex m_innerSize; 210 }; 211 212 213 enum { 214 CoeffReadCost = int(evaluator<Lhs>::CoeffReadCost) + int(evaluator<Rhs>::CoeffReadCost) + int(functor_traits<BinaryOp>::Cost), 215 Flags = XprType::Flags 216 }; 217 218 explicit binary_evaluator(const XprType& xpr) 219 : m_functor(xpr.functor()), 220 m_lhsImpl(xpr.lhs()), 221 m_rhsImpl(xpr.rhs()), 222 m_expr(xpr) 223 { 224 EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<BinaryOp>::Cost); 225 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 226 } 227 228 inline Index nonZerosEstimate() const { 229 return m_expr.size(); 230 } 231 232 protected: 233 const BinaryOp m_functor; 234 evaluator<Lhs> m_lhsImpl; 235 evaluator<Rhs> m_rhsImpl; 236 const XprType &m_expr; 237 }; 238 239 // sparse op dense 240 template<typename BinaryOp, typename Lhs, typename Rhs> 241 struct binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs>, IteratorBased, IndexBased> 242 : evaluator_base<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > 243 { 244 protected: 245 typedef typename evaluator<Lhs>::InnerIterator LhsIterator; 246 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType; 247 typedef typename traits<XprType>::Scalar Scalar; 248 typedef typename XprType::StorageIndex StorageIndex; 249 public: 250 251 class InnerIterator 252 { 253 enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit }; 254 public: 255 256 EIGEN_STRONG_INLINE InnerIterator(const binary_evaluator& aEval, Index outer) 257 : m_lhsIter(aEval.m_lhsImpl,outer), m_rhsEval(aEval.m_rhsImpl), m_functor(aEval.m_functor), m_value(0), m_id(-1), m_innerSize(aEval.m_expr.lhs().innerSize()) 258 { 259 this->operator++(); 260 } 261 262 EIGEN_STRONG_INLINE InnerIterator& operator++() 263 { 264 ++m_id; 265 if(m_id<m_innerSize) 266 { 267 Scalar rhsVal = m_rhsEval.coeff(IsRowMajor?m_lhsIter.outer():m_id, 268 IsRowMajor?m_id:m_lhsIter.outer()); 269 if(m_lhsIter && m_lhsIter.index()==m_id) 270 { 271 m_value = m_functor(m_lhsIter.value(), rhsVal); 272 ++m_lhsIter; 273 } 274 else 275 m_value = m_functor(Scalar(0),rhsVal); 276 } 277 278 return *this; 279 } 280 281 EIGEN_STRONG_INLINE Scalar value() const { eigen_internal_assert(m_id<m_innerSize); return m_value; } 282 283 EIGEN_STRONG_INLINE StorageIndex index() const { return m_id; } 284 EIGEN_STRONG_INLINE Index outer() const { return m_lhsIter.outer(); } 285 EIGEN_STRONG_INLINE Index row() const { return IsRowMajor ? m_lhsIter.outer() : m_id; } 286 EIGEN_STRONG_INLINE Index col() const { return IsRowMajor ? m_id : m_lhsIter.outer(); } 287 288 EIGEN_STRONG_INLINE operator bool() const { return m_id<m_innerSize; } 289 290 protected: 291 LhsIterator m_lhsIter; 292 const evaluator<Rhs> &m_rhsEval; 293 const BinaryOp& m_functor; 294 Scalar m_value; 295 StorageIndex m_id; 296 StorageIndex m_innerSize; 297 }; 298 299 300 enum { 301 CoeffReadCost = int(evaluator<Lhs>::CoeffReadCost) + int(evaluator<Rhs>::CoeffReadCost) + int(functor_traits<BinaryOp>::Cost), 302 Flags = XprType::Flags 303 }; 304 305 explicit binary_evaluator(const XprType& xpr) 306 : m_functor(xpr.functor()), 307 m_lhsImpl(xpr.lhs()), 308 m_rhsImpl(xpr.rhs()), 309 m_expr(xpr) 310 { 311 EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<BinaryOp>::Cost); 312 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 313 } 314 315 inline Index nonZerosEstimate() const { 316 return m_expr.size(); 317 } 318 319 protected: 320 const BinaryOp m_functor; 321 evaluator<Lhs> m_lhsImpl; 322 evaluator<Rhs> m_rhsImpl; 323 const XprType &m_expr; 324 }; 325 326 template<typename T, 327 typename LhsKind = typename evaluator_traits<typename T::Lhs>::Kind, 328 typename RhsKind = typename evaluator_traits<typename T::Rhs>::Kind, 329 typename LhsScalar = typename traits<typename T::Lhs>::Scalar, 330 typename RhsScalar = typename traits<typename T::Rhs>::Scalar> struct sparse_conjunction_evaluator; 331 332 // "sparse .* sparse" 333 template<typename T1, typename T2, typename Lhs, typename Rhs> 334 struct binary_evaluator<CwiseBinaryOp<scalar_product_op<T1,T2>, Lhs, Rhs>, IteratorBased, IteratorBased> 335 : sparse_conjunction_evaluator<CwiseBinaryOp<scalar_product_op<T1,T2>, Lhs, Rhs> > 336 { 337 typedef CwiseBinaryOp<scalar_product_op<T1,T2>, Lhs, Rhs> XprType; 338 typedef sparse_conjunction_evaluator<XprType> Base; 339 explicit binary_evaluator(const XprType& xpr) : Base(xpr) {} 340 }; 341 // "dense .* sparse" 342 template<typename T1, typename T2, typename Lhs, typename Rhs> 343 struct binary_evaluator<CwiseBinaryOp<scalar_product_op<T1,T2>, Lhs, Rhs>, IndexBased, IteratorBased> 344 : sparse_conjunction_evaluator<CwiseBinaryOp<scalar_product_op<T1,T2>, Lhs, Rhs> > 345 { 346 typedef CwiseBinaryOp<scalar_product_op<T1,T2>, Lhs, Rhs> XprType; 347 typedef sparse_conjunction_evaluator<XprType> Base; 348 explicit binary_evaluator(const XprType& xpr) : Base(xpr) {} 349 }; 350 // "sparse .* dense" 351 template<typename T1, typename T2, typename Lhs, typename Rhs> 352 struct binary_evaluator<CwiseBinaryOp<scalar_product_op<T1,T2>, Lhs, Rhs>, IteratorBased, IndexBased> 353 : sparse_conjunction_evaluator<CwiseBinaryOp<scalar_product_op<T1,T2>, Lhs, Rhs> > 354 { 355 typedef CwiseBinaryOp<scalar_product_op<T1,T2>, Lhs, Rhs> XprType; 356 typedef sparse_conjunction_evaluator<XprType> Base; 357 explicit binary_evaluator(const XprType& xpr) : Base(xpr) {} 358 }; 359 360 // "sparse ./ dense" 361 template<typename T1, typename T2, typename Lhs, typename Rhs> 362 struct binary_evaluator<CwiseBinaryOp<scalar_quotient_op<T1,T2>, Lhs, Rhs>, IteratorBased, IndexBased> 363 : sparse_conjunction_evaluator<CwiseBinaryOp<scalar_quotient_op<T1,T2>, Lhs, Rhs> > 364 { 365 typedef CwiseBinaryOp<scalar_quotient_op<T1,T2>, Lhs, Rhs> XprType; 366 typedef sparse_conjunction_evaluator<XprType> Base; 367 explicit binary_evaluator(const XprType& xpr) : Base(xpr) {} 368 }; 369 370 // "sparse && sparse" 371 template<typename Lhs, typename Rhs> 372 struct binary_evaluator<CwiseBinaryOp<scalar_boolean_and_op, Lhs, Rhs>, IteratorBased, IteratorBased> 373 : sparse_conjunction_evaluator<CwiseBinaryOp<scalar_boolean_and_op, Lhs, Rhs> > 374 { 375 typedef CwiseBinaryOp<scalar_boolean_and_op, Lhs, Rhs> XprType; 376 typedef sparse_conjunction_evaluator<XprType> Base; 377 explicit binary_evaluator(const XprType& xpr) : Base(xpr) {} 378 }; 379 // "dense && sparse" 380 template<typename Lhs, typename Rhs> 381 struct binary_evaluator<CwiseBinaryOp<scalar_boolean_and_op, Lhs, Rhs>, IndexBased, IteratorBased> 382 : sparse_conjunction_evaluator<CwiseBinaryOp<scalar_boolean_and_op, Lhs, Rhs> > 383 { 384 typedef CwiseBinaryOp<scalar_boolean_and_op, Lhs, Rhs> XprType; 385 typedef sparse_conjunction_evaluator<XprType> Base; 386 explicit binary_evaluator(const XprType& xpr) : Base(xpr) {} 387 }; 388 // "sparse && dense" 389 template<typename Lhs, typename Rhs> 390 struct binary_evaluator<CwiseBinaryOp<scalar_boolean_and_op, Lhs, Rhs>, IteratorBased, IndexBased> 391 : sparse_conjunction_evaluator<CwiseBinaryOp<scalar_boolean_and_op, Lhs, Rhs> > 392 { 393 typedef CwiseBinaryOp<scalar_boolean_and_op, Lhs, Rhs> XprType; 394 typedef sparse_conjunction_evaluator<XprType> Base; 395 explicit binary_evaluator(const XprType& xpr) : Base(xpr) {} 396 }; 397 398 // "sparse ^ sparse" 399 template<typename XprType> 400 struct sparse_conjunction_evaluator<XprType, IteratorBased, IteratorBased> 401 : evaluator_base<XprType> 402 { 403 protected: 404 typedef typename XprType::Functor BinaryOp; 405 typedef typename XprType::Lhs LhsArg; 406 typedef typename XprType::Rhs RhsArg; 407 typedef typename evaluator<LhsArg>::InnerIterator LhsIterator; 408 typedef typename evaluator<RhsArg>::InnerIterator RhsIterator; 409 typedef typename XprType::StorageIndex StorageIndex; 410 typedef typename traits<XprType>::Scalar Scalar; 411 public: 412 413 class InnerIterator 414 { 415 public: 416 417 EIGEN_STRONG_INLINE InnerIterator(const sparse_conjunction_evaluator& aEval, Index outer) 418 : m_lhsIter(aEval.m_lhsImpl,outer), m_rhsIter(aEval.m_rhsImpl,outer), m_functor(aEval.m_functor) 419 { 420 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index())) 421 { 422 if (m_lhsIter.index() < m_rhsIter.index()) 423 ++m_lhsIter; 424 else 425 ++m_rhsIter; 426 } 427 } 428 429 EIGEN_STRONG_INLINE InnerIterator& operator++() 430 { 431 ++m_lhsIter; 432 ++m_rhsIter; 433 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index())) 434 { 435 if (m_lhsIter.index() < m_rhsIter.index()) 436 ++m_lhsIter; 437 else 438 ++m_rhsIter; 439 } 440 return *this; 441 } 442 443 EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_lhsIter.value(), m_rhsIter.value()); } 444 445 EIGEN_STRONG_INLINE StorageIndex index() const { return m_lhsIter.index(); } 446 EIGEN_STRONG_INLINE Index outer() const { return m_lhsIter.outer(); } 447 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); } 448 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); } 449 450 EIGEN_STRONG_INLINE operator bool() const { return (m_lhsIter && m_rhsIter); } 451 452 protected: 453 LhsIterator m_lhsIter; 454 RhsIterator m_rhsIter; 455 const BinaryOp& m_functor; 456 }; 457 458 459 enum { 460 CoeffReadCost = int(evaluator<LhsArg>::CoeffReadCost) + int(evaluator<RhsArg>::CoeffReadCost) + int(functor_traits<BinaryOp>::Cost), 461 Flags = XprType::Flags 462 }; 463 464 explicit sparse_conjunction_evaluator(const XprType& xpr) 465 : m_functor(xpr.functor()), 466 m_lhsImpl(xpr.lhs()), 467 m_rhsImpl(xpr.rhs()) 468 { 469 EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<BinaryOp>::Cost); 470 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 471 } 472 473 inline Index nonZerosEstimate() const { 474 return (std::min)(m_lhsImpl.nonZerosEstimate(), m_rhsImpl.nonZerosEstimate()); 475 } 476 477 protected: 478 const BinaryOp m_functor; 479 evaluator<LhsArg> m_lhsImpl; 480 evaluator<RhsArg> m_rhsImpl; 481 }; 482 483 // "dense ^ sparse" 484 template<typename XprType> 485 struct sparse_conjunction_evaluator<XprType, IndexBased, IteratorBased> 486 : evaluator_base<XprType> 487 { 488 protected: 489 typedef typename XprType::Functor BinaryOp; 490 typedef typename XprType::Lhs LhsArg; 491 typedef typename XprType::Rhs RhsArg; 492 typedef evaluator<LhsArg> LhsEvaluator; 493 typedef typename evaluator<RhsArg>::InnerIterator RhsIterator; 494 typedef typename XprType::StorageIndex StorageIndex; 495 typedef typename traits<XprType>::Scalar Scalar; 496 public: 497 498 class InnerIterator 499 { 500 enum { IsRowMajor = (int(RhsArg::Flags)&RowMajorBit)==RowMajorBit }; 501 502 public: 503 504 EIGEN_STRONG_INLINE InnerIterator(const sparse_conjunction_evaluator& aEval, Index outer) 505 : m_lhsEval(aEval.m_lhsImpl), m_rhsIter(aEval.m_rhsImpl,outer), m_functor(aEval.m_functor), m_outer(outer) 506 {} 507 508 EIGEN_STRONG_INLINE InnerIterator& operator++() 509 { 510 ++m_rhsIter; 511 return *this; 512 } 513 514 EIGEN_STRONG_INLINE Scalar value() const 515 { return m_functor(m_lhsEval.coeff(IsRowMajor?m_outer:m_rhsIter.index(),IsRowMajor?m_rhsIter.index():m_outer), m_rhsIter.value()); } 516 517 EIGEN_STRONG_INLINE StorageIndex index() const { return m_rhsIter.index(); } 518 EIGEN_STRONG_INLINE Index outer() const { return m_rhsIter.outer(); } 519 EIGEN_STRONG_INLINE Index row() const { return m_rhsIter.row(); } 520 EIGEN_STRONG_INLINE Index col() const { return m_rhsIter.col(); } 521 522 EIGEN_STRONG_INLINE operator bool() const { return m_rhsIter; } 523 524 protected: 525 const LhsEvaluator &m_lhsEval; 526 RhsIterator m_rhsIter; 527 const BinaryOp& m_functor; 528 const Index m_outer; 529 }; 530 531 532 enum { 533 CoeffReadCost = int(evaluator<LhsArg>::CoeffReadCost) + int(evaluator<RhsArg>::CoeffReadCost) + int(functor_traits<BinaryOp>::Cost), 534 Flags = XprType::Flags 535 }; 536 537 explicit sparse_conjunction_evaluator(const XprType& xpr) 538 : m_functor(xpr.functor()), 539 m_lhsImpl(xpr.lhs()), 540 m_rhsImpl(xpr.rhs()) 541 { 542 EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<BinaryOp>::Cost); 543 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 544 } 545 546 inline Index nonZerosEstimate() const { 547 return m_rhsImpl.nonZerosEstimate(); 548 } 549 550 protected: 551 const BinaryOp m_functor; 552 evaluator<LhsArg> m_lhsImpl; 553 evaluator<RhsArg> m_rhsImpl; 554 }; 555 556 // "sparse ^ dense" 557 template<typename XprType> 558 struct sparse_conjunction_evaluator<XprType, IteratorBased, IndexBased> 559 : evaluator_base<XprType> 560 { 561 protected: 562 typedef typename XprType::Functor BinaryOp; 563 typedef typename XprType::Lhs LhsArg; 564 typedef typename XprType::Rhs RhsArg; 565 typedef typename evaluator<LhsArg>::InnerIterator LhsIterator; 566 typedef evaluator<RhsArg> RhsEvaluator; 567 typedef typename XprType::StorageIndex StorageIndex; 568 typedef typename traits<XprType>::Scalar Scalar; 569 public: 570 571 class InnerIterator 572 { 573 enum { IsRowMajor = (int(LhsArg::Flags)&RowMajorBit)==RowMajorBit }; 574 575 public: 576 577 EIGEN_STRONG_INLINE InnerIterator(const sparse_conjunction_evaluator& aEval, Index outer) 578 : m_lhsIter(aEval.m_lhsImpl,outer), m_rhsEval(aEval.m_rhsImpl), m_functor(aEval.m_functor), m_outer(outer) 579 {} 580 581 EIGEN_STRONG_INLINE InnerIterator& operator++() 582 { 583 ++m_lhsIter; 584 return *this; 585 } 586 587 EIGEN_STRONG_INLINE Scalar value() const 588 { return m_functor(m_lhsIter.value(), 589 m_rhsEval.coeff(IsRowMajor?m_outer:m_lhsIter.index(),IsRowMajor?m_lhsIter.index():m_outer)); } 590 591 EIGEN_STRONG_INLINE StorageIndex index() const { return m_lhsIter.index(); } 592 EIGEN_STRONG_INLINE Index outer() const { return m_lhsIter.outer(); } 593 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); } 594 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); } 595 596 EIGEN_STRONG_INLINE operator bool() const { return m_lhsIter; } 597 598 protected: 599 LhsIterator m_lhsIter; 600 const evaluator<RhsArg> &m_rhsEval; 601 const BinaryOp& m_functor; 602 const Index m_outer; 603 }; 604 605 606 enum { 607 CoeffReadCost = int(evaluator<LhsArg>::CoeffReadCost) + int(evaluator<RhsArg>::CoeffReadCost) + int(functor_traits<BinaryOp>::Cost), 608 Flags = XprType::Flags 609 }; 610 611 explicit sparse_conjunction_evaluator(const XprType& xpr) 612 : m_functor(xpr.functor()), 613 m_lhsImpl(xpr.lhs()), 614 m_rhsImpl(xpr.rhs()) 615 { 616 EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<BinaryOp>::Cost); 617 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 618 } 619 620 inline Index nonZerosEstimate() const { 621 return m_lhsImpl.nonZerosEstimate(); 622 } 623 624 protected: 625 const BinaryOp m_functor; 626 evaluator<LhsArg> m_lhsImpl; 627 evaluator<RhsArg> m_rhsImpl; 628 }; 629 630 } 631 632 /*************************************************************************** 633 * Implementation of SparseMatrixBase and SparseCwise functions/operators 634 ***************************************************************************/ 635 636 template<typename Derived> 637 template<typename OtherDerived> 638 Derived& SparseMatrixBase<Derived>::operator+=(const EigenBase<OtherDerived> &other) 639 { 640 call_assignment(derived(), other.derived(), internal::add_assign_op<Scalar,typename OtherDerived::Scalar>()); 641 return derived(); 642 } 643 644 template<typename Derived> 645 template<typename OtherDerived> 646 Derived& SparseMatrixBase<Derived>::operator-=(const EigenBase<OtherDerived> &other) 647 { 648 call_assignment(derived(), other.derived(), internal::assign_op<Scalar,typename OtherDerived::Scalar>()); 649 return derived(); 650 } 651 652 template<typename Derived> 653 template<typename OtherDerived> 654 EIGEN_STRONG_INLINE Derived & 655 SparseMatrixBase<Derived>::operator-=(const SparseMatrixBase<OtherDerived> &other) 656 { 657 return derived() = derived() - other.derived(); 658 } 659 660 template<typename Derived> 661 template<typename OtherDerived> 662 EIGEN_STRONG_INLINE Derived & 663 SparseMatrixBase<Derived>::operator+=(const SparseMatrixBase<OtherDerived>& other) 664 { 665 return derived() = derived() + other.derived(); 666 } 667 668 template<typename Derived> 669 template<typename OtherDerived> 670 Derived& SparseMatrixBase<Derived>::operator+=(const DiagonalBase<OtherDerived>& other) 671 { 672 call_assignment_no_alias(derived(), other.derived(), internal::add_assign_op<Scalar,typename OtherDerived::Scalar>()); 673 return derived(); 674 } 675 676 template<typename Derived> 677 template<typename OtherDerived> 678 Derived& SparseMatrixBase<Derived>::operator-=(const DiagonalBase<OtherDerived>& other) 679 { 680 call_assignment_no_alias(derived(), other.derived(), internal::sub_assign_op<Scalar,typename OtherDerived::Scalar>()); 681 return derived(); 682 } 683 684 template<typename Derived> 685 template<typename OtherDerived> 686 EIGEN_STRONG_INLINE const typename SparseMatrixBase<Derived>::template CwiseProductDenseReturnType<OtherDerived>::Type 687 SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const 688 { 689 return typename CwiseProductDenseReturnType<OtherDerived>::Type(derived(), other.derived()); 690 } 691 692 template<typename DenseDerived, typename SparseDerived> 693 EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_sum_op<typename DenseDerived::Scalar,typename SparseDerived::Scalar>, const DenseDerived, const SparseDerived> 694 operator+(const MatrixBase<DenseDerived> &a, const SparseMatrixBase<SparseDerived> &b) 695 { 696 return CwiseBinaryOp<internal::scalar_sum_op<typename DenseDerived::Scalar,typename SparseDerived::Scalar>, const DenseDerived, const SparseDerived>(a.derived(), b.derived()); 697 } 698 699 template<typename SparseDerived, typename DenseDerived> 700 EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_sum_op<typename SparseDerived::Scalar,typename DenseDerived::Scalar>, const SparseDerived, const DenseDerived> 701 operator+(const SparseMatrixBase<SparseDerived> &a, const MatrixBase<DenseDerived> &b) 702 { 703 return CwiseBinaryOp<internal::scalar_sum_op<typename SparseDerived::Scalar,typename DenseDerived::Scalar>, const SparseDerived, const DenseDerived>(a.derived(), b.derived()); 704 } 705 706 template<typename DenseDerived, typename SparseDerived> 707 EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_difference_op<typename DenseDerived::Scalar,typename SparseDerived::Scalar>, const DenseDerived, const SparseDerived> 708 operator-(const MatrixBase<DenseDerived> &a, const SparseMatrixBase<SparseDerived> &b) 709 { 710 return CwiseBinaryOp<internal::scalar_difference_op<typename DenseDerived::Scalar,typename SparseDerived::Scalar>, const DenseDerived, const SparseDerived>(a.derived(), b.derived()); 711 } 712 713 template<typename SparseDerived, typename DenseDerived> 714 EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_difference_op<typename SparseDerived::Scalar,typename DenseDerived::Scalar>, const SparseDerived, const DenseDerived> 715 operator-(const SparseMatrixBase<SparseDerived> &a, const MatrixBase<DenseDerived> &b) 716 { 717 return CwiseBinaryOp<internal::scalar_difference_op<typename SparseDerived::Scalar,typename DenseDerived::Scalar>, const SparseDerived, const DenseDerived>(a.derived(), b.derived()); 718 } 719 720 } // end namespace Eigen 721 722 #endif // EIGEN_SPARSE_CWISE_BINARY_OP_H 723