1 //
2 //  Copyright (c) 2018-2019, Cem Bassoy, [email protected]
3 //
4 //  Distributed under the Boost Software License, Version 1.0. (See
5 //  accompanying file LICENSE_1_0.txt or copy at
6 //  http://www.boost.org/LICENSE_1_0.txt)
7 //
8 //  The authors gratefully acknowledge the support of
9 //  Fraunhofer IOSB, Ettlingen, Germany
10 //
11 
12 #ifndef BOOST_UBLAS_TENSOR_OPERATORS_COMPARISON_HPP
13 #define BOOST_UBLAS_TENSOR_OPERATORS_COMPARISON_HPP
14 
15 #include <boost/numeric/ublas/tensor/expression.hpp>
16 #include <boost/numeric/ublas/tensor/expression_evaluation.hpp>
17 #include <type_traits>
18 #include <functional>
19 
20 namespace boost::numeric::ublas {
21 template<class element_type, class storage_format, class storage_type>
22 class tensor;
23 }
24 
25 namespace boost::numeric::ublas::detail {
26 
27 template<class T, class F, class A, class BinaryPred>
compare(tensor<T,F,A> const & lhs,tensor<T,F,A> const & rhs,BinaryPred pred)28 bool compare(tensor<T,F,A> const& lhs, tensor<T,F,A> const& rhs, BinaryPred pred)
29 {
30 
31 	if(lhs.extents() != rhs.extents()){
32 		if constexpr(!std::is_same<BinaryPred,std::equal_to<>>::value && !std::is_same<BinaryPred,std::not_equal_to<>>::value)
33 			throw std::runtime_error("Error in boost::numeric::ublas::detail::compare: cannot compare tensors with different shapes.");
34 		else
35 			return false;
36 	}
37 
38 	if constexpr(std::is_same<BinaryPred,std::greater<>>::value || std::is_same<BinaryPred,std::less<>>::value)
39 		if(lhs.empty())
40 			return false;
41 
42 	for(auto i = 0u; i < lhs.size(); ++i)
43 		if(!pred(lhs(i), rhs(i)))
44 			return false;
45 	return true;
46 }
47 
48 template<class T, class F, class A, class UnaryPred>
compare(tensor<T,F,A> const & rhs,UnaryPred pred)49 bool compare(tensor<T,F,A> const& rhs, UnaryPred pred)
50 {
51 	for(auto i = 0u; i < rhs.size(); ++i)
52 		if(!pred(rhs(i)))
53 			return false;
54 	return true;
55 }
56 
57 
58 template<class T, class L, class R, class BinaryPred>
compare(tensor_expression<T,L> const & lhs,tensor_expression<T,R> const & rhs,BinaryPred pred)59 bool compare(tensor_expression<T,L> const& lhs, tensor_expression<T,R> const& rhs, BinaryPred pred)
60 {
61 	constexpr bool lhs_is_tensor = std::is_same<T,L>::value;
62 	constexpr bool rhs_is_tensor = std::is_same<T,R>::value;
63 
64 	if constexpr (lhs_is_tensor && rhs_is_tensor)
65 		return compare(static_cast<T const&>( lhs ), static_cast<T const&>( rhs ), pred);
66 	else if constexpr (lhs_is_tensor && !rhs_is_tensor)
67 		return compare(static_cast<T const&>( lhs ), T( rhs ), pred);
68 	else if constexpr (!lhs_is_tensor && rhs_is_tensor)
69 		return compare(T( lhs ), static_cast<T const&>( rhs ), pred);
70 	else
71 		return compare(T( lhs ), T( rhs ), pred);
72 
73 }
74 
75 template<class T, class D, class UnaryPred>
compare(tensor_expression<T,D> const & expr,UnaryPred pred)76 bool compare(tensor_expression<T,D> const& expr, UnaryPred pred)
77 {
78 	if constexpr (std::is_same<T,D>::value)
79 		return compare(static_cast<T const&>( expr ), pred);
80 	else
81 		return compare(T( expr ), pred);
82 }
83 
84 }
85 
86 
87 template<class T, class L, class R>
operator ==(boost::numeric::ublas::detail::tensor_expression<T,L> const & lhs,boost::numeric::ublas::detail::tensor_expression<T,R> const & rhs)88 bool operator==( boost::numeric::ublas::detail::tensor_expression<T,L> const& lhs,
89 								 boost::numeric::ublas::detail::tensor_expression<T,R> const& rhs) {
90 	return boost::numeric::ublas::detail::compare( lhs, rhs, std::equal_to<>{} );
91 }
92 template<class T, class L, class R>
operator !=(boost::numeric::ublas::detail::tensor_expression<T,L> const & lhs,boost::numeric::ublas::detail::tensor_expression<T,R> const & rhs)93 auto operator!=(boost::numeric::ublas::detail::tensor_expression<T,L> const& lhs,
94 								boost::numeric::ublas::detail::tensor_expression<T,R> const& rhs) {
95 	return boost::numeric::ublas::detail::compare( lhs, rhs, std::not_equal_to<>{}  );
96 }
97 template<class T, class L, class R>
operator <(boost::numeric::ublas::detail::tensor_expression<T,L> const & lhs,boost::numeric::ublas::detail::tensor_expression<T,R> const & rhs)98 auto operator< ( boost::numeric::ublas::detail::tensor_expression<T,L> const& lhs,
99 								 boost::numeric::ublas::detail::tensor_expression<T,R> const& rhs) {
100 	return boost::numeric::ublas::detail::compare( lhs, rhs, std::less<>{} );
101 }
102 template<class T, class L, class R>
operator <=(boost::numeric::ublas::detail::tensor_expression<T,L> const & lhs,boost::numeric::ublas::detail::tensor_expression<T,R> const & rhs)103 auto operator<=( boost::numeric::ublas::detail::tensor_expression<T,L> const& lhs,
104 								 boost::numeric::ublas::detail::tensor_expression<T,R> const& rhs) {
105 	return boost::numeric::ublas::detail::compare( lhs, rhs, std::less_equal<>{} );
106 }
107 template<class T, class L, class R>
operator >(boost::numeric::ublas::detail::tensor_expression<T,L> const & lhs,boost::numeric::ublas::detail::tensor_expression<T,R> const & rhs)108 auto operator> ( boost::numeric::ublas::detail::tensor_expression<T,L> const& lhs,
109 								 boost::numeric::ublas::detail::tensor_expression<T,R> const& rhs) {
110 	return boost::numeric::ublas::detail::compare( lhs, rhs, std::greater<>{} );
111 }
112 template<class T, class L, class R>
operator >=(boost::numeric::ublas::detail::tensor_expression<T,L> const & lhs,boost::numeric::ublas::detail::tensor_expression<T,R> const & rhs)113 auto operator>=( boost::numeric::ublas::detail::tensor_expression<T,L> const& lhs,
114 								 boost::numeric::ublas::detail::tensor_expression<T,R> const& rhs) {
115 	return boost::numeric::ublas::detail::compare( lhs, rhs, std::greater_equal<>{} );
116 }
117 
118 
119 
120 
121 
122 template<class T, class D>
operator ==(typename T::const_reference lhs,boost::numeric::ublas::detail::tensor_expression<T,D> const & rhs)123 bool operator==( typename T::const_reference lhs, boost::numeric::ublas::detail::tensor_expression<T,D> const& rhs) {
124 	return boost::numeric::ublas::detail::compare( rhs, [lhs](auto const& r){ return lhs == r; } );
125 }
126 template<class T, class D>
operator !=(typename T::const_reference lhs,boost::numeric::ublas::detail::tensor_expression<T,D> const & rhs)127 auto operator!=( typename T::const_reference lhs, boost::numeric::ublas::detail::tensor_expression<T,D> const& rhs) {
128 	return boost::numeric::ublas::detail::compare( rhs, [lhs](auto const& r){ return lhs != r; } );
129 }
130 template<class T, class D>
operator <(typename T::const_reference lhs,boost::numeric::ublas::detail::tensor_expression<T,D> const & rhs)131 auto operator< ( typename T::const_reference lhs, boost::numeric::ublas::detail::tensor_expression<T,D> const& rhs) {
132 	return boost::numeric::ublas::detail::compare( rhs, [lhs](auto const& r){ return lhs <  r; } );
133 }
134 template<class T, class D>
operator <=(typename T::const_reference lhs,boost::numeric::ublas::detail::tensor_expression<T,D> const & rhs)135 auto operator<=( typename T::const_reference lhs, boost::numeric::ublas::detail::tensor_expression<T,D> const& rhs) {
136 	return boost::numeric::ublas::detail::compare( rhs, [lhs](auto const& r){ return lhs <= r; } );
137 }
138 template<class T, class D>
operator >(typename T::const_reference lhs,boost::numeric::ublas::detail::tensor_expression<T,D> const & rhs)139 auto operator> ( typename T::const_reference lhs, boost::numeric::ublas::detail::tensor_expression<T,D> const& rhs) {
140 	return boost::numeric::ublas::detail::compare( rhs, [lhs](auto const& r){ return lhs >  r; } );
141 }
142 template<class T, class D>
operator >=(typename T::const_reference lhs,boost::numeric::ublas::detail::tensor_expression<T,D> const & rhs)143 auto operator>=( typename T::const_reference lhs, boost::numeric::ublas::detail::tensor_expression<T,D> const& rhs) {
144 	return boost::numeric::ublas::detail::compare( rhs, [lhs](auto const& r){ return lhs >= r; } );
145 }
146 
147 
148 
149 template<class T, class D>
operator ==(boost::numeric::ublas::detail::tensor_expression<T,D> const & lhs,typename T::const_reference rhs)150 bool operator==( boost::numeric::ublas::detail::tensor_expression<T,D> const& lhs, typename T::const_reference rhs) {
151 	return boost::numeric::ublas::detail::compare( lhs, [rhs](auto const& l){ return l == rhs; } );
152 }
153 template<class T, class D>
operator !=(boost::numeric::ublas::detail::tensor_expression<T,D> const & lhs,typename T::const_reference rhs)154 auto operator!=( boost::numeric::ublas::detail::tensor_expression<T,D> const& lhs, typename T::const_reference rhs) {
155 	return boost::numeric::ublas::detail::compare( lhs, [rhs](auto const& l){ return l != rhs; } );
156 }
157 template<class T, class D>
operator <(boost::numeric::ublas::detail::tensor_expression<T,D> const & lhs,typename T::const_reference rhs)158 auto operator< ( boost::numeric::ublas::detail::tensor_expression<T,D> const& lhs, typename T::const_reference rhs) {
159 	return boost::numeric::ublas::detail::compare( lhs, [rhs](auto const& l){ return l <  rhs; } );
160 }
161 template<class T, class D>
operator <=(boost::numeric::ublas::detail::tensor_expression<T,D> const & lhs,typename T::const_reference rhs)162 auto operator<=( boost::numeric::ublas::detail::tensor_expression<T,D> const& lhs, typename T::const_reference rhs) {
163 	return boost::numeric::ublas::detail::compare( lhs, [rhs](auto const& l){ return l <= rhs; } );
164 }
165 template<class T, class D>
operator >(boost::numeric::ublas::detail::tensor_expression<T,D> const & lhs,typename T::const_reference rhs)166 auto operator> ( boost::numeric::ublas::detail::tensor_expression<T,D> const& lhs, typename T::const_reference rhs) {
167 	return boost::numeric::ublas::detail::compare( lhs, [rhs](auto const& l){ return l >  rhs; } );
168 }
169 template<class T, class D>
operator >=(boost::numeric::ublas::detail::tensor_expression<T,D> const & lhs,typename T::const_reference rhs)170 auto operator>=( boost::numeric::ublas::detail::tensor_expression<T,D> const& lhs, typename T::const_reference rhs) {
171 	return boost::numeric::ublas::detail::compare( lhs, [rhs](auto const& l){ return l >= rhs; } );
172 }
173 
174 
175 #endif
176