xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/xla_compiler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/xla_compiler.h"
17 
18 #include <cstdint>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/hash/hash.h"
25 #include "absl/synchronization/mutex.h"
26 #include "absl/types/span.h"
27 #include "pybind11/attr.h"
28 #include "pybind11/cast.h"
29 #include "pybind11/numpy.h"
30 #include "pybind11/pybind11.h"
31 #include "pybind11/pytypes.h"
32 #include "pybind11/stl_bind.h"
33 #include "tensorflow/compiler/xla/client/executable_build_options.h"
34 #include "tensorflow/compiler/xla/client/xla_builder.h"
35 #include "tensorflow/compiler/xla/client/xla_computation.h"
36 #include "tensorflow/compiler/xla/debug_options_flags.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/python/py_client.h"
39 #include "tensorflow/compiler/xla/python/types.h"
40 #include "tensorflow/compiler/xla/service/computation_placer.h"
41 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
42 #include "tensorflow/compiler/xla/service/hlo.pb.h"
43 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
44 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
45 #include "tensorflow/compiler/xla/service/hlo_module.h"
46 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
47 #include "tensorflow/compiler/xla/service/hlo_parser.h"
48 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
49 #include "tensorflow/compiler/xla/service/name_uniquer.h"
50 #include "tensorflow/compiler/xla/service/platform_util.h"
51 #include "tensorflow/compiler/xla/shape.h"
52 #include "tensorflow/compiler/xla/shape_util.h"
53 #include "tensorflow/compiler/xla/statusor.h"
54 #include "tensorflow/compiler/xla/util.h"
55 #include "tensorflow/compiler/xla/xla.pb.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/lib/strings/proto_serialization.h"
58 
59 namespace xla {
60 namespace {
61 
62 namespace py = pybind11;
63 
64 struct Uniquer {
65   absl::Mutex mu;
66   NameUniquer name_uniquer ABSL_GUARDED_BY(mu);
67 };
68 
GetUniquer()69 Uniquer* GetUniquer() {
70   static Uniquer* uniquer = new Uniquer;
71   return uniquer;
72 }
73 
UniquifyName(const std::string & name)74 static std::string UniquifyName(const std::string& name) {
75   Uniquer* uniquer = GetUniquer();
76   absl::MutexLock lock(&uniquer->mu);
77   return uniquer->name_uniquer.GetUniqueName(name);
78 }
79 
80 // Converts a computation to a serialized HloModuleProto.
GetComputationSerializedProto(const XlaComputation & computation)81 StatusOr<py::bytes> GetComputationSerializedProto(
82     const XlaComputation& computation) {
83   std::string result;
84   if (!tensorflow::SerializeToStringDeterministic(computation.proto(),
85                                                   &result)) {
86     return Unknown("Failed to serialize the HloModuleProto.");
87   }
88   return py::bytes(result);
89 }
90 
91 // Converts a hlo module to a serialized HloModuleProto.
GetHloModuleSerializedProto(const HloModule & module)92 StatusOr<py::bytes> GetHloModuleSerializedProto(const HloModule& module) {
93   std::string result;
94   if (!tensorflow::SerializeToStringDeterministic(module.ToProto(), &result)) {
95     return Unknown("Failed to serialize the HloModuleProto.");
96   }
97   return py::bytes(result);
98 }
99 
100 // Converts a serialized HloModuleProto into a HloModule.
HloModuleFromSerializedProto(const py::bytes & bytes)101 StatusOr<std::shared_ptr<HloModule>> HloModuleFromSerializedProto(
102     const py::bytes& bytes) {
103   HloModuleProto proto;
104   proto.ParseFromString(bytes);
105   TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
106                       HloModule::CreateModuleConfigFromProto(
107                           proto, GetDebugOptionsFromFlags()));
108   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
109                       HloModule::CreateFromProto(proto, module_config));
110   return std::shared_ptr<HloModule>(std::move(module));
111 }
112 
GetHloModule(const XlaComputation & computation)113 StatusOr<std::shared_ptr<HloModule>> GetHloModule(
114     const XlaComputation& computation) {
115   TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
116                       HloModule::CreateModuleConfigFromProto(
117                           computation.proto(), GetDebugOptionsFromFlags()));
118   TF_ASSIGN_OR_RETURN(
119       std::unique_ptr<HloModule> module,
120       HloModule::CreateFromProto(computation.proto(), module_config));
121   return std::shared_ptr<HloModule>(std::move(module));
122 }
123 
124 // Converts a computation to textual HLO form.
GetComputationHloText(const XlaComputation & computation,bool print_large_constants=false)125 StatusOr<std::string> GetComputationHloText(
126     const XlaComputation& computation, bool print_large_constants = false) {
127   TF_ASSIGN_OR_RETURN(std::shared_ptr<HloModule> hlo_module,
128                       GetHloModule(computation));
129   HloPrintOptions options;
130   options = HloPrintOptions::ShortParsable();
131   options.set_print_large_constants(print_large_constants);
132   return hlo_module->ToString(options);
133 }
134 
135 // Converts a computation to HLO dot graph form.
GetComputationHloDotGraph(const XlaComputation & computation)136 StatusOr<std::string> GetComputationHloDotGraph(
137     const XlaComputation& computation) {
138   TF_ASSIGN_OR_RETURN(std::shared_ptr<HloModule> hlo_module,
139                       GetHloModule(computation));
140   return RenderGraph(*hlo_module->entry_computation(), /*label=*/"",
141                      hlo_module->config().debug_options(),
142                      RenderedGraphFormat::kDot);
143 }
144 
145 // Hashes the HLO module.
HashComputation(const XlaComputation & computation)146 StatusOr<uint64_t> HashComputation(const XlaComputation& computation) {
147   TF_ASSIGN_OR_RETURN(std::shared_ptr<HloModule> hlo_module,
148                       GetHloModule(computation));
149   return absl::HashOf(*hlo_module);
150 }
151 // Safe version of ShapeUtil::MakeShapeWithLayout that fails gracefully on
152 // invalid input.
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64_t> dims,std::optional<absl::Span<const int64_t>> minor_to_major,std::optional<const std::vector<bool>> dynamic_dimensions)153 StatusOr<Shape> MakeShapeWithLayout(
154     PrimitiveType element_type, absl::Span<const int64_t> dims,
155     std::optional<absl::Span<const int64_t>> minor_to_major,
156     std::optional<const std::vector<bool>> dynamic_dimensions) {
157   Shape shape;
158   if (dynamic_dimensions) {
159     TF_ASSIGN_OR_RETURN(
160         shape, ShapeUtil::MakeValidatedShape(element_type, dims,
161                                              dynamic_dimensions.value()));
162   } else {
163     TF_ASSIGN_OR_RETURN(shape,
164                         ShapeUtil::MakeValidatedShape(element_type, dims));
165   }
166   if (minor_to_major) {
167     *shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major);
168     TF_RETURN_IF_ERROR(
169         LayoutUtil::ValidateLayoutForShape(shape.layout(), shape));
170   } else {
171     shape.clear_layout();
172   }
173   return shape;
174 }
175 
176 // Registers a 'fn_capsule' as a CPU custom call target.
177 // 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object,
178 // with name "xla._CUSTOM_CALL_TARGET".
179 // 'platform' is an XLA platform name, e.g., "Host" or "CUDA".
PyRegisterCustomCallTarget(const std::string & fn_name,py::capsule capsule,const std::string & platform)180 Status PyRegisterCustomCallTarget(const std::string& fn_name,
181                                   py::capsule capsule,
182                                   const std::string& platform) {
183   static const char* const kName = "xla._CUSTOM_CALL_TARGET";
184   // TODO(phawkins): remove old name after fixing users.
185   static const char* const kOldCpuName = "xla._CPU_CUSTOM_CALL_TARGET";
186   if (absl::string_view(capsule.name()) != kName &&
187       absl::string_view(capsule.name()) != kOldCpuName) {
188     return InvalidArgument(
189         "Argument to RegisterCustomCallTargetRegistry was not a "
190         "xla._CUSTOM_CALL_TARGET capsule.");
191   }
192   CustomCallTargetRegistry::Global()->Register(
193       fn_name, static_cast<void*>(capsule), platform);
194   return OkStatus();
195 }
196 
197 template <typename T, typename Container>
DefRepeatedProperty(py::class_<T> & cls,const char * name,Container * (T::* getter)())198 void DefRepeatedProperty(py::class_<T>& cls, const char* name,
199                          Container* (T::*getter)()) {
200   cls.def_property(
201       name,
202       [getter](T& obj) {
203         Container* elems = (obj.*getter)();
204         std::vector<typename Container::value_type> result;
205         result.reserve(elems->size());
206         std::copy(elems->begin(), elems->end(), std::back_inserter(result));
207         return result;
208       },
209       [getter](T& obj, std::vector<typename Container::value_type> new_elems) {
210         Container* elems = (obj.*getter)();
211         elems->Clear();
212         elems->Reserve(new_elems.size());
213         for (typename Container::value_type& e : new_elems) {
214           elems->Add(std::move(e));
215         }
216       });
217 }
218 
219 }  // namespace
220 
BuildXlaCompilerSubmodule(py::module & m)221 void BuildXlaCompilerSubmodule(py::module& m) {
222   // Shapes
223   py::class_<Layout> layout_class(m, "Layout");
224   layout_class
225       .def("minor_to_major",
226            [](Layout layout) { return SpanToTuple(layout.minor_to_major()); })
227       .def("__eq__", [](const Layout& layout,
228                         const Layout& other) { return layout == other; })
229       .def("__ne__", [](const Layout& layout,
230                         const Layout& other) { return layout != other; })
231       .def("__hash__",
232            [](const Layout& layout) { return absl::HashOf(layout); })
233       .def("to_string", &Layout::ToString);
234 
235   py::class_<Shape> shape_class(m, "Shape");
236   shape_class
237       .def(py::init([](const std::string& s) {
238         return std::make_unique<Shape>(ValueOrThrow(ParseShape(s)));
239       }))
240       .def_static(
241           "tuple_shape",
242           [](std::vector<Shape> shapes) -> Shape {
243             return ShapeUtil::MakeTupleShape(shapes);
244           },
245           "Constructs a tuple shape.")
246       .def_static(
247           "array_shape",
248           [](PrimitiveType type, py::object dims_seq,
249              std::optional<py::object> layout_seq,
250              std::optional<std::vector<bool>> dynamic_dimensions)
251               -> StatusOr<Shape> {
252             std::vector<int64_t> dims = SequenceToVector<int64_t>(dims_seq);
253             if (layout_seq) {
254               std::vector<int64_t> layout =
255                   SequenceToVector<int64_t>(*layout_seq);
256               return MakeShapeWithLayout(type, dims, layout,
257                                          dynamic_dimensions);
258             } else {
259               return MakeShapeWithLayout(type, dims, std::nullopt,
260                                          dynamic_dimensions);
261             }
262           },
263           "Constructs an array shape.", py::arg("type"), py::arg("dims"),
264           py::arg("layout") = std::nullopt,
265           py::arg("dynamic_dimensions") = std::nullopt)
266       .def_static(
267           "array_shape",
268           [](py::dtype dtype, py::object dims_seq,
269              std::optional<py::object> layout_seq,
270              std::optional<std::vector<bool>> dynamic_dimensions)
271               -> StatusOr<Shape> {
272             PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype));
273             std::vector<int64_t> dims = SequenceToVector<int64_t>(dims_seq);
274             if (layout_seq) {
275               std::vector<int64_t> layout =
276                   SequenceToVector<int64_t>(*layout_seq);
277               return MakeShapeWithLayout(type, dims, layout,
278                                          dynamic_dimensions);
279             } else {
280               return MakeShapeWithLayout(type, dims, std::nullopt,
281                                          dynamic_dimensions);
282             }
283           },
284           "Constructs an array shape.", py::arg("type"), py::arg("dims"),
285           py::arg("layout") = std::nullopt,
286           py::arg("dynamic_dimensions") = std::nullopt)
287       .def_static("token_shape", []() { return ShapeUtil::MakeTokenShape(); })
288       .def_static(
289           "scalar_shape",
290           [](PrimitiveType type) -> Shape {
291             return ShapeUtil::MakeScalarShape(type);
292           },
293           "Constructs a scalar shape.", py::arg("type"))
294       .def_static(
295           "scalar_shape",
296           [](py::dtype dtype) -> StatusOr<Shape> {
297             PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype));
298             return ShapeUtil::MakeScalarShape(type);
299           },
300           "Constructs a scalar shape.", py::arg("type"))
301       .def("dimensions",
302            [](const Shape& shape) -> py::tuple {
303              return SpanToTuple(shape.dimensions());
304            })
305       .def("layout",
306            [](const Shape& shape) -> Layout { return shape.layout(); })
307       .def("xla_element_type", &Shape::element_type)
308       .def("element_type",
309            [](const Shape& shape) {
310              return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type()));
311            })
312       .def("numpy_dtype",
313            [](const Shape& shape) {
314              if (shape.IsTuple()) {
315                return py::dtype("O");
316              }
317              return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type()));
318            })
319       .def("is_tuple", &Shape::IsTuple)
320       .def("is_array", &Shape::IsArray)
321       .def("is_token", &Shape::IsToken)
322       .def("is_static", &Shape::is_static)
323       .def("is_dynamic", &Shape::is_dynamic)
324       .def("is_dynamic_dimension", &Shape::is_dynamic_dimension,
325            py::arg("dimension"))
326       .def("set_dynamic_dimension", &Shape::set_dynamic_dimension,
327            py::arg("dimension"), py::arg("is_dynamic"))
328       .def("rank", &Shape::rank)
329       .def("to_serialized_proto",
330            [](const Shape& shape) {
331              ShapeProto proto = shape.ToProto();
332              return py::bytes(proto.SerializeAsString());
333            })
334       .def("tuple_shapes",
335            [](const Shape& shape) {
336              return std::vector<Shape>(shape.tuple_shapes());
337            })
338       .def("leaf_count",
339            [](const Shape& shape) { return ShapeUtil::GetLeafCount(shape); })
340       .def(
341           "with_major_to_minor_layout_if_absent",
342           [](const Shape& shape) {
343             Shape out = shape;
344             ShapeUtil::ForEachMutableSubshape(
345                 &out, [](Shape* subshape, const ShapeIndex&) {
346                   if (!subshape->has_layout()) {
347                     LayoutUtil::SetToDefaultLayout(subshape);
348                   }
349                 });
350             return out;
351           },
352           "Returns a copy of a shape with missing layouts set to "
353           "major-to-minor.")
354       .def("__eq__", [](const Shape& shape,
355                         const Shape& other) { return shape == other; })
356       .def("__ne__", [](const Shape& shape,
357                         const Shape& other) { return shape != other; })
358       .def("__hash__", [](const Shape& shape) { return absl::HashOf(shape); })
359       .def("__repr__", [](const Shape& shape) {
360         return shape.ToString(/*print_layout=*/true);
361       });
362 
363   py::class_<ProgramShape>(m, "ProgramShape")
364       .def(py::init(
365           [](absl::Span<const Shape> params, Shape result) -> ProgramShape {
366             ProgramShape program_shape;
367             for (const Shape& param : params) {
368               *program_shape.add_parameters() = param;
369             }
370             *program_shape.mutable_result() = result;
371             return program_shape;
372           }))
373       .def("parameter_shapes",
374            static_cast<const std::vector<Shape>& (ProgramShape::*)() const>(
375                &ProgramShape::parameters))
376       .def("result_shape", &ProgramShape::result)
377       .def("__repr__", &ProgramShape::ToString);
378 
379   py::class_<ShapeIndex>(m, "ShapeIndex")
380       .def(py::init([](const std::vector<int64_t>& v) {
381         return std::make_unique<ShapeIndex>(v.begin(), v.end());
382       }))
383       .def("__repr__", &ShapeIndex::ToString)
384       .def("__eq__", [](const ShapeIndex& shape_ind,
385                         const ShapeIndex& other) { return shape_ind == other; })
386       .def("__ne__", [](const ShapeIndex& shape_ind,
387                         const ShapeIndex& other) { return shape_ind != other; })
388       .def("__hash__",
389            [](const ShapeIndex& shape_ind) { return absl::HashOf(shape_ind); });
390 
391   // Literals
392   py::class_<Literal, std::shared_ptr<Literal>>(m, "Literal")
393       .def("__repr__", &Literal::ToString);
394 
395   py::class_<XlaComputation>(m, "XlaComputation")
396       .def(py::init([](const py::bytes& serialized_hlo_module_proto)
397                         -> std::unique_ptr<XlaComputation> {
398         HloModuleProto proto;
399         proto.ParseFromString(std::string(serialized_hlo_module_proto));
400         return std::make_unique<XlaComputation>(proto);
401       }))
402       .def("get_hlo_module", &GetHloModule)
403       .def("program_shape", &XlaComputation::GetProgramShape)
404       .def("name", &XlaComputation::name)
405       .def("as_serialized_hlo_module_proto", &GetComputationSerializedProto)
406       .def("as_hlo_text", &GetComputationHloText,
407            py::arg("print_large_constants") = false)
408       .def("as_hlo_dot_graph", &GetComputationHloDotGraph)
409       .def("hash", &HashComputation)
410       .def("as_hlo_module", &GetHloModule);
411 
412   py::class_<HloPrintOptions> hlo_print_options_class(m, "HloPrintOptions");
413   hlo_print_options_class.def(py::init<>())
414       .def_static("short_parsable", &HloPrintOptions::ShortParsable)
415       .def_static("canonical", &HloPrintOptions::Canonical)
416       .def_static("fingerprint", &HloPrintOptions::Fingerprint)
417       .def_property("print_large_constants",
418                     &HloPrintOptions::print_large_constants,
419                     &HloPrintOptions::set_print_large_constants)
420       .def_property("print_metadata", &HloPrintOptions::print_metadata,
421                     &HloPrintOptions::set_print_metadata)
422       .def_property("print_backend_config",
423                     &HloPrintOptions::print_backend_config,
424                     &HloPrintOptions::set_print_backend_config)
425       .def_property("print_result_shape", &HloPrintOptions::print_result_shape,
426                     &HloPrintOptions::set_print_result_shape)
427       .def_property("print_operand_shape",
428                     &HloPrintOptions::print_operand_shape,
429                     &HloPrintOptions::set_print_operand_shape)
430       .def_property("print_operand_names",
431                     &HloPrintOptions::print_operand_names,
432                     &HloPrintOptions::set_print_operand_names)
433       .def_property("print_ids", &HloPrintOptions::print_ids,
434                     &HloPrintOptions::set_print_ids)
435       .def_property("print_extra_attributes",
436                     &HloPrintOptions::print_extra_attributes,
437                     &HloPrintOptions::set_print_extra_attributes)
438       .def_property("print_program_shape",
439                     &HloPrintOptions::print_program_shape,
440                     &HloPrintOptions::set_print_program_shape)
441       .def_property("print_percent", &HloPrintOptions::print_percent,
442                     &HloPrintOptions::set_print_percent)
443       .def_property("print_control_dependencies",
444                     &HloPrintOptions::print_control_dependencies,
445                     &HloPrintOptions::set_print_control_dependencies)
446       .def_property("compact_operands", &HloPrintOptions::compact_operands,
447                     &HloPrintOptions::set_compact_operands)
448       .def_property("include_layout_in_shapes",
449                     &HloPrintOptions::include_layout_in_shapes,
450                     &HloPrintOptions::set_include_layout_in_shapes)
451       .def_property("canonicalize_instruction_names",
452                     &HloPrintOptions::canonicalize_instruction_names,
453                     &HloPrintOptions::set_canonicalize_instruction_names)
454       .def_property("canonicalize_computations",
455                     &HloPrintOptions::canonicalize_computations,
456                     &HloPrintOptions::set_canonicalize_computations)
457       .def_property("indent_amount", &HloPrintOptions::indent_amount,
458                     &HloPrintOptions::set_indent_amount)
459       .def_property("is_in_nested_computation",
460                     &HloPrintOptions::is_in_nested_computation,
461                     &HloPrintOptions::set_is_in_nested_computation);
462 
463   py::class_<HloModule, std::shared_ptr<HloModule>> hlo_module_class(
464       m, "HloModule");
465   hlo_module_class.def_property_readonly("name", &HloModule::name)
466       .def(
467           "to_string",
468           static_cast<std::string (HloModule::*)(const HloPrintOptions&) const>(
469               &HloModule::ToString),
470           py::arg("options") = HloPrintOptions())
471       .def("as_serialized_hlo_module_proto", &GetHloModuleSerializedProto)
472       .def("from_serialized_hlo_module_proto", &HloModuleFromSerializedProto)
473       .def_property_readonly(
474           "spmd_output_sharding",
475           [](const HloModule& m) -> std::optional<xla::OpSharding> {
476             if (!m.has_spmd_output_sharding()) return std::nullopt;
477             return m.spmd_output_sharding().ToProto();
478           })
479       .def_property_readonly(
480           "spmd_parameters_shardings",
481           [](const HloModule& m)
482               -> std::optional<std::vector<xla::OpSharding>> {
483             if (!m.has_spmd_parameters_shardings()) return std::nullopt;
484             std::vector<xla::OpSharding> param_shardings;
485             for (const auto& parameter_sharding :
486                  m.spmd_parameters_shardings()) {
487               param_shardings.push_back(parameter_sharding.ToProto());
488             }
489             return param_shardings;
490           });
491 
492   m.def("hlo_module_to_dot_graph",
493         [](const HloModule& hlo_module) -> StatusOr<std::string> {
494           return RenderGraph(*hlo_module.entry_computation(), /*label=*/"",
495                              hlo_module.config().debug_options(),
496                              RenderedGraphFormat::kDot);
497         });
498   m.def(
499       "hlo_module_cost_analysis",
500       [](PyClient* client,
501          const HloModule& module) -> StatusOr<HloCostAnalysis::Properties> {
502         TF_ASSIGN_OR_RETURN(auto analysis,
503                             client->pjrt_client()->GetHloCostAnalysis());
504         TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get()));
505         return analysis->properties();
506       });
507 
508   py::class_<XlaOp> xla_op_class(m, "XlaOp");
509 
510   py::class_<XlaBuilder>(m, "XlaBuilder")
511       .def(py::init([](const std::string& name) -> std::unique_ptr<XlaBuilder> {
512         return std::make_unique<XlaBuilder>(UniquifyName(name));
513       }))
514       // TODO(phawkins): delete capitalized names after updating callers.
515       .def(
516           "Build",
517           [](XlaBuilder& builder, std::optional<XlaOp> root) {
518             return root ? builder.Build(*root) : builder.Build();
519           },
520           "Builds a computation from the contents of the builder.",
521           py::arg("root") = std::nullopt)
522       .def("GetShape", &XlaBuilder::GetShape)
523       .def(
524           "build",
525           [](XlaBuilder& builder, std::optional<XlaOp> root) {
526             return root ? builder.Build(*root) : builder.Build();
527           },
528           "Builds a computation from the contents of the builder.",
529           py::arg("root") = std::nullopt)
530       .def("clear_op_metadata", &XlaBuilder::ClearOpMetadata)
531       .def("get_shape", &XlaBuilder::GetShape)
532       .def(
533           "get_program_shape",
534           [](const XlaBuilder& builder,
535              std::optional<XlaOp> root) -> StatusOr<ProgramShape> {
536             return root ? builder.GetProgramShape(*root)
537                         : builder.GetProgramShape();
538           },
539           py::arg("root") = std::nullopt)
540       .def("is_constant", &XlaBuilder::IsConstant)
541       .def("set_op_metadata", &XlaBuilder::SetOpMetadata)
542       .def("set_sharding", &XlaBuilder::SetSharding)
543       .def("clear_sharding", &XlaBuilder::ClearSharding)
544       .def("set_frontend_attributes", &XlaBuilder::SetFrontendAttributes)
545       .def("clear_frontend_attributes", &XlaBuilder::ClearFrontendAttributes)
546       .def("setup_alias",
547            [](XlaBuilder& builder, const std::vector<int64_t>& output_index,
548               int64_t param_number, const std::vector<int64_t>& param_index) {
549              builder.SetUpAlias(
550                  ShapeIndex(output_index.begin(), output_index.end()),
551                  param_number,
552                  ShapeIndex(param_index.begin(), param_index.end()));
553            });
554 
555   // Device assignments
556   py::class_<DeviceAssignment>(m, "DeviceAssignment")
557       .def_static("create",
558                   [](py::array_t<int> array) -> StatusOr<DeviceAssignment> {
559                     if (array.ndim() != 2) {
560                       return InvalidArgument(
561                           "Argument to DeviceAssignment constructor must be a "
562                           "2D array, received an %dD array.",
563                           array.ndim());
564                     }
565                     DeviceAssignment result(array.shape(0), array.shape(1));
566                     for (int i = 0; i < array.shape(0); ++i) {
567                       for (int j = 0; j < array.shape(1); ++j) {
568                         result(i, j) = array.at(i, j);
569                       }
570                     }
571                     return result;
572                   })
573       .def("replica_count", &DeviceAssignment::replica_count)
574       .def("computation_count", &DeviceAssignment::computation_count)
575       .def("__repr__", &DeviceAssignment::ToString)
576       .def("serialize", [](const DeviceAssignment& da) -> StatusOr<py::bytes> {
577         DeviceAssignmentProto proto;
578         TF_RETURN_IF_ERROR(da.Serialize(&proto));
579         std::string result;
580         if (!tensorflow::SerializeToStringDeterministic(proto, &result)) {
581           return Unknown("Failed to serialize the DeviceAssignmentProto.");
582         }
583         return py::bytes(result);
584       });
585 
586   py::class_<CompileOptions> compile_options(m, "CompileOptions");
587   compile_options
588       .def(py::init([]() -> CompileOptions {
589         CompileOptions options;
590         DebugOptions* debug_options =
591             options.executable_build_options.mutable_debug_options();
592         // Sets fast-math-disabling default options expected by JAX.
593         debug_options->set_xla_cpu_enable_fast_min_max(false);
594         debug_options->set_xla_gpu_enable_fast_min_max(false);
595         return options;
596       }))
597       .def_readwrite("argument_layouts", &CompileOptions::argument_layouts)
598       .def_readwrite("parameter_is_tupled_arguments",
599                      &CompileOptions::parameter_is_tupled_arguments)
600       .def_readwrite("compile_portable_executable",
601                      &CompileOptions::compile_portable_executable)
602       .def_readonly("executable_build_options",
603                     &CompileOptions::executable_build_options)
604       // TODO(phawkins): the following fields exist for backward compatibility.
605       // Remove them after JAX has been updated not to use them.
606       .def_readwrite("tuple_arguments",
607                      &CompileOptions::parameter_is_tupled_arguments)
608       .def_property(
609           "num_replicas",
610           [](const CompileOptions& options) {
611             return options.executable_build_options.num_replicas();
612           },
613           [](CompileOptions& options, int num_replicas) {
614             options.executable_build_options.set_num_replicas(num_replicas);
615           })
616       .def_property(
617           "num_partitions",
618           [](const CompileOptions& options) {
619             return options.executable_build_options.num_partitions();
620           },
621           [](CompileOptions& options, int num_partitions) {
622             options.executable_build_options.set_num_partitions(num_partitions);
623           })
624       .def_property(
625           "profile_version",
626           [](const CompileOptions& options) { return options.profile_version; },
627           [](CompileOptions& options, int64_t profile_version) {
628             options.profile_version = profile_version;
629           })
630       .def_property(
631           "device_assignment",
632           [](const CompileOptions& options) -> std::optional<DeviceAssignment> {
633             return options.executable_build_options.has_device_assignment()
634                        ? std::optional<DeviceAssignment>(
635                              options.executable_build_options
636                                  .device_assignment())
637                        : std::nullopt;
638           },
639           [](CompileOptions& options,
640              const DeviceAssignment& device_assignment) {
641             options.executable_build_options.set_device_assignment(
642                 device_assignment);
643           });
644 
645   // Custom-call targets.
646   m.def("register_custom_call_target", &PyRegisterCustomCallTarget);
647 
648   py::class_<DebugOptions>(m, "DebugOptions")
649       .def("__repr__", &DebugOptions::DebugString)
650       .def_property("xla_backend_optimization_level",
651                     &DebugOptions::xla_backend_optimization_level,
652                     &DebugOptions::set_xla_backend_optimization_level)
653       .def_property("xla_cpu_enable_fast_math",
654                     &DebugOptions::xla_cpu_enable_fast_math,
655                     &DebugOptions::set_xla_cpu_enable_fast_math)
656       .def_property("xla_cpu_enable_xprof_traceme",
657                     &DebugOptions::xla_cpu_enable_xprof_traceme,
658                     &DebugOptions::set_xla_cpu_enable_xprof_traceme)
659       .def_property("xla_cpu_fast_math_honor_infs",
660                     &DebugOptions::xla_cpu_fast_math_honor_infs,
661                     &DebugOptions::set_xla_cpu_fast_math_honor_infs)
662       .def_property("xla_cpu_fast_math_honor_nans",
663                     &DebugOptions::xla_cpu_fast_math_honor_nans,
664                     &DebugOptions::set_xla_cpu_fast_math_honor_nans)
665       .def_property("xla_cpu_fast_math_honor_division",
666                     &DebugOptions::xla_cpu_fast_math_honor_division,
667                     &DebugOptions::set_xla_cpu_fast_math_honor_division)
668       .def_property("xla_cpu_fast_math_honor_functions",
669                     &DebugOptions::xla_cpu_fast_math_honor_functions,
670                     &DebugOptions::set_xla_cpu_fast_math_honor_functions)
671       .def_property("xla_detailed_logging_and_dumping",
672                     &DebugOptions::xla_detailed_logging_and_dumping,
673                     &DebugOptions::set_xla_detailed_logging_and_dumping)
674       .def_property("xla_gpu_enable_fast_min_max",
675                     &DebugOptions::xla_gpu_enable_fast_min_max,
676                     &DebugOptions::set_xla_gpu_enable_fast_min_max)
677       .def_property("xla_gpu_cuda_data_dir",
678                     &DebugOptions::xla_gpu_cuda_data_dir,
679                     [](DebugOptions* self, std::string value) {
680                       self->set_xla_gpu_cuda_data_dir(value);
681                     })
682       .def_property("xla_llvm_disable_expensive_passes",
683                     &DebugOptions::xla_llvm_disable_expensive_passes,
684                     &DebugOptions::set_xla_llvm_disable_expensive_passes)
685       .def_property("xla_test_all_input_layouts",
686                     &DebugOptions::xla_test_all_input_layouts,
687                     &DebugOptions::set_xla_test_all_input_layouts);
688 
689   py::class_<ExecutableBuildOptions>(m, "ExecutableBuildOptions")
690       .def(py::init<>())
691       .def("__repr__", &ExecutableBuildOptions::ToString)
692       .def_property(
693           "result_layout",
694           [](const ExecutableBuildOptions& options) -> std::optional<Shape> {
695             return options.result_layout()
696                        ? std::optional<Shape>(*options.result_layout())
697                        : std::nullopt;
698           },
699           &ExecutableBuildOptions::set_result_layout)
700       .def_property("num_replicas", &ExecutableBuildOptions::num_replicas,
701                     &ExecutableBuildOptions::set_num_replicas)
702       .def_property("num_partitions", &ExecutableBuildOptions::num_partitions,
703                     &ExecutableBuildOptions::set_num_partitions)
704       .def_property_readonly(
705           "debug_options", &ExecutableBuildOptions::mutable_debug_options,
706           py::return_value_policy::reference, py::keep_alive<1, 0>())
707       .def_property(
708           "device_assignment",
709           [](const ExecutableBuildOptions& options)
710               -> std::optional<DeviceAssignment> {
711             return options.has_device_assignment()
712                        ? std::optional<DeviceAssignment>(
713                              options.device_assignment())
714                        : std::nullopt;
715           },
716           &ExecutableBuildOptions::set_device_assignment)
717       .def_property("use_spmd_partitioning",
718                     &ExecutableBuildOptions::use_spmd_partitioning,
719                     &ExecutableBuildOptions::set_use_spmd_partitioning)
720       .def_property("use_auto_spmd_partitioning",
721                     &ExecutableBuildOptions::use_auto_spmd_partitioning,
722                     &ExecutableBuildOptions::set_use_auto_spmd_partitioning)
723       .def_property(
724           "auto_spmd_partitioning_mesh_shape",
725           &ExecutableBuildOptions::auto_spmd_partitioning_mesh_shape,
726           &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_shape)
727       .def_property(
728           "auto_spmd_partitioning_mesh_ids",
729           &ExecutableBuildOptions::auto_spmd_partitioning_mesh_ids,
730           &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_ids)
731       .def_property(
732           "allow_spmd_sharding_propagation_to_output",
733           &ExecutableBuildOptions::allow_spmd_sharding_propagation_to_output,
734           &ExecutableBuildOptions::
735               set_allow_spmd_sharding_propagation_to_output);
736 
737   py::enum_<OpSharding::Type> op_sharding_type(m, "OpSharding_Type");
738   op_sharding_type.value("REPLICATED", OpSharding::REPLICATED)
739       .value("MAXIMAL", OpSharding::MAXIMAL)
740       .value("MANUAL", OpSharding::MANUAL)
741       .value("TUPLE", OpSharding::TUPLE)
742       .value("OTHER", OpSharding::OTHER);
743 
744   py::class_<OpSharding> op_sharding(m, "OpSharding");
745   op_sharding
746       .def_property_readonly_static(
747           "Type",
748           [op_sharding_type](const py::object&) { return op_sharding_type; })
749       .def(py::init<>())
750       .def_property("type", &xla::OpSharding::type, &xla::OpSharding::set_type)
751       .def_property("replicate_on_last_tile_dim",
752                     &xla::OpSharding::replicate_on_last_tile_dim,
753                     &xla::OpSharding::set_replicate_on_last_tile_dim)
754       .def("__repr__", &xla::OpSharding::DebugString)
755       .def("SerializeToString",
756            [](const OpSharding& sharding) {
757              return py::bytes(sharding.SerializeAsString());
758            })
759       .def("clone",
760            [](const OpSharding& sharding) { return OpSharding(sharding); });
761   DefRepeatedProperty(op_sharding, "tile_assignment_dimensions",
762                       &xla::OpSharding::mutable_tile_assignment_dimensions);
763   DefRepeatedProperty(op_sharding, "tile_assignment_devices",
764                       &xla::OpSharding::mutable_tile_assignment_devices);
765   DefRepeatedProperty(op_sharding, "tuple_shardings",
766                       &xla::OpSharding::mutable_tuple_shardings);
767   DefRepeatedProperty(op_sharding, "last_tile_dims",
768                       &xla::OpSharding::mutable_last_tile_dims);
769 
770   py::class_<HloSharding> hlo_sharding(m, "HloSharding");
771   hlo_sharding.def_static("from_proto", &xla::HloSharding::FromProto)
772       .def("__eq__", [](const xla::HloSharding& a,
773                         const xla::HloSharding& b) { return a == b; })
774       .def("__hash__",
775            [](const xla::HloSharding& self) { return absl::HashOf(self); })
776       .def("is_replicated", &xla::HloSharding::IsReplicated)
777       .def("to_proto", &xla::HloSharding::ToProto);
778 
779   py::class_<FrontendAttributes> frontend_attributes(m, "FrontendAttributes");
780   frontend_attributes.def(py::init<>())
781       .def("__setitem__",
782            [](FrontendAttributes* attr, std::string key, std::string value) {
783              (*attr->mutable_map())[key] = value;
784            });
785 
786   py::enum_<PrecisionConfig::Precision>(m, "PrecisionConfig_Precision")
787       .value("DEFAULT", PrecisionConfig::DEFAULT)
788       .value("HIGH", PrecisionConfig::HIGH)
789       .value("HIGHEST", PrecisionConfig::HIGHEST);
790 
791   py::enum_<ChannelHandle::ChannelType>(m, "ChannelHandle_ChannelType")
792       .value("CHANNEL_TYPE_INVALID", ChannelHandle::CHANNEL_TYPE_INVALID)
793       .value("DEVICE_TO_DEVICE", ChannelHandle::DEVICE_TO_DEVICE)
794       .value("DEVICE_TO_HOST", ChannelHandle::DEVICE_TO_HOST)
795       .value("HOST_TO_DEVICE", ChannelHandle::HOST_TO_DEVICE);
796 
797   py::class_<ChannelHandle>(m, "ChannelHandle")
798       .def_property("type", &ChannelHandle::type,
799                     [](ChannelHandle* h, ChannelHandle::ChannelType type) {
800                       h->set_type(type);
801                     })
802       .def_property(
803           "handle", &ChannelHandle::handle,
804           [](ChannelHandle* h, int64_t handle) { h->set_handle(handle); })
805       .def("__repr__", [](ChannelHandle* h) { return h->DebugString(); });
806 
807   py::enum_<FftType>(m, "FftType")
808       .value("FFT", FftType::FFT)
809       .value("IFFT", FftType::IFFT)
810       .value("RFFT", FftType::RFFT)
811       .value("IRFFT", FftType::IRFFT);
812 }  // NOLINT(readability/fn_size)
813 }  // namespace xla
814