1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/comparators.h"
17
18 #include <cmath>
19 #include <limits>
20 #include <vector>
21
22 #include "absl/container/inlined_vector.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/primitive_util.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28 #include "tensorflow/compiler/xla/tests/test_macros.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31
32 namespace xla {
33 namespace {
34
35 class ComparatorsTest : public ClientLibraryTestBase {
36 public:
ComparatorsTest()37 ComparatorsTest() : builder_(TestName()) {}
builder()38 XlaBuilder* builder() { return &builder_; }
39
40 private:
41 XlaBuilder builder_;
42 };
43
44 template <
45 PrimitiveType type,
46 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
BuildComparatorAndComparisons(ComparatorsTest * test,bool compare_less_than,absl::InlinedVector<bool,10> * expected)47 void BuildComparatorAndComparisons(ComparatorsTest* test,
48 bool compare_less_than,
49 absl::InlinedVector<bool, 10>* expected) {
50 auto compare = compare_less_than
51 ? CreateScalarLtComputation({type}, test->builder())
52 : CreateScalarGtComputation({type}, test->builder());
53
54 auto negative_nan = ConstantR0<T>(
55 test->builder(), -T(std::numeric_limits<float>::quiet_NaN()));
56 auto positive_nan = ConstantR0<T>(test->builder(),
57 T(std::numeric_limits<float>::quiet_NaN()));
58 auto negative_zero = ConstantR0<T>(test->builder(), T(-0.));
59 auto positive_zero = ConstantR0<T>(test->builder(), T(0.));
60 auto negative_infinity = MinValue(test->builder(), type);
61 auto positive_infinity = MaxValue(test->builder(), type);
62
63 // List the values in the expected sorting order from smallest to largest.
64 std::vector<XlaOp> all_constants{negative_nan, negative_infinity,
65 negative_zero, positive_zero,
66 positive_infinity, positive_nan};
67
68 // Do pairwise comparisons.
69 std::vector<XlaOp> all_comparisons;
70 all_comparisons.reserve(std::pow(all_constants.size(), 2));
71 for (const XlaOp& lhs_constant : all_constants) {
72 for (const XlaOp& rhs_constant : all_constants) {
73 all_comparisons.push_back(Broadcast(
74 Call(test->builder(), compare, {lhs_constant, rhs_constant}), {1}));
75 }
76 }
77
78 // Concatenate the comparison results.
79 ConcatInDim(test->builder(), all_comparisons, 0);
80
81 // If we use less-than comparisons, we expect the comparison to result in true
82 // if the lhs value to be compared appears earlier in 'all_constants' than the
83 // rhs value. Likewise, if we use greater-than comparisons, we expect the
84 // comparison to return true if the rhs value appears earlier in
85 // 'all_constants' than the lhs value.
86 expected->clear();
87 for (int i = 0; i < all_constants.size(); ++i) {
88 for (int j = 0; j < all_constants.size(); ++j) {
89 expected->push_back(compare_less_than ? i < j : i > j);
90 }
91 }
92 }
93
XLA_TEST_F(ComparatorsTest,CompareLtBF16)94 XLA_TEST_F(ComparatorsTest, CompareLtBF16) {
95 absl::InlinedVector<bool, 10> expected;
96 BuildComparatorAndComparisons<BF16>(this, /*compare_less_than=*/true,
97 &expected);
98 ComputeAndCompareR1<bool>(builder(), expected, {});
99 }
100
XLA_TEST_F(ComparatorsTest,CompareGtBF16)101 XLA_TEST_F(ComparatorsTest, CompareGtBF16) {
102 absl::InlinedVector<bool, 10> expected;
103 BuildComparatorAndComparisons<BF16>(this, /*compare_less_than=*/false,
104 &expected);
105 ComputeAndCompareR1<bool>(builder(), expected, {});
106 }
107
XLA_TEST_F(ComparatorsTest,CompareLtF16)108 XLA_TEST_F(ComparatorsTest, CompareLtF16) {
109 absl::InlinedVector<bool, 10> expected;
110 BuildComparatorAndComparisons<F16>(this, /*compare_less_than=*/true,
111 &expected);
112 ComputeAndCompareR1<bool>(builder(), expected, {});
113 }
114
XLA_TEST_F(ComparatorsTest,CompareGtF16)115 XLA_TEST_F(ComparatorsTest, CompareGtF16) {
116 absl::InlinedVector<bool, 10> expected;
117 BuildComparatorAndComparisons<F16>(this, /*compare_less_than=*/false,
118 &expected);
119 ComputeAndCompareR1<bool>(builder(), expected, {});
120 }
121
XLA_TEST_F(ComparatorsTest,CompareLtF32)122 XLA_TEST_F(ComparatorsTest, CompareLtF32) {
123 absl::InlinedVector<bool, 10> expected;
124 BuildComparatorAndComparisons<F32>(this, /*compare_less_than=*/true,
125 &expected);
126 ComputeAndCompareR1<bool>(builder(), expected, {});
127 }
128
XLA_TEST_F(ComparatorsTest,CompareGtF32)129 XLA_TEST_F(ComparatorsTest, CompareGtF32) {
130 absl::InlinedVector<bool, 10> expected;
131 BuildComparatorAndComparisons<F32>(this, /*compare_less_than=*/false,
132 &expected);
133 ComputeAndCompareR1<bool>(builder(), expected, {});
134 }
135
XLA_TEST_F(ComparatorsTest,CompareLtF64)136 XLA_TEST_F(ComparatorsTest, CompareLtF64) {
137 absl::InlinedVector<bool, 10> expected;
138 BuildComparatorAndComparisons<F64>(this, /*compare_less_than=*/true,
139 &expected);
140 ComputeAndCompareR1<bool>(builder(), expected, {});
141 }
142
XLA_TEST_F(ComparatorsTest,CompareGtF64)143 XLA_TEST_F(ComparatorsTest, CompareGtF64) {
144 absl::InlinedVector<bool, 10> expected;
145 BuildComparatorAndComparisons<F64>(this, /*compare_less_than=*/false,
146 &expected);
147 ComputeAndCompareR1<bool>(builder(), expected, {});
148 }
149
150 } // namespace
151 } // namespace xla
152