xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/comparators_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/comparators.h"
17 
18 #include <cmath>
19 #include <limits>
20 #include <vector>
21 
22 #include "absl/container/inlined_vector.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/primitive_util.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28 #include "tensorflow/compiler/xla/tests/test_macros.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 
32 namespace xla {
33 namespace {
34 
35 class ComparatorsTest : public ClientLibraryTestBase {
36  public:
ComparatorsTest()37   ComparatorsTest() : builder_(TestName()) {}
builder()38   XlaBuilder* builder() { return &builder_; }
39 
40  private:
41   XlaBuilder builder_;
42 };
43 
44 template <
45     PrimitiveType type,
46     typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
BuildComparatorAndComparisons(ComparatorsTest * test,bool compare_less_than,absl::InlinedVector<bool,10> * expected)47 void BuildComparatorAndComparisons(ComparatorsTest* test,
48                                    bool compare_less_than,
49                                    absl::InlinedVector<bool, 10>* expected) {
50   auto compare = compare_less_than
51                      ? CreateScalarLtComputation({type}, test->builder())
52                      : CreateScalarGtComputation({type}, test->builder());
53 
54   auto negative_nan = ConstantR0<T>(
55       test->builder(), -T(std::numeric_limits<float>::quiet_NaN()));
56   auto positive_nan = ConstantR0<T>(test->builder(),
57                                     T(std::numeric_limits<float>::quiet_NaN()));
58   auto negative_zero = ConstantR0<T>(test->builder(), T(-0.));
59   auto positive_zero = ConstantR0<T>(test->builder(), T(0.));
60   auto negative_infinity = MinValue(test->builder(), type);
61   auto positive_infinity = MaxValue(test->builder(), type);
62 
63   // List the values in the expected sorting order from smallest to largest.
64   std::vector<XlaOp> all_constants{negative_nan,      negative_infinity,
65                                    negative_zero,     positive_zero,
66                                    positive_infinity, positive_nan};
67 
68   // Do pairwise comparisons.
69   std::vector<XlaOp> all_comparisons;
70   all_comparisons.reserve(std::pow(all_constants.size(), 2));
71   for (const XlaOp& lhs_constant : all_constants) {
72     for (const XlaOp& rhs_constant : all_constants) {
73       all_comparisons.push_back(Broadcast(
74           Call(test->builder(), compare, {lhs_constant, rhs_constant}), {1}));
75     }
76   }
77 
78   // Concatenate the comparison results.
79   ConcatInDim(test->builder(), all_comparisons, 0);
80 
81   // If we use less-than comparisons, we expect the comparison to result in true
82   // if the lhs value to be compared appears earlier in 'all_constants' than the
83   // rhs value. Likewise, if we use greater-than comparisons, we expect the
84   // comparison to return true if the rhs value appears earlier in
85   // 'all_constants' than the lhs value.
86   expected->clear();
87   for (int i = 0; i < all_constants.size(); ++i) {
88     for (int j = 0; j < all_constants.size(); ++j) {
89       expected->push_back(compare_less_than ? i < j : i > j);
90     }
91   }
92 }
93 
XLA_TEST_F(ComparatorsTest,CompareLtBF16)94 XLA_TEST_F(ComparatorsTest, CompareLtBF16) {
95   absl::InlinedVector<bool, 10> expected;
96   BuildComparatorAndComparisons<BF16>(this, /*compare_less_than=*/true,
97                                       &expected);
98   ComputeAndCompareR1<bool>(builder(), expected, {});
99 }
100 
XLA_TEST_F(ComparatorsTest,CompareGtBF16)101 XLA_TEST_F(ComparatorsTest, CompareGtBF16) {
102   absl::InlinedVector<bool, 10> expected;
103   BuildComparatorAndComparisons<BF16>(this, /*compare_less_than=*/false,
104                                       &expected);
105   ComputeAndCompareR1<bool>(builder(), expected, {});
106 }
107 
XLA_TEST_F(ComparatorsTest,CompareLtF16)108 XLA_TEST_F(ComparatorsTest, CompareLtF16) {
109   absl::InlinedVector<bool, 10> expected;
110   BuildComparatorAndComparisons<F16>(this, /*compare_less_than=*/true,
111                                      &expected);
112   ComputeAndCompareR1<bool>(builder(), expected, {});
113 }
114 
XLA_TEST_F(ComparatorsTest,CompareGtF16)115 XLA_TEST_F(ComparatorsTest, CompareGtF16) {
116   absl::InlinedVector<bool, 10> expected;
117   BuildComparatorAndComparisons<F16>(this, /*compare_less_than=*/false,
118                                      &expected);
119   ComputeAndCompareR1<bool>(builder(), expected, {});
120 }
121 
XLA_TEST_F(ComparatorsTest,CompareLtF32)122 XLA_TEST_F(ComparatorsTest, CompareLtF32) {
123   absl::InlinedVector<bool, 10> expected;
124   BuildComparatorAndComparisons<F32>(this, /*compare_less_than=*/true,
125                                      &expected);
126   ComputeAndCompareR1<bool>(builder(), expected, {});
127 }
128 
XLA_TEST_F(ComparatorsTest,CompareGtF32)129 XLA_TEST_F(ComparatorsTest, CompareGtF32) {
130   absl::InlinedVector<bool, 10> expected;
131   BuildComparatorAndComparisons<F32>(this, /*compare_less_than=*/false,
132                                      &expected);
133   ComputeAndCompareR1<bool>(builder(), expected, {});
134 }
135 
XLA_TEST_F(ComparatorsTest,CompareLtF64)136 XLA_TEST_F(ComparatorsTest, CompareLtF64) {
137   absl::InlinedVector<bool, 10> expected;
138   BuildComparatorAndComparisons<F64>(this, /*compare_less_than=*/true,
139                                      &expected);
140   ComputeAndCompareR1<bool>(builder(), expected, {});
141 }
142 
XLA_TEST_F(ComparatorsTest,CompareGtF64)143 XLA_TEST_F(ComparatorsTest, CompareGtF64) {
144   absl::InlinedVector<bool, 10> expected;
145   BuildComparatorAndComparisons<F64>(this, /*compare_less_than=*/false,
146                                      &expected);
147   ComputeAndCompareR1<bool>(builder(), expected, {});
148 }
149 
150 }  // namespace
151 }  // namespace xla
152