1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28
29 namespace xla {
30
CreateScalarComputation(const std::string & name,PrimitiveType type,XlaBuilder * builder,XlaOpGenerator generator)31 XlaComputation CreateScalarComputation(const std::string& name,
32 PrimitiveType type, XlaBuilder* builder,
33 XlaOpGenerator generator) {
34 std::unique_ptr<XlaBuilder> b;
35 if (type == PRED) {
36 b = builder->CreateSubBuilder(name);
37 } else {
38 b = builder->CreateSubBuilder(
39 absl::StrCat(name, "_", PrimitiveType_Name(type)));
40 }
41
42 const Shape scalar = ShapeUtil::MakeShape(type, {});
43 auto lhs = Parameter(b.get(), 0, scalar, "lhs");
44 auto rhs = Parameter(b.get(), 1, scalar, "rhs");
45 generator(lhs, rhs);
46 return b->BuildAndNoteError();
47 }
48
CreateScalarAddComputation(PrimitiveType type,XlaBuilder * builder)49 XlaComputation CreateScalarAddComputation(PrimitiveType type,
50 XlaBuilder* builder) {
51 return CreateScalarComputation(
52 "add", type, builder, [](XlaOp lhs, XlaOp rhs) { return Add(lhs, rhs); });
53 }
54
CreateScalarMultiplyComputation(PrimitiveType type,XlaBuilder * builder)55 XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
56 XlaBuilder* builder) {
57 return CreateScalarComputation(
58 "mul", type, builder, [](XlaOp lhs, XlaOp rhs) { return Mul(lhs, rhs); });
59 }
60
CreateScalarGeComputation(PrimitiveType type,XlaBuilder * builder)61 XlaComputation CreateScalarGeComputation(PrimitiveType type,
62 XlaBuilder* builder) {
63 return CreateScalarComputation(
64 "ge", type, builder, [](XlaOp lhs, XlaOp rhs) { return Ge(lhs, rhs); });
65 }
66
CreateScalarMaxComputation(PrimitiveType type,XlaBuilder * builder)67 XlaComputation CreateScalarMaxComputation(PrimitiveType type,
68 XlaBuilder* builder) {
69 return CreateScalarComputation(
70 "max", type, builder, [](XlaOp lhs, XlaOp rhs) { return Max(lhs, rhs); });
71 }
72
CreateScalarMinComputation(PrimitiveType type,XlaBuilder * builder)73 XlaComputation CreateScalarMinComputation(PrimitiveType type,
74 XlaBuilder* builder) {
75 return CreateScalarComputation(
76 "min", type, builder, [](XlaOp lhs, XlaOp rhs) { return Min(lhs, rhs); });
77 }
78
CreateScalarAndComputation(PrimitiveType type,XlaBuilder * builder)79 XlaComputation CreateScalarAndComputation(PrimitiveType type,
80 XlaBuilder* builder) {
81 return CreateScalarComputation(
82 "and", type, builder, [](XlaOp lhs, XlaOp rhs) { return And(lhs, rhs); });
83 }
84
CreateScalarOrComputation(PrimitiveType type,XlaBuilder * builder)85 XlaComputation CreateScalarOrComputation(PrimitiveType type,
86 XlaBuilder* builder) {
87 return CreateScalarComputation(
88 "or", type, builder, [](XlaOp lhs, XlaOp rhs) { return Or(lhs, rhs); });
89 }
90
CreateScalarIdentityWithZeroComputation(PrimitiveType type,XlaBuilder * builder)91 XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,
92 XlaBuilder* builder) {
93 XlaComputation reducer =
94 (primitive_util::IsIntegralType(type) || type == PRED)
95 ? CreateScalarOrComputation(type, builder)
96 : CreateScalarAddComputation(type, builder);
97 return reducer;
98 }
99
Any(XlaOp predicates)100 XlaOp Any(XlaOp predicates) {
101 XlaBuilder* builder = predicates.builder();
102 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
103 auto f = ConstantR0<bool>(builder, false);
104 XlaComputation logical_or = CreateScalarOrComputation(PRED, builder);
105 TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
106 builder->GetShape(predicates));
107 std::vector<int64_t> all_dimensions(predicates_shape.rank());
108 std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
109 return Reduce(predicates, f, logical_or, all_dimensions);
110 });
111 }
112
CreateMinMaxComputation(XlaBuilder * outer_builder,PrimitiveType value_type,PrimitiveType index_type,bool is_min)113 static XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
114 PrimitiveType value_type,
115 PrimitiveType index_type,
116 bool is_min) {
117 auto sub_builder = outer_builder->CreateSubBuilder("minmax_func");
118 XlaBuilder* b = sub_builder.get();
119 XlaOp lhs_value =
120 Parameter(b, 0, ShapeUtil::MakeShape(value_type, {}), "lhs_value");
121 XlaOp lhs_index =
122 Parameter(b, 1, ShapeUtil::MakeShape(index_type, {}), "lhs_index");
123 XlaOp rhs_value =
124 Parameter(b, 2, ShapeUtil::MakeShape(value_type, {}), "rhs_value");
125 XlaOp rhs_index =
126 Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index");
127
128 XlaOp cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value);
129 XlaOp max = Select(cmp, lhs_value, rhs_value);
130 XlaOp arg_max = Select(cmp, lhs_index, rhs_index);
131 XlaOp eq = Eq(lhs_value, rhs_value);
132 XlaOp tie_id = Min(lhs_index, rhs_index);
133 arg_max = Select(eq, tie_id, arg_max);
134 Tuple(b, {max, arg_max});
135 return b->BuildAndNoteError();
136 }
137
ArgMinMax(XlaOp input,PrimitiveType output_type,int axis,bool is_min)138 XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
139 XlaBuilder* builder = input.builder();
140 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
141 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
142 XlaOp value_init_value;
143 if (is_min) {
144 value_init_value = MaxValue(builder, input_shape.element_type());
145 } else {
146 value_init_value = MinValue(builder, input_shape.element_type());
147 }
148 int64_t dimension_size = input_shape.dimensions(axis);
149 auto index_type = dimension_size <= INT32_MAX ? S32 : output_type;
150 XlaOp index_init_value = Zero(builder, index_type);
151 auto iota_shape = input_shape;
152 iota_shape.set_element_type(index_type);
153 XlaOp iota = Iota(builder, iota_shape, axis);
154
155 XlaComputation reducer = CreateMinMaxComputation(
156 builder, input_shape.element_type(), index_type, is_min);
157 XlaOp max_argmax = Reduce(builder, {input, iota},
158 {value_init_value, index_init_value}, reducer,
159 /*dimensions_to_reduce=*/{axis});
160 XlaOp argmax = GetTupleElement(max_argmax, 1);
161 if (index_type != output_type) {
162 argmax = ConvertElementType(argmax, output_type);
163 }
164 return argmax;
165 });
166 }
167
ArgMax(XlaOp input,PrimitiveType output_type,int axis)168 XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) {
169 return ArgMinMax(input, output_type, axis, /*is_min=*/false);
170 }
171
ArgMin(XlaOp input,PrimitiveType output_type,int axis)172 XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) {
173 return ArgMinMax(input, output_type, axis, /*is_min=*/true);
174 }
175
176 } // namespace xla
177