1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/ops.h"
17
18 #include <optional>
19 #include <string>
20 #include <vector>
21
22 #include "absl/types/span.h"
23 #include "pybind11/attr.h"
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/compiler/xla/client/lib/approx_topk.h"
26 #include "tensorflow/compiler/xla/client/lib/approx_topk_shape.h"
27 #include "tensorflow/compiler/xla/client/lib/comparators.h"
28 #include "tensorflow/compiler/xla/client/lib/lu_decomposition.h"
29 #include "tensorflow/compiler/xla/client/lib/math.h"
30 #include "tensorflow/compiler/xla/client/lib/qr.h"
31 #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
32 #include "tensorflow/compiler/xla/client/lib/sorting.h"
33 #include "tensorflow/compiler/xla/client/lib/svd.h"
34 #include "tensorflow/compiler/xla/client/xla_builder.h"
35 #include "tensorflow/compiler/xla/client/xla_computation.h"
36 #include "tensorflow/compiler/xla/python/types.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38
39 namespace xla {
40
41 namespace py = pybind11;
42
BuildOpsSubmodule(py::module * m)43 void BuildOpsSubmodule(py::module* m) {
44 // ops submodule, containing free functions that add operators to an
45 // XlaBuilder.
46 py::module ops = m->def_submodule("ops", "XLA operations");
47
48 py::enum_<TriangularSolveOptions::Transpose>(
49 ops, "TriangularSolveOptions_Transpose")
50 .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID)
51 .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE)
52 .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
53 .value("ADJOINT", TriangularSolveOptions::ADJOINT);
54
55 py::enum_<RandomAlgorithm>(ops, "RandomAlgorithm")
56 .value("RNG_DEFAULT", RandomAlgorithm::RNG_DEFAULT)
57 .value("RNG_THREE_FRY", RandomAlgorithm::RNG_THREE_FRY)
58 .value("RNG_PHILOX", RandomAlgorithm::RNG_PHILOX);
59
60 py::enum_<CustomCallSchedule>(ops, "CustomCallSchedule")
61 .value("SCHEDULE_NONE", CustomCallSchedule::SCHEDULE_NONE)
62 .value("SCHEDULE_LATEST", CustomCallSchedule::SCHEDULE_LATEST)
63 .value("SCHEDULE_EARLIEST", CustomCallSchedule::SCHEDULE_EARLIEST);
64
65 py::enum_<CustomCallApiVersion>(ops, "CustomCallApiVersion")
66 .value("API_VERSION_ORIGINAL", CustomCallApiVersion::API_VERSION_ORIGINAL)
67 .value("API_VERSION_STATUS_RETURNING",
68 CustomCallApiVersion::API_VERSION_STATUS_RETURNING)
69 .value("API_VERSION_STATUS_RETURNING_UNIFIED",
70 CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED);
71
72 ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens"));
73 ops.def("AllGather", &AllGather, py::arg("operand"),
74 py::arg("all_gather_dimension"), py::arg("shard_count"),
75 py::arg("replica_groups") = py::list(),
76 py::arg("channel_id") = std::nullopt,
77 py::arg("shape_with_layout") = std::nullopt,
78 py::arg("use_global_device_ids") = std::nullopt);
79 ops.def("AllReduce",
80 static_cast<XlaOp (*)(
81 XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
82 const std::optional<ChannelHandle>&, const std::optional<Shape>&,
83 const std::optional<bool>)>(&AllReduce),
84 py::arg("operand"), py::arg("computation"),
85 py::arg("replica_groups") = py::list(),
86 py::arg("channel_id") = std::nullopt,
87 py::arg("shape_with_layout") = std::nullopt,
88 py::arg("use_global_device_ids") = std::nullopt);
89 ops.def("ReduceScatter", &ReduceScatter, py::arg("operand"),
90 py::arg("computation"), py::arg("scatter_dimension"),
91 py::arg("shard_count"), py::arg("replica_groups") = py::list(),
92 py::arg("channel_id") = std::nullopt,
93 py::arg("layout") = std::nullopt,
94 py::arg("use_global_device_ids") = std::nullopt);
95 ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
96 py::arg("concat_dimension"), py::arg("split_count"),
97 py::arg("replica_groups") = py::list(),
98 py::arg("layout") = std::nullopt);
99 ops.def("ApproxTopK", &ApproxTopK, py::arg("builder"), py::arg("operands"),
100 py::arg("init_values"), py::arg("top_k"), py::arg("reduction_dim"),
101 py::arg("comparator"), py::arg("recall_target") = 0.9,
102 py::arg("aggregate_to_topk") = true,
103 py::arg("reduction_input_size_override") = -1);
104 ops.def("ApproxTopKFallback", &ApproxTopKFallback, py::arg("builder"),
105 py::arg("operands"), py::arg("init_values"), py::arg("top_k"),
106 py::arg("reduction_dim"), py::arg("comparator"),
107 py::arg("recall_target") = 0.9, py::arg("aggregate_to_topk") = true,
108 py::arg("reduction_input_size_override") = -1);
109 ops.def("ApproxTopKReductionOutputSize", &ApproxTopKReductionOutputSize,
110 py::arg("input_size"), py::arg("rank"), py::arg("top_k"),
111 py::arg("recall_target"), py::arg("aggregate_to_topk") = true,
112 py::arg("input_size_override") = -1);
113 ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"),
114 py::arg("new_element_type"));
115 ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes"));
116 ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"),
117 py::arg("shape"), py::arg("broadcast_dimensions"));
118 ops.def("Call", &Call, py::arg("builder"), py::arg("computation"),
119 py::arg("operands"));
120 ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true);
121 ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max"));
122 ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions"));
123 ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
124 py::arg("source_target_pairs"));
125 ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"),
126 py::arg("dimension"));
127 ops.def("Conditional",
128 static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>,
129 absl::Span<const XlaOp>)>(&Conditional),
130 py::arg("branch_index"), py::arg("branch_computations"),
131 py::arg("branch_operands"));
132 ops.def("Conditional",
133 static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp,
134 const XlaComputation&)>(&Conditional),
135 py::arg("predicate"), py::arg("true_operand"),
136 py::arg("true_computation"), py::arg("false_operand"),
137 py::arg("false_computation"));
138 ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal"));
139 ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"),
140 py::arg("literal"));
141 ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"),
142 py::arg("rhs"), py::arg("window_strides"), py::arg("padding"),
143 py::arg("lhs_dilation"), py::arg("rhs_dilation"),
144 py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
145 py::arg("batch_group_count") = 1,
146 py::arg("precision_config") = nullptr,
147 py::arg("preferred_element_type") = std::nullopt,
148 py::arg("window_reversal") = std::nullopt);
149 ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
150 py::arg("new_element_type"));
151 ops.def("CreateToken", &CreateToken, py::arg("builder"));
152 ops.def("CrossReplicaSum",
153 static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
154 &CrossReplicaSum),
155 py::arg("operand"), py::arg("replica_groups") = py::list());
156 ops.def(
157 "CustomCall",
158 [](XlaBuilder* builder, const py::bytes& call_target_name,
159 absl::Span<const XlaOp> operands, const Shape& shape,
160 const py::bytes& opaque, bool has_side_effect,
161 CustomCallSchedule schedule,
162 CustomCallApiVersion api_version) -> XlaOp {
163 return CustomCall(builder, call_target_name, operands, shape, opaque,
164 has_side_effect, /*output_operand_aliasing=*/{},
165 /*literal=*/nullptr, schedule, api_version);
166 },
167 py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
168 py::arg("shape"), py::arg("opaque") = py::bytes(""),
169 py::arg("has_side_effect") = false,
170 py::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE,
171 py::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL);
172 ops.def(
173 "CustomCallWithLayout",
174 [](XlaBuilder* builder, const py::bytes& call_target_name,
175 absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
176 absl::Span<const Shape> operand_shapes_with_layout,
177 const py::bytes& opaque, bool has_side_effect,
178 CustomCallSchedule schedule,
179 CustomCallApiVersion api_version) -> XlaOp {
180 return CustomCallWithLayout(
181 builder, call_target_name, operands, shape_with_layout,
182 operand_shapes_with_layout, opaque, has_side_effect,
183 /*output_operand_aliasing=*/{},
184 /*literal=*/nullptr, schedule, api_version);
185 },
186 py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
187 py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
188 py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false,
189 py::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE,
190 py::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL);
191 ops.def(
192 "CustomCallWithAliasing",
193 [](XlaBuilder* builder, const py::bytes& call_target_name,
194 absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
195 absl::Span<const Shape> operand_shapes_with_layout,
196 const py::bytes& opaque, bool has_side_effect,
197 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
198 output_operand_aliasing,
199 const Literal* literal, CustomCallSchedule schedule,
200 CustomCallApiVersion api_version) -> XlaOp {
201 return CustomCallWithLayout(
202 builder, call_target_name, operands, shape_with_layout,
203 operand_shapes_with_layout, opaque, has_side_effect,
204 output_operand_aliasing, literal, schedule, api_version);
205 },
206 py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
207 py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
208 py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false,
209 py::arg("output_operand_aliasing"), py::arg("literal") = nullptr,
210 py::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE,
211 py::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL);
212 ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
213 py::arg("precision_config") = nullptr,
214 py::arg("preferred_element_type") = std::nullopt);
215 ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
216 py::arg("dimension_numbers"), py::arg("precision_config") = nullptr,
217 py::arg("preferred_element_type") = std::nullopt);
218 ops.def("DynamicReshape",
219 static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
220 absl::Span<const int64_t>,
221 const std::vector<bool>&)>(&DynamicReshape),
222 py::arg("operand"), py::arg("dim_sizes"), py::arg("new_size_bounds"),
223 py::arg("dims_are_dynamic"));
224 ops.def("DynamicSlice",
225 static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
226 absl::Span<const int64_t>)>(&DynamicSlice),
227 py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes"));
228 ops.def("DynamicUpdateSlice",
229 static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
230 &DynamicUpdateSlice),
231 py::arg("operand"), py::arg("update"), py::arg("start_indices"));
232 ops.def(
233 "Eigh",
234 [](XlaOp a, bool lower, int64_t max_iter, float epsilon,
235 bool sort_eigenvalues) -> std::pair<XlaOp, XlaOp> {
236 auto eigh =
237 SelfAdjointEig(a, lower, max_iter, epsilon, sort_eigenvalues);
238 return std::make_pair(eigh.v, eigh.w);
239 },
240 py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 15,
241 py::arg("epsilon") = 1e-5, py::arg("sort_eigenvalues") = true);
242 ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"),
243 py::arg("fft_length"));
244 ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
245 py::arg("dimension_numbers"), py::arg("slice_sizes"),
246 py::arg("indices_are_sorted") = false);
247 ops.def("GetDimensionSize", &GetDimensionSize, py::arg("operand"),
248 py::arg("dimension"));
249 ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"),
250 py::arg("index"));
251 ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"),
252 py::arg("shape"), py::arg("config") = "");
253 ops.def("Iota",
254 static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64_t)>(&Iota),
255 py::arg("builder"), py::arg("shape"), py::arg("iota_dimension"));
256 ops.def("Iota",
257 static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64_t)>(&Iota),
258 py::arg("builder"), py::arg("type"), py::arg("size"));
259 ops.def(
260 "LU",
261 [](XlaOp a) -> StatusOr<std::tuple<XlaOp, XlaOp, XlaOp>> {
262 LuDecompositionResult lu = LuDecomposition(a);
263 return std::make_tuple(lu.lu, lu.pivots, lu.permutation);
264 },
265 py::arg("operand"));
266 ops.def("Map", &Map, py::arg("builder"), py::arg("operands"),
267 py::arg("computation"), py::arg("dimensions"),
268 py::arg("static_operands") = py::list());
269 ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to"));
270 ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
271 py::arg("token"), py::arg("shape_with_layout"),
272 py::arg("outfeed_config") = "");
273 ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"),
274 py::arg("padding_config"));
275 ops.def("Parameter",
276 static_cast<XlaOp (*)(XlaBuilder*, int64_t, const Shape&,
277 const std::string&, const std::vector<bool>&)>(
278 &Parameter),
279 py::arg("builder"), py::arg("parameter_number"), py::arg("shape"),
280 py::arg("name") = "",
281 py::arg("replicated_at_leaf_buffers") = std::vector<bool>());
282 ops.def("ProductOfElementaryHouseholderReflectors",
283 &ProductOfElementaryHouseholderReflectors, py::arg("a"),
284 py::arg("taus"));
285 ops.def(
286 "QR",
287 [](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
288 XlaOp q, r;
289 QrExplicit(a, full_matrices, q, r);
290 return std::make_pair(q, r);
291 },
292 py::arg("operand"), py::arg("full_matrices"));
293 ops.def(
294 "QrDecomposition",
295 [](XlaOp a) -> StatusOr<std::pair<XlaOp, XlaOp>> {
296 QrDecomposition d = Qr(a);
297 return std::make_pair(d.q_and_r, d.taus);
298 },
299 py::arg("operand"));
300 ops.def("RecvFromHost", &RecvFromHost, py::arg("token"), py::arg("shape"),
301 py::arg("handle"));
302 ops.def("Reduce",
303 static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>,
304 absl::Span<const XlaOp>, const XlaComputation&,
305 absl::Span<const int64_t>)>(&Reduce),
306 py::arg("builder"), py::arg("operands"), py::arg("init_values"),
307 py::arg("computation"), py::arg("dimensions_to_reduce"));
308 ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
309 py::arg("exponent_bits"), py::arg("mantissa_bits"));
310 ops.def("ReduceWindowWithGeneralPadding",
311 static_cast<XlaOp (*)(
312 XlaOp, XlaOp, const XlaComputation&, absl::Span<const int64_t>,
313 absl::Span<const int64_t>, absl::Span<const int64_t>,
314 absl::Span<const int64_t>,
315 absl::Span<const std::pair<int64_t, int64_t>>)>(
316 &ReduceWindowWithGeneralPadding),
317 py::arg("operand"), py::arg("init_value"), py::arg("computation"),
318 py::arg("window_dimensions"), py::arg("window_strides"),
319 py::arg("base_dilations"), py::arg("window_dilations"),
320 py::arg("padding"));
321 ops.def("ReduceWindowWithGeneralPadding",
322 static_cast<XlaOp (*)(
323 absl::Span<const XlaOp>, absl::Span<const XlaOp>,
324 const XlaComputation&, absl::Span<const int64_t>,
325 absl::Span<const int64_t>, absl::Span<const int64_t>,
326 absl::Span<const int64_t>,
327 absl::Span<const std::pair<int64_t, int64_t>>)>(
328 &ReduceWindowWithGeneralPadding),
329 py::arg("operands"), py::arg("init_values"), py::arg("computation"),
330 py::arg("window_dimensions"), py::arg("window_strides"),
331 py::arg("base_dilations"), py::arg("window_dilations"),
332 py::arg("padding"));
333 ops.def("RemoveDynamicDimension", &RemoveDynamicDimension, py::arg("operand"),
334 py::arg("dimension"));
335 ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
336 ops.def("Reshape",
337 static_cast<XlaOp (*)(XlaOp, absl::Span<const int64_t>,
338 absl::Span<const int64_t>)>(&Reshape),
339 py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes"));
340 ops.def("Reshape",
341 static_cast<XlaOp (*)(XlaOp, absl::Span<const int64_t>)>(&Reshape),
342 py::arg("operand"), py::arg("new_sizes"));
343 ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions"));
344 ops.def("RngBitGenerator", &RngBitGenerator, py::arg("algorithm"),
345 py::arg("initial_state"), py::arg("shape"));
346 ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"),
347 py::arg("shape"));
348 ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"),
349 py::arg("shape"));
350 ops.def("Scatter",
351 static_cast<XlaOp (*)(XlaOp, XlaOp, XlaOp, const XlaComputation&,
352 const ScatterDimensionNumbers&, bool, bool)>(
353 &Scatter),
354 py::arg("input"), py::arg("scatter_indices"), py::arg("updates"),
355 py::arg("update_computation"), py::arg("dimension_numbers"),
356 py::arg("indices_are_sorted") = false,
357 py::arg("unique_indices") = false);
358 ops.def("Scatter",
359 static_cast<XlaOp (*)(absl::Span<const XlaOp>, XlaOp,
360 absl::Span<const XlaOp>, const XlaComputation&,
361 const ScatterDimensionNumbers&, bool, bool)>(
362 &Scatter),
363 py::arg("inputs"), py::arg("scatter_indices"), py::arg("updates"),
364 py::arg("update_computation"), py::arg("dimension_numbers"),
365 py::arg("indices_are_sorted") = false,
366 py::arg("unique_indices") = false);
367 ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"),
368 py::arg("on_false"));
369 ops.def("SelectAndScatterWithGeneralPadding",
370 &SelectAndScatterWithGeneralPadding, py::arg("operand"),
371 py::arg("select"), py::arg("window_dimensions"),
372 py::arg("window_strides"), py::arg("padding"), py::arg("source"),
373 py::arg("init_value"), py::arg("scatter"));
374 ops.def("SendToHost", &SendToHost, py::arg("operand"), py::arg("token"),
375 py::arg("shape_with_layout"), py::arg("handle"));
376 ops.def("SetDimensionSize", &SetDimensionSize, py::arg("operand"),
377 py::arg("val"), py::arg("dimension"));
378 ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"),
379 py::arg("limit_indices"), py::arg("strides"));
380 ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"),
381 py::arg("limit_index"), py::arg("stride"), py::arg("dimno"));
382 ops.def(
383 "Sort",
384 [](XlaBuilder* builder, absl::Span<const XlaOp> operands,
385 std::optional<const XlaComputation*> comparator, int64_t dimension,
386 bool is_stable) -> XlaOp {
387 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
388 std::vector<PrimitiveType> operand_types;
389 operand_types.reserve(operands.size());
390 for (const auto& operand : operands) {
391 TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand));
392 operand_types.push_back(operand_shape.element_type());
393 }
394
395 if (comparator) {
396 return Sort(operands, **comparator, dimension, is_stable);
397 } else {
398 return Sort(operands,
399 CreateScalarLtComputation(operand_types, builder),
400 dimension, is_stable);
401 }
402 });
403 },
404 py::arg("builder"), py::arg("operands"),
405 py::arg("comparator") = std::nullopt, py::arg("dimension") = -1,
406 py::arg("is_stable") = false);
407 ops.def(
408 "SVD",
409 [](XlaOp a, int64_t max_iter,
410 float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> {
411 auto svd = SVD(a, max_iter, epsilon);
412 return std::make_tuple(svd.u, svd.d, svd.v);
413 },
414 py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6);
415 ops.def("TopK", &TopK, py::arg("input"), py::arg("k"));
416 ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation"));
417 ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"),
418 py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"),
419 py::arg("transpose_a"));
420 ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements"));
421 ops.def("While", &While, py::arg("condition"), py::arg("body"),
422 py::arg("init"));
423
424 ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x"));
425 ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x"));
426 ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x"));
427 ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
428 ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
429 py::arg("b"), py::arg("x"));
430 ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q"));
431
432 #define BINARY_OP(op) \
433 ops.def( \
434 #op, \
435 [](XlaOp a, XlaOp b, std::optional<std::vector<int64_t>> dims) { \
436 return dims ? op(a, b, *dims) : op(a, b); \
437 }, \
438 py::arg("lhs"), py::arg("rhs"), \
439 py::arg("broadcast_dimensions") = std::nullopt)
440 BINARY_OP(Eq);
441 BINARY_OP(Ne);
442 BINARY_OP(Ge);
443 BINARY_OP(Gt);
444 BINARY_OP(Lt);
445 BINARY_OP(Le);
446 BINARY_OP(Add);
447 BINARY_OP(Sub);
448 BINARY_OP(Mul);
449 BINARY_OP(Div);
450 BINARY_OP(Rem);
451 BINARY_OP(Max);
452 BINARY_OP(Min);
453 BINARY_OP(And);
454 BINARY_OP(Or);
455 BINARY_OP(Xor);
456 BINARY_OP(ShiftLeft);
457 BINARY_OP(ShiftRightArithmetic);
458 BINARY_OP(ShiftRightLogical);
459 BINARY_OP(Atan2);
460 BINARY_OP(Pow);
461 BINARY_OP(Complex);
462 #undef BINARY_OP
463
464 #define UNARY_OP(op) ops.def(#op, &op)
465 UNARY_OP(Not);
466 UNARY_OP(PopulationCount);
467 UNARY_OP(Clz);
468 UNARY_OP(Abs);
469 UNARY_OP(Exp);
470 UNARY_OP(Expm1);
471 UNARY_OP(Floor);
472 UNARY_OP(Ceil);
473 UNARY_OP(Round);
474 UNARY_OP(Log);
475 UNARY_OP(Log1p);
476 UNARY_OP(Sign);
477 UNARY_OP(Cos);
478 UNARY_OP(Sin);
479 UNARY_OP(Tanh);
480 UNARY_OP(IsFinite);
481 UNARY_OP(Neg);
482 UNARY_OP(Sqrt);
483 UNARY_OP(Rsqrt);
484 UNARY_OP(Cbrt);
485 UNARY_OP(Square);
486 UNARY_OP(Reciprocal);
487 UNARY_OP(Erfc);
488 UNARY_OP(Erf);
489 UNARY_OP(ErfInv);
490 UNARY_OP(Lgamma);
491 UNARY_OP(Digamma);
492 UNARY_OP(BesselI0e);
493 UNARY_OP(BesselI1e);
494 UNARY_OP(Acos);
495 UNARY_OP(Asin);
496 UNARY_OP(Atan);
497 UNARY_OP(Tan);
498 UNARY_OP(Acosh);
499 UNARY_OP(Asinh);
500 UNARY_OP(Atanh);
501 UNARY_OP(Cosh);
502 UNARY_OP(Sinh);
503 UNARY_OP(Real);
504 UNARY_OP(Imag);
505 UNARY_OP(Conj);
506 UNARY_OP(OptimizationBarrier);
507 #undef UNARY_OP
508 }
509
510 } // namespace xla
511