xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/comparison_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/comparison_util.h"
17 
18 #include <optional>
19 #include <string>
20 
21 #include "absl/base/attributes.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 
28 namespace xla {
29 namespace {
30 
31 // Verifies that this is a valid Comparison: (1) not a partial ordering on
32 // integers, and (2) a valid PrimitiveType.
IsValidComparison(xla::PrimitiveType type,Comparison::Order order)33 bool IsValidComparison(xla::PrimitiveType type, Comparison::Order order) {
34   switch (type) {
35     case F16:
36     case F32:
37     case BF16:
38     case F64:
39     case C64:
40     case C128:
41       return true;
42     case S8:
43     case S16:
44     case S32:
45     case S64:
46     case PRED:
47     case U8:
48     case U16:
49     case U32:
50     case U64:
51       return order == Comparison::Order::kTotal;
52     case TUPLE:
53     case OPAQUE_TYPE:
54     case TOKEN:
55     case PRIMITIVE_TYPE_INVALID:
56     case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
57     case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
58       return false;
59   }
60 }
61 
62 // Returns the X32 primitive type for each Type.
DefaultPrimitiveType(Comparison::Type type)63 PrimitiveType DefaultPrimitiveType(Comparison::Type type) {
64   switch (type) {
65     case Comparison::Type::kFloat:
66     case Comparison::Type::kFloatTotalOrder:
67       return PrimitiveType::F32;
68     case Comparison::Type::kSigned:
69       return PrimitiveType::S32;
70     case Comparison::Type::kUnsigned:
71       return PrimitiveType::U32;
72   }
73 }
74 
75 // Returns the default ordering for each Comparison::Type.
DefaultOrdering(Comparison::Type type)76 Comparison::Order DefaultOrdering(Comparison::Type type) {
77   switch (type) {
78     case Comparison::Type::kFloat:
79       return Comparison::Order::kPartial;
80     case Comparison::Type::kFloatTotalOrder:
81     case Comparison::Type::kSigned:
82     case Comparison::Type::kUnsigned:
83       return Comparison::Order::kTotal;
84   }
85 }
86 
87 // Returns the expected ordering for each primitive type.
DefaultOrdering(PrimitiveType type)88 Comparison::Order DefaultOrdering(PrimitiveType type) {
89   switch (type) {
90     case S8:
91     case S16:
92     case S32:
93     case S64:
94     case PRED:
95     case U8:
96     case U16:
97     case U32:
98     case U64:
99       return Comparison::Order::kTotal;
100     case BF16:
101     case F16:
102     case F32:
103     case F64:
104     case C64:
105     case C128:
106       return Comparison::Order::kPartial;
107     default:
108       LOG(FATAL) << "Unsupported type: " << PrimitiveType_Name(type);
109   }
110 }
111 
112 // Returns the converse of `direction`.
Converse(Comparison::Direction direction)113 Comparison::Direction Converse(Comparison::Direction direction) {
114   switch (direction) {
115     case Comparison::Direction::kEq:
116       return Comparison::Direction::kEq;
117     case Comparison::Direction::kNe:
118       return Comparison::Direction::kNe;
119     case Comparison::Direction::kGe:
120       return Comparison::Direction::kLe;
121     case Comparison::Direction::kGt:
122       return Comparison::Direction::kLt;
123     case Comparison::Direction::kLe:
124       return Comparison::Direction::kGe;
125     case Comparison::Direction::kLt:
126       return Comparison::Direction::kGt;
127   }
128 }
129 
130 // Returns the inverse of `direction`.
Inverse(Comparison::Direction direction)131 Comparison::Direction Inverse(Comparison::Direction direction) {
132   switch (direction) {
133     case Comparison::Direction::kEq:
134       return Comparison::Direction::kNe;
135     case Comparison::Direction::kNe:
136       return Comparison::Direction::kEq;
137     case Comparison::Direction::kGe:
138       return Comparison::Direction::kLt;
139     case Comparison::Direction::kGt:
140       return Comparison::Direction::kLe;
141     case Comparison::Direction::kLe:
142       return Comparison::Direction::kGt;
143     case Comparison::Direction::kLt:
144       return Comparison::Direction::kGe;
145   }
146 }
147 
148 }  // namespace
149 
ComparisonDirectionToString(Comparison::Direction direction)150 std::string ComparisonDirectionToString(Comparison::Direction direction) {
151   switch (direction) {
152     case Comparison::Direction::kEq:
153       return "EQ";
154     case Comparison::Direction::kNe:
155       return "NE";
156     case Comparison::Direction::kGe:
157       return "GE";
158     case Comparison::Direction::kGt:
159       return "GT";
160     case Comparison::Direction::kLe:
161       return "LE";
162     case Comparison::Direction::kLt:
163       return "LT";
164     default:
165       LOG(FATAL) << "Attempted to print uninitialized comparison direction";
166   }
167 }
168 
ComparisonTypeToString(Comparison::Type type)169 std::string ComparisonTypeToString(Comparison::Type type) {
170   switch (type) {
171     case Comparison::Type::kFloat:
172       return "FLOAT";
173     case Comparison::Type::kFloatTotalOrder:
174       return "TOTALORDER";
175     case Comparison::Type::kSigned:
176       return "SIGNED";
177     case Comparison::Type::kUnsigned:
178       return "UNSIGNED";
179   }
180 }
181 
ComparisonPrimitiveTypeToString(PrimitiveType type)182 std::string ComparisonPrimitiveTypeToString(PrimitiveType type) {
183   return PrimitiveType_Name(type);
184 }
185 
ComparisonOrderToString(Comparison::Order order)186 std::string ComparisonOrderToString(Comparison::Order order) {
187   switch (order) {
188     case Comparison::Order::kPartial:
189       return "PARTIALORDER";
190     case Comparison::Order::kTotal:
191       return "TOTALORDER";
192   }
193 }
194 
StringToComparisonDirection(absl::string_view direction)195 StatusOr<Comparison::Direction> StringToComparisonDirection(
196     absl::string_view direction) {
197   static auto* map =
198       new absl::flat_hash_map<std::string, Comparison::Direction>({
199           {"EQ", Comparison::Direction::kEq},
200           {"NE", Comparison::Direction::kNe},
201           {"GE", Comparison::Direction::kGe},
202           {"GT", Comparison::Direction::kGt},
203           {"LE", Comparison::Direction::kLe},
204           {"LT", Comparison::Direction::kLt},
205       });
206   auto it = map->find(direction);
207   if (it == map->end()) {
208     return InvalidArgument("Unknown comparison direction: %s", direction);
209   }
210   return it->second;
211 }
212 
StringToComparisonOrder(absl::string_view order)213 StatusOr<Comparison::Order> StringToComparisonOrder(absl::string_view order) {
214   static auto* map = new absl::flat_hash_map<std::string, Comparison::Order>({
215       {"TOTALORDER", Comparison::Order::kTotal},
216       {"PARTIALORDER", Comparison::Order::kPartial},
217   });
218   auto it = map->find(order);
219   if (it == map->end()) {
220     return InvalidArgument("Unknown comparison type: %s", order);
221   }
222   return it->second;
223 }
224 
StringToComparisonType(absl::string_view comparison)225 StatusOr<Comparison::Type> StringToComparisonType(
226     absl::string_view comparison) {
227   static auto* map = new absl::flat_hash_map<std::string, Comparison::Type>({
228       {"FLOAT", Comparison::Type::kFloat},
229       {"TOTALORDER", Comparison::Type::kFloatTotalOrder},
230       {"SIGNED", Comparison::Type::kSigned},
231       {"UNSIGNED", Comparison::Type::kUnsigned},
232   });
233   auto it = map->find(comparison);
234   if (it == map->end()) {
235     return InvalidArgument("Unknown comparison type: %s", comparison);
236   }
237   return it->second;
238 }
239 
DefaultComparisonType(PrimitiveType type)240 Comparison::Type Comparison::DefaultComparisonType(PrimitiveType type) {
241   switch (type) {
242     case S8:
243     case S16:
244     case S32:
245     case S64:
246       return Type::kSigned;
247     case PRED:
248     case U8:
249     case U16:
250     case U32:
251     case U64:
252       return Type::kUnsigned;
253     case F16:
254     case F32:
255     case BF16:
256     case F64:
257     case C64:
258     case C128:
259       return Type::kFloat;
260     default:
261       LOG(FATAL) << "Unexpected: " << PrimitiveType_Name(type);
262   }
263 }
264 
Comparison(Direction dir,PrimitiveType type,Order order)265 Comparison::Comparison(Direction dir, PrimitiveType type, Order order)
266     : dir_(dir),
267       primitive_type_(type),
268       order_(order),
269       type_(DefaultComparisonType(type)) {
270   CHECK(IsValidComparison(primitive_type_, order_));
271 }
272 
Comparison(Direction dir,PrimitiveType type)273 Comparison::Comparison(Direction dir, PrimitiveType type)
274     : dir_(dir),
275       primitive_type_(type),
276       order_(DefaultOrdering(type)),
277       type_(DefaultComparisonType(type)) {
278   CHECK(IsValidComparison(primitive_type_, order_));
279 }
280 
Comparison(Direction dir,Type type)281 Comparison::Comparison(Direction dir, Type type)
282     : dir_(dir),
283       primitive_type_(DefaultPrimitiveType(type)),
284       order_(DefaultOrdering(type)),
285       type_(type) {
286   CHECK(IsValidComparison(primitive_type_, order_));
287 }
288 
Converse() const289 Comparison Comparison::Converse() const {
290   return Comparison(xla::Converse(dir_), primitive_type_, order_);
291 }
292 
Inverse() const293 std::optional<Comparison> Comparison::Inverse() const {
294   if (IsPartialOrder()) {
295     // We assume comparisons don't have inverses unless they are total order,
296     // e.g., a partial order floating point comparison can return true if one
297     // operand is NaN.
298     return std::nullopt;
299   }
300   switch (primitive_type_) {
301     case F16:
302     case F32:
303     case BF16:
304     case F64:
305     case C64:
306     case C128:
307     case S8:
308     case S16:
309     case S32:
310     case S64:
311     case PRED:
312     case U8:
313     case U16:
314     case U32:
315     case U64:
316       return Comparison(xla::Inverse(dir_), primitive_type_, order_);
317     case TUPLE:
318     case OPAQUE_TYPE:
319     case TOKEN:
320     case PRIMITIVE_TYPE_INVALID:
321     case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
322     case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
323       return std::nullopt;
324   }
325 }
326 
IsReflexive() const327 bool Comparison::IsReflexive() const {
328   switch (dir_) {
329     case Direction::kEq:
330     case Direction::kGe:
331     case Direction::kLe:
332       return IsTotalOrder();
333     case Direction::kNe:
334     case Direction::kGt:
335     case Direction::kLt:
336       return false;
337   }
338 }
339 
IsAntireflexive() const340 bool Comparison::IsAntireflexive() const {
341   switch (dir_) {
342     case Direction::kNe:
343       return IsTotalOrder();
344     case Direction::kGt:
345     case Direction::kLt:
346       return true;
347     case Direction::kEq:
348     case Direction::kGe:
349     case Direction::kLe:
350       return false;
351   }
352 }
353 
ToString(std::string prefix1,std::string prefix2,std::string prefix3) const354 std::string Comparison::ToString(std::string prefix1, std::string prefix2,
355                                  std::string prefix3) const {
356   return absl::StrCat(prefix1, ComparisonDirectionToString(dir_), prefix2,
357                       ComparisonPrimitiveTypeToString(primitive_type_), prefix3,
358                       ComparisonOrderToString(order_));
359 }
360 }  // namespace xla
361