xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/ops.h"
17 
18 #include <optional>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/types/span.h"
23 #include "pybind11/attr.h"
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/compiler/xla/client/lib/approx_topk.h"
26 #include "tensorflow/compiler/xla/client/lib/approx_topk_shape.h"
27 #include "tensorflow/compiler/xla/client/lib/comparators.h"
28 #include "tensorflow/compiler/xla/client/lib/lu_decomposition.h"
29 #include "tensorflow/compiler/xla/client/lib/math.h"
30 #include "tensorflow/compiler/xla/client/lib/qr.h"
31 #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
32 #include "tensorflow/compiler/xla/client/lib/sorting.h"
33 #include "tensorflow/compiler/xla/client/lib/svd.h"
34 #include "tensorflow/compiler/xla/client/xla_builder.h"
35 #include "tensorflow/compiler/xla/client/xla_computation.h"
36 #include "tensorflow/compiler/xla/python/types.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 
39 namespace xla {
40 
41 namespace py = pybind11;
42 
BuildOpsSubmodule(py::module * m)43 void BuildOpsSubmodule(py::module* m) {
44   // ops submodule, containing free functions that add operators to an
45   // XlaBuilder.
46   py::module ops = m->def_submodule("ops", "XLA operations");
47 
48   py::enum_<TriangularSolveOptions::Transpose>(
49       ops, "TriangularSolveOptions_Transpose")
50       .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID)
51       .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE)
52       .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
53       .value("ADJOINT", TriangularSolveOptions::ADJOINT);
54 
55   py::enum_<RandomAlgorithm>(ops, "RandomAlgorithm")
56       .value("RNG_DEFAULT", RandomAlgorithm::RNG_DEFAULT)
57       .value("RNG_THREE_FRY", RandomAlgorithm::RNG_THREE_FRY)
58       .value("RNG_PHILOX", RandomAlgorithm::RNG_PHILOX);
59 
60   py::enum_<CustomCallSchedule>(ops, "CustomCallSchedule")
61       .value("SCHEDULE_NONE", CustomCallSchedule::SCHEDULE_NONE)
62       .value("SCHEDULE_LATEST", CustomCallSchedule::SCHEDULE_LATEST)
63       .value("SCHEDULE_EARLIEST", CustomCallSchedule::SCHEDULE_EARLIEST);
64 
65   py::enum_<CustomCallApiVersion>(ops, "CustomCallApiVersion")
66       .value("API_VERSION_ORIGINAL", CustomCallApiVersion::API_VERSION_ORIGINAL)
67       .value("API_VERSION_STATUS_RETURNING",
68              CustomCallApiVersion::API_VERSION_STATUS_RETURNING)
69       .value("API_VERSION_STATUS_RETURNING_UNIFIED",
70              CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED);
71 
72   ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens"));
73   ops.def("AllGather", &AllGather, py::arg("operand"),
74           py::arg("all_gather_dimension"), py::arg("shard_count"),
75           py::arg("replica_groups") = py::list(),
76           py::arg("channel_id") = std::nullopt,
77           py::arg("shape_with_layout") = std::nullopt,
78           py::arg("use_global_device_ids") = std::nullopt);
79   ops.def("AllReduce",
80           static_cast<XlaOp (*)(
81               XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
82               const std::optional<ChannelHandle>&, const std::optional<Shape>&,
83               const std::optional<bool>)>(&AllReduce),
84           py::arg("operand"), py::arg("computation"),
85           py::arg("replica_groups") = py::list(),
86           py::arg("channel_id") = std::nullopt,
87           py::arg("shape_with_layout") = std::nullopt,
88           py::arg("use_global_device_ids") = std::nullopt);
89   ops.def("ReduceScatter", &ReduceScatter, py::arg("operand"),
90           py::arg("computation"), py::arg("scatter_dimension"),
91           py::arg("shard_count"), py::arg("replica_groups") = py::list(),
92           py::arg("channel_id") = std::nullopt,
93           py::arg("layout") = std::nullopt,
94           py::arg("use_global_device_ids") = std::nullopt);
95   ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
96           py::arg("concat_dimension"), py::arg("split_count"),
97           py::arg("replica_groups") = py::list(),
98           py::arg("layout") = std::nullopt);
99   ops.def("ApproxTopK", &ApproxTopK, py::arg("builder"), py::arg("operands"),
100           py::arg("init_values"), py::arg("top_k"), py::arg("reduction_dim"),
101           py::arg("comparator"), py::arg("recall_target") = 0.9,
102           py::arg("aggregate_to_topk") = true,
103           py::arg("reduction_input_size_override") = -1);
104   ops.def("ApproxTopKFallback", &ApproxTopKFallback, py::arg("builder"),
105           py::arg("operands"), py::arg("init_values"), py::arg("top_k"),
106           py::arg("reduction_dim"), py::arg("comparator"),
107           py::arg("recall_target") = 0.9, py::arg("aggregate_to_topk") = true,
108           py::arg("reduction_input_size_override") = -1);
109   ops.def("ApproxTopKReductionOutputSize", &ApproxTopKReductionOutputSize,
110           py::arg("input_size"), py::arg("rank"), py::arg("top_k"),
111           py::arg("recall_target"), py::arg("aggregate_to_topk") = true,
112           py::arg("input_size_override") = -1);
113   ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"),
114           py::arg("new_element_type"));
115   ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes"));
116   ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"),
117           py::arg("shape"), py::arg("broadcast_dimensions"));
118   ops.def("Call", &Call, py::arg("builder"), py::arg("computation"),
119           py::arg("operands"));
120   ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true);
121   ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max"));
122   ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions"));
123   ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
124           py::arg("source_target_pairs"));
125   ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"),
126           py::arg("dimension"));
127   ops.def("Conditional",
128           static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>,
129                                 absl::Span<const XlaOp>)>(&Conditional),
130           py::arg("branch_index"), py::arg("branch_computations"),
131           py::arg("branch_operands"));
132   ops.def("Conditional",
133           static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp,
134                                 const XlaComputation&)>(&Conditional),
135           py::arg("predicate"), py::arg("true_operand"),
136           py::arg("true_computation"), py::arg("false_operand"),
137           py::arg("false_computation"));
138   ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal"));
139   ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"),
140           py::arg("literal"));
141   ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"),
142           py::arg("rhs"), py::arg("window_strides"), py::arg("padding"),
143           py::arg("lhs_dilation"), py::arg("rhs_dilation"),
144           py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
145           py::arg("batch_group_count") = 1,
146           py::arg("precision_config") = nullptr,
147           py::arg("preferred_element_type") = std::nullopt,
148           py::arg("window_reversal") = std::nullopt);
149   ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
150           py::arg("new_element_type"));
151   ops.def("CreateToken", &CreateToken, py::arg("builder"));
152   ops.def("CrossReplicaSum",
153           static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
154               &CrossReplicaSum),
155           py::arg("operand"), py::arg("replica_groups") = py::list());
156   ops.def(
157       "CustomCall",
158       [](XlaBuilder* builder, const py::bytes& call_target_name,
159          absl::Span<const XlaOp> operands, const Shape& shape,
160          const py::bytes& opaque, bool has_side_effect,
161          CustomCallSchedule schedule,
162          CustomCallApiVersion api_version) -> XlaOp {
163         return CustomCall(builder, call_target_name, operands, shape, opaque,
164                           has_side_effect, /*output_operand_aliasing=*/{},
165                           /*literal=*/nullptr, schedule, api_version);
166       },
167       py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
168       py::arg("shape"), py::arg("opaque") = py::bytes(""),
169       py::arg("has_side_effect") = false,
170       py::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE,
171       py::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL);
172   ops.def(
173       "CustomCallWithLayout",
174       [](XlaBuilder* builder, const py::bytes& call_target_name,
175          absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
176          absl::Span<const Shape> operand_shapes_with_layout,
177          const py::bytes& opaque, bool has_side_effect,
178          CustomCallSchedule schedule,
179          CustomCallApiVersion api_version) -> XlaOp {
180         return CustomCallWithLayout(
181             builder, call_target_name, operands, shape_with_layout,
182             operand_shapes_with_layout, opaque, has_side_effect,
183             /*output_operand_aliasing=*/{},
184             /*literal=*/nullptr, schedule, api_version);
185       },
186       py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
187       py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
188       py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false,
189       py::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE,
190       py::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL);
191   ops.def(
192       "CustomCallWithAliasing",
193       [](XlaBuilder* builder, const py::bytes& call_target_name,
194          absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
195          absl::Span<const Shape> operand_shapes_with_layout,
196          const py::bytes& opaque, bool has_side_effect,
197          absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
198              output_operand_aliasing,
199          const Literal* literal, CustomCallSchedule schedule,
200          CustomCallApiVersion api_version) -> XlaOp {
201         return CustomCallWithLayout(
202             builder, call_target_name, operands, shape_with_layout,
203             operand_shapes_with_layout, opaque, has_side_effect,
204             output_operand_aliasing, literal, schedule, api_version);
205       },
206       py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
207       py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
208       py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false,
209       py::arg("output_operand_aliasing"), py::arg("literal") = nullptr,
210       py::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE,
211       py::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL);
212   ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
213           py::arg("precision_config") = nullptr,
214           py::arg("preferred_element_type") = std::nullopt);
215   ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
216           py::arg("dimension_numbers"), py::arg("precision_config") = nullptr,
217           py::arg("preferred_element_type") = std::nullopt);
218   ops.def("DynamicReshape",
219           static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
220                                 absl::Span<const int64_t>,
221                                 const std::vector<bool>&)>(&DynamicReshape),
222           py::arg("operand"), py::arg("dim_sizes"), py::arg("new_size_bounds"),
223           py::arg("dims_are_dynamic"));
224   ops.def("DynamicSlice",
225           static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
226                                 absl::Span<const int64_t>)>(&DynamicSlice),
227           py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes"));
228   ops.def("DynamicUpdateSlice",
229           static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
230               &DynamicUpdateSlice),
231           py::arg("operand"), py::arg("update"), py::arg("start_indices"));
232   ops.def(
233       "Eigh",
234       [](XlaOp a, bool lower, int64_t max_iter, float epsilon,
235          bool sort_eigenvalues) -> std::pair<XlaOp, XlaOp> {
236         auto eigh =
237             SelfAdjointEig(a, lower, max_iter, epsilon, sort_eigenvalues);
238         return std::make_pair(eigh.v, eigh.w);
239       },
240       py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 15,
241       py::arg("epsilon") = 1e-5, py::arg("sort_eigenvalues") = true);
242   ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"),
243           py::arg("fft_length"));
244   ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
245           py::arg("dimension_numbers"), py::arg("slice_sizes"),
246           py::arg("indices_are_sorted") = false);
247   ops.def("GetDimensionSize", &GetDimensionSize, py::arg("operand"),
248           py::arg("dimension"));
249   ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"),
250           py::arg("index"));
251   ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"),
252           py::arg("shape"), py::arg("config") = "");
253   ops.def("Iota",
254           static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64_t)>(&Iota),
255           py::arg("builder"), py::arg("shape"), py::arg("iota_dimension"));
256   ops.def("Iota",
257           static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64_t)>(&Iota),
258           py::arg("builder"), py::arg("type"), py::arg("size"));
259   ops.def(
260       "LU",
261       [](XlaOp a) -> StatusOr<std::tuple<XlaOp, XlaOp, XlaOp>> {
262         LuDecompositionResult lu = LuDecomposition(a);
263         return std::make_tuple(lu.lu, lu.pivots, lu.permutation);
264       },
265       py::arg("operand"));
266   ops.def("Map", &Map, py::arg("builder"), py::arg("operands"),
267           py::arg("computation"), py::arg("dimensions"),
268           py::arg("static_operands") = py::list());
269   ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to"));
270   ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
271           py::arg("token"), py::arg("shape_with_layout"),
272           py::arg("outfeed_config") = "");
273   ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"),
274           py::arg("padding_config"));
275   ops.def("Parameter",
276           static_cast<XlaOp (*)(XlaBuilder*, int64_t, const Shape&,
277                                 const std::string&, const std::vector<bool>&)>(
278               &Parameter),
279           py::arg("builder"), py::arg("parameter_number"), py::arg("shape"),
280           py::arg("name") = "",
281           py::arg("replicated_at_leaf_buffers") = std::vector<bool>());
282   ops.def("ProductOfElementaryHouseholderReflectors",
283           &ProductOfElementaryHouseholderReflectors, py::arg("a"),
284           py::arg("taus"));
285   ops.def(
286       "QR",
287       [](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
288         XlaOp q, r;
289         QrExplicit(a, full_matrices, q, r);
290         return std::make_pair(q, r);
291       },
292       py::arg("operand"), py::arg("full_matrices"));
293   ops.def(
294       "QrDecomposition",
295       [](XlaOp a) -> StatusOr<std::pair<XlaOp, XlaOp>> {
296         QrDecomposition d = Qr(a);
297         return std::make_pair(d.q_and_r, d.taus);
298       },
299       py::arg("operand"));
300   ops.def("RecvFromHost", &RecvFromHost, py::arg("token"), py::arg("shape"),
301           py::arg("handle"));
302   ops.def("Reduce",
303           static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>,
304                                 absl::Span<const XlaOp>, const XlaComputation&,
305                                 absl::Span<const int64_t>)>(&Reduce),
306           py::arg("builder"), py::arg("operands"), py::arg("init_values"),
307           py::arg("computation"), py::arg("dimensions_to_reduce"));
308   ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
309           py::arg("exponent_bits"), py::arg("mantissa_bits"));
310   ops.def("ReduceWindowWithGeneralPadding",
311           static_cast<XlaOp (*)(
312               XlaOp, XlaOp, const XlaComputation&, absl::Span<const int64_t>,
313               absl::Span<const int64_t>, absl::Span<const int64_t>,
314               absl::Span<const int64_t>,
315               absl::Span<const std::pair<int64_t, int64_t>>)>(
316               &ReduceWindowWithGeneralPadding),
317           py::arg("operand"), py::arg("init_value"), py::arg("computation"),
318           py::arg("window_dimensions"), py::arg("window_strides"),
319           py::arg("base_dilations"), py::arg("window_dilations"),
320           py::arg("padding"));
321   ops.def("ReduceWindowWithGeneralPadding",
322           static_cast<XlaOp (*)(
323               absl::Span<const XlaOp>, absl::Span<const XlaOp>,
324               const XlaComputation&, absl::Span<const int64_t>,
325               absl::Span<const int64_t>, absl::Span<const int64_t>,
326               absl::Span<const int64_t>,
327               absl::Span<const std::pair<int64_t, int64_t>>)>(
328               &ReduceWindowWithGeneralPadding),
329           py::arg("operands"), py::arg("init_values"), py::arg("computation"),
330           py::arg("window_dimensions"), py::arg("window_strides"),
331           py::arg("base_dilations"), py::arg("window_dilations"),
332           py::arg("padding"));
333   ops.def("RemoveDynamicDimension", &RemoveDynamicDimension, py::arg("operand"),
334           py::arg("dimension"));
335   ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
336   ops.def("Reshape",
337           static_cast<XlaOp (*)(XlaOp, absl::Span<const int64_t>,
338                                 absl::Span<const int64_t>)>(&Reshape),
339           py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes"));
340   ops.def("Reshape",
341           static_cast<XlaOp (*)(XlaOp, absl::Span<const int64_t>)>(&Reshape),
342           py::arg("operand"), py::arg("new_sizes"));
343   ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions"));
344   ops.def("RngBitGenerator", &RngBitGenerator, py::arg("algorithm"),
345           py::arg("initial_state"), py::arg("shape"));
346   ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"),
347           py::arg("shape"));
348   ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"),
349           py::arg("shape"));
350   ops.def("Scatter",
351           static_cast<XlaOp (*)(XlaOp, XlaOp, XlaOp, const XlaComputation&,
352                                 const ScatterDimensionNumbers&, bool, bool)>(
353               &Scatter),
354           py::arg("input"), py::arg("scatter_indices"), py::arg("updates"),
355           py::arg("update_computation"), py::arg("dimension_numbers"),
356           py::arg("indices_are_sorted") = false,
357           py::arg("unique_indices") = false);
358   ops.def("Scatter",
359           static_cast<XlaOp (*)(absl::Span<const XlaOp>, XlaOp,
360                                 absl::Span<const XlaOp>, const XlaComputation&,
361                                 const ScatterDimensionNumbers&, bool, bool)>(
362               &Scatter),
363           py::arg("inputs"), py::arg("scatter_indices"), py::arg("updates"),
364           py::arg("update_computation"), py::arg("dimension_numbers"),
365           py::arg("indices_are_sorted") = false,
366           py::arg("unique_indices") = false);
367   ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"),
368           py::arg("on_false"));
369   ops.def("SelectAndScatterWithGeneralPadding",
370           &SelectAndScatterWithGeneralPadding, py::arg("operand"),
371           py::arg("select"), py::arg("window_dimensions"),
372           py::arg("window_strides"), py::arg("padding"), py::arg("source"),
373           py::arg("init_value"), py::arg("scatter"));
374   ops.def("SendToHost", &SendToHost, py::arg("operand"), py::arg("token"),
375           py::arg("shape_with_layout"), py::arg("handle"));
376   ops.def("SetDimensionSize", &SetDimensionSize, py::arg("operand"),
377           py::arg("val"), py::arg("dimension"));
378   ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"),
379           py::arg("limit_indices"), py::arg("strides"));
380   ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"),
381           py::arg("limit_index"), py::arg("stride"), py::arg("dimno"));
382   ops.def(
383       "Sort",
384       [](XlaBuilder* builder, absl::Span<const XlaOp> operands,
385          std::optional<const XlaComputation*> comparator, int64_t dimension,
386          bool is_stable) -> XlaOp {
387         return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
388           std::vector<PrimitiveType> operand_types;
389           operand_types.reserve(operands.size());
390           for (const auto& operand : operands) {
391             TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand));
392             operand_types.push_back(operand_shape.element_type());
393           }
394 
395           if (comparator) {
396             return Sort(operands, **comparator, dimension, is_stable);
397           } else {
398             return Sort(operands,
399                         CreateScalarLtComputation(operand_types, builder),
400                         dimension, is_stable);
401           }
402         });
403       },
404       py::arg("builder"), py::arg("operands"),
405       py::arg("comparator") = std::nullopt, py::arg("dimension") = -1,
406       py::arg("is_stable") = false);
407   ops.def(
408       "SVD",
409       [](XlaOp a, int64_t max_iter,
410          float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> {
411         auto svd = SVD(a, max_iter, epsilon);
412         return std::make_tuple(svd.u, svd.d, svd.v);
413       },
414       py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6);
415   ops.def("TopK", &TopK, py::arg("input"), py::arg("k"));
416   ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation"));
417   ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"),
418           py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"),
419           py::arg("transpose_a"));
420   ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements"));
421   ops.def("While", &While, py::arg("condition"), py::arg("body"),
422           py::arg("init"));
423 
424   ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x"));
425   ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x"));
426   ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x"));
427   ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
428   ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
429           py::arg("b"), py::arg("x"));
430   ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q"));
431 
432 #define BINARY_OP(op)                                                  \
433   ops.def(                                                             \
434       #op,                                                             \
435       [](XlaOp a, XlaOp b, std::optional<std::vector<int64_t>> dims) { \
436         return dims ? op(a, b, *dims) : op(a, b);                      \
437       },                                                               \
438       py::arg("lhs"), py::arg("rhs"),                                  \
439       py::arg("broadcast_dimensions") = std::nullopt)
440   BINARY_OP(Eq);
441   BINARY_OP(Ne);
442   BINARY_OP(Ge);
443   BINARY_OP(Gt);
444   BINARY_OP(Lt);
445   BINARY_OP(Le);
446   BINARY_OP(Add);
447   BINARY_OP(Sub);
448   BINARY_OP(Mul);
449   BINARY_OP(Div);
450   BINARY_OP(Rem);
451   BINARY_OP(Max);
452   BINARY_OP(Min);
453   BINARY_OP(And);
454   BINARY_OP(Or);
455   BINARY_OP(Xor);
456   BINARY_OP(ShiftLeft);
457   BINARY_OP(ShiftRightArithmetic);
458   BINARY_OP(ShiftRightLogical);
459   BINARY_OP(Atan2);
460   BINARY_OP(Pow);
461   BINARY_OP(Complex);
462 #undef BINARY_OP
463 
464 #define UNARY_OP(op) ops.def(#op, &op)
465   UNARY_OP(Not);
466   UNARY_OP(PopulationCount);
467   UNARY_OP(Clz);
468   UNARY_OP(Abs);
469   UNARY_OP(Exp);
470   UNARY_OP(Expm1);
471   UNARY_OP(Floor);
472   UNARY_OP(Ceil);
473   UNARY_OP(Round);
474   UNARY_OP(Log);
475   UNARY_OP(Log1p);
476   UNARY_OP(Sign);
477   UNARY_OP(Cos);
478   UNARY_OP(Sin);
479   UNARY_OP(Tanh);
480   UNARY_OP(IsFinite);
481   UNARY_OP(Neg);
482   UNARY_OP(Sqrt);
483   UNARY_OP(Rsqrt);
484   UNARY_OP(Cbrt);
485   UNARY_OP(Square);
486   UNARY_OP(Reciprocal);
487   UNARY_OP(Erfc);
488   UNARY_OP(Erf);
489   UNARY_OP(ErfInv);
490   UNARY_OP(Lgamma);
491   UNARY_OP(Digamma);
492   UNARY_OP(BesselI0e);
493   UNARY_OP(BesselI1e);
494   UNARY_OP(Acos);
495   UNARY_OP(Asin);
496   UNARY_OP(Atan);
497   UNARY_OP(Tan);
498   UNARY_OP(Acosh);
499   UNARY_OP(Asinh);
500   UNARY_OP(Atanh);
501   UNARY_OP(Cosh);
502   UNARY_OP(Sinh);
503   UNARY_OP(Real);
504   UNARY_OP(Imag);
505   UNARY_OP(Conj);
506   UNARY_OP(OptimizationBarrier);
507 #undef UNARY_OP
508 }
509 
510 }  // namespace xla
511