xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/arithmetic.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace xla {
30 
CreateScalarComputation(const std::string & name,PrimitiveType type,XlaBuilder * builder,XlaOpGenerator generator)31 XlaComputation CreateScalarComputation(const std::string& name,
32                                        PrimitiveType type, XlaBuilder* builder,
33                                        XlaOpGenerator generator) {
34   std::unique_ptr<XlaBuilder> b;
35   if (type == PRED) {
36     b = builder->CreateSubBuilder(name);
37   } else {
38     b = builder->CreateSubBuilder(
39         absl::StrCat(name, "_", PrimitiveType_Name(type)));
40   }
41 
42   const Shape scalar = ShapeUtil::MakeShape(type, {});
43   auto lhs = Parameter(b.get(), 0, scalar, "lhs");
44   auto rhs = Parameter(b.get(), 1, scalar, "rhs");
45   generator(lhs, rhs);
46   return b->BuildAndNoteError();
47 }
48 
CreateScalarAddComputation(PrimitiveType type,XlaBuilder * builder)49 XlaComputation CreateScalarAddComputation(PrimitiveType type,
50                                           XlaBuilder* builder) {
51   return CreateScalarComputation(
52       "add", type, builder, [](XlaOp lhs, XlaOp rhs) { return Add(lhs, rhs); });
53 }
54 
CreateScalarMultiplyComputation(PrimitiveType type,XlaBuilder * builder)55 XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
56                                                XlaBuilder* builder) {
57   return CreateScalarComputation(
58       "mul", type, builder, [](XlaOp lhs, XlaOp rhs) { return Mul(lhs, rhs); });
59 }
60 
CreateScalarGeComputation(PrimitiveType type,XlaBuilder * builder)61 XlaComputation CreateScalarGeComputation(PrimitiveType type,
62                                          XlaBuilder* builder) {
63   return CreateScalarComputation(
64       "ge", type, builder, [](XlaOp lhs, XlaOp rhs) { return Ge(lhs, rhs); });
65 }
66 
CreateScalarMaxComputation(PrimitiveType type,XlaBuilder * builder)67 XlaComputation CreateScalarMaxComputation(PrimitiveType type,
68                                           XlaBuilder* builder) {
69   return CreateScalarComputation(
70       "max", type, builder, [](XlaOp lhs, XlaOp rhs) { return Max(lhs, rhs); });
71 }
72 
CreateScalarMinComputation(PrimitiveType type,XlaBuilder * builder)73 XlaComputation CreateScalarMinComputation(PrimitiveType type,
74                                           XlaBuilder* builder) {
75   return CreateScalarComputation(
76       "min", type, builder, [](XlaOp lhs, XlaOp rhs) { return Min(lhs, rhs); });
77 }
78 
CreateScalarAndComputation(PrimitiveType type,XlaBuilder * builder)79 XlaComputation CreateScalarAndComputation(PrimitiveType type,
80                                           XlaBuilder* builder) {
81   return CreateScalarComputation(
82       "and", type, builder, [](XlaOp lhs, XlaOp rhs) { return And(lhs, rhs); });
83 }
84 
CreateScalarOrComputation(PrimitiveType type,XlaBuilder * builder)85 XlaComputation CreateScalarOrComputation(PrimitiveType type,
86                                          XlaBuilder* builder) {
87   return CreateScalarComputation(
88       "or", type, builder, [](XlaOp lhs, XlaOp rhs) { return Or(lhs, rhs); });
89 }
90 
CreateScalarIdentityWithZeroComputation(PrimitiveType type,XlaBuilder * builder)91 XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,
92                                                        XlaBuilder* builder) {
93   XlaComputation reducer =
94       (primitive_util::IsIntegralType(type) || type == PRED)
95           ? CreateScalarOrComputation(type, builder)
96           : CreateScalarAddComputation(type, builder);
97   return reducer;
98 }
99 
Any(XlaOp predicates)100 XlaOp Any(XlaOp predicates) {
101   XlaBuilder* builder = predicates.builder();
102   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
103     auto f = ConstantR0<bool>(builder, false);
104     XlaComputation logical_or = CreateScalarOrComputation(PRED, builder);
105     TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
106                         builder->GetShape(predicates));
107     std::vector<int64_t> all_dimensions(predicates_shape.rank());
108     std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
109     return Reduce(predicates, f, logical_or, all_dimensions);
110   });
111 }
112 
CreateMinMaxComputation(XlaBuilder * outer_builder,PrimitiveType value_type,PrimitiveType index_type,bool is_min)113 static XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
114                                               PrimitiveType value_type,
115                                               PrimitiveType index_type,
116                                               bool is_min) {
117   auto sub_builder = outer_builder->CreateSubBuilder("minmax_func");
118   XlaBuilder* b = sub_builder.get();
119   XlaOp lhs_value =
120       Parameter(b, 0, ShapeUtil::MakeShape(value_type, {}), "lhs_value");
121   XlaOp lhs_index =
122       Parameter(b, 1, ShapeUtil::MakeShape(index_type, {}), "lhs_index");
123   XlaOp rhs_value =
124       Parameter(b, 2, ShapeUtil::MakeShape(value_type, {}), "rhs_value");
125   XlaOp rhs_index =
126       Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index");
127 
128   XlaOp cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value);
129   XlaOp max = Select(cmp, lhs_value, rhs_value);
130   XlaOp arg_max = Select(cmp, lhs_index, rhs_index);
131   XlaOp eq = Eq(lhs_value, rhs_value);
132   XlaOp tie_id = Min(lhs_index, rhs_index);
133   arg_max = Select(eq, tie_id, arg_max);
134   Tuple(b, {max, arg_max});
135   return b->BuildAndNoteError();
136 }
137 
ArgMinMax(XlaOp input,PrimitiveType output_type,int axis,bool is_min)138 XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
139   XlaBuilder* builder = input.builder();
140   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
141     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
142     XlaOp value_init_value;
143     if (is_min) {
144       value_init_value = MaxValue(builder, input_shape.element_type());
145     } else {
146       value_init_value = MinValue(builder, input_shape.element_type());
147     }
148     int64_t dimension_size = input_shape.dimensions(axis);
149     auto index_type = dimension_size <= INT32_MAX ? S32 : output_type;
150     XlaOp index_init_value = Zero(builder, index_type);
151     auto iota_shape = input_shape;
152     iota_shape.set_element_type(index_type);
153     XlaOp iota = Iota(builder, iota_shape, axis);
154 
155     XlaComputation reducer = CreateMinMaxComputation(
156         builder, input_shape.element_type(), index_type, is_min);
157     XlaOp max_argmax = Reduce(builder, {input, iota},
158                               {value_init_value, index_init_value}, reducer,
159                               /*dimensions_to_reduce=*/{axis});
160     XlaOp argmax = GetTupleElement(max_argmax, 1);
161     if (index_type != output_type) {
162       argmax = ConvertElementType(argmax, output_type);
163     }
164     return argmax;
165   });
166 }
167 
ArgMax(XlaOp input,PrimitiveType output_type,int axis)168 XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) {
169   return ArgMinMax(input, output_type, axis, /*is_min=*/false);
170 }
171 
ArgMin(XlaOp input,PrimitiveType output_type,int axis)172 XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) {
173   return ArgMinMax(input, output_type, axis, /*is_min=*/true);
174 }
175 
176 }  // namespace xla
177