1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/xla_compiler.h"
17
18 #include <cstdint>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/hash/hash.h"
25 #include "absl/synchronization/mutex.h"
26 #include "absl/types/span.h"
27 #include "pybind11/attr.h"
28 #include "pybind11/cast.h"
29 #include "pybind11/numpy.h"
30 #include "pybind11/pybind11.h"
31 #include "pybind11/pytypes.h"
32 #include "pybind11/stl_bind.h"
33 #include "tensorflow/compiler/xla/client/executable_build_options.h"
34 #include "tensorflow/compiler/xla/client/xla_builder.h"
35 #include "tensorflow/compiler/xla/client/xla_computation.h"
36 #include "tensorflow/compiler/xla/debug_options_flags.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/python/py_client.h"
39 #include "tensorflow/compiler/xla/python/types.h"
40 #include "tensorflow/compiler/xla/service/computation_placer.h"
41 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
42 #include "tensorflow/compiler/xla/service/hlo.pb.h"
43 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
44 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
45 #include "tensorflow/compiler/xla/service/hlo_module.h"
46 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
47 #include "tensorflow/compiler/xla/service/hlo_parser.h"
48 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
49 #include "tensorflow/compiler/xla/service/name_uniquer.h"
50 #include "tensorflow/compiler/xla/service/platform_util.h"
51 #include "tensorflow/compiler/xla/shape.h"
52 #include "tensorflow/compiler/xla/shape_util.h"
53 #include "tensorflow/compiler/xla/statusor.h"
54 #include "tensorflow/compiler/xla/util.h"
55 #include "tensorflow/compiler/xla/xla.pb.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/lib/strings/proto_serialization.h"
58
59 namespace xla {
60 namespace {
61
62 namespace py = pybind11;
63
64 struct Uniquer {
65 absl::Mutex mu;
66 NameUniquer name_uniquer ABSL_GUARDED_BY(mu);
67 };
68
GetUniquer()69 Uniquer* GetUniquer() {
70 static Uniquer* uniquer = new Uniquer;
71 return uniquer;
72 }
73
UniquifyName(const std::string & name)74 static std::string UniquifyName(const std::string& name) {
75 Uniquer* uniquer = GetUniquer();
76 absl::MutexLock lock(&uniquer->mu);
77 return uniquer->name_uniquer.GetUniqueName(name);
78 }
79
80 // Converts a computation to a serialized HloModuleProto.
GetComputationSerializedProto(const XlaComputation & computation)81 StatusOr<py::bytes> GetComputationSerializedProto(
82 const XlaComputation& computation) {
83 std::string result;
84 if (!tensorflow::SerializeToStringDeterministic(computation.proto(),
85 &result)) {
86 return Unknown("Failed to serialize the HloModuleProto.");
87 }
88 return py::bytes(result);
89 }
90
91 // Converts a hlo module to a serialized HloModuleProto.
GetHloModuleSerializedProto(const HloModule & module)92 StatusOr<py::bytes> GetHloModuleSerializedProto(const HloModule& module) {
93 std::string result;
94 if (!tensorflow::SerializeToStringDeterministic(module.ToProto(), &result)) {
95 return Unknown("Failed to serialize the HloModuleProto.");
96 }
97 return py::bytes(result);
98 }
99
100 // Converts a serialized HloModuleProto into a HloModule.
HloModuleFromSerializedProto(const py::bytes & bytes)101 StatusOr<std::shared_ptr<HloModule>> HloModuleFromSerializedProto(
102 const py::bytes& bytes) {
103 HloModuleProto proto;
104 proto.ParseFromString(bytes);
105 TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
106 HloModule::CreateModuleConfigFromProto(
107 proto, GetDebugOptionsFromFlags()));
108 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
109 HloModule::CreateFromProto(proto, module_config));
110 return std::shared_ptr<HloModule>(std::move(module));
111 }
112
GetHloModule(const XlaComputation & computation)113 StatusOr<std::shared_ptr<HloModule>> GetHloModule(
114 const XlaComputation& computation) {
115 TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
116 HloModule::CreateModuleConfigFromProto(
117 computation.proto(), GetDebugOptionsFromFlags()));
118 TF_ASSIGN_OR_RETURN(
119 std::unique_ptr<HloModule> module,
120 HloModule::CreateFromProto(computation.proto(), module_config));
121 return std::shared_ptr<HloModule>(std::move(module));
122 }
123
124 // Converts a computation to textual HLO form.
GetComputationHloText(const XlaComputation & computation,bool print_large_constants=false)125 StatusOr<std::string> GetComputationHloText(
126 const XlaComputation& computation, bool print_large_constants = false) {
127 TF_ASSIGN_OR_RETURN(std::shared_ptr<HloModule> hlo_module,
128 GetHloModule(computation));
129 HloPrintOptions options;
130 options = HloPrintOptions::ShortParsable();
131 options.set_print_large_constants(print_large_constants);
132 return hlo_module->ToString(options);
133 }
134
135 // Converts a computation to HLO dot graph form.
GetComputationHloDotGraph(const XlaComputation & computation)136 StatusOr<std::string> GetComputationHloDotGraph(
137 const XlaComputation& computation) {
138 TF_ASSIGN_OR_RETURN(std::shared_ptr<HloModule> hlo_module,
139 GetHloModule(computation));
140 return RenderGraph(*hlo_module->entry_computation(), /*label=*/"",
141 hlo_module->config().debug_options(),
142 RenderedGraphFormat::kDot);
143 }
144
145 // Hashes the HLO module.
HashComputation(const XlaComputation & computation)146 StatusOr<uint64_t> HashComputation(const XlaComputation& computation) {
147 TF_ASSIGN_OR_RETURN(std::shared_ptr<HloModule> hlo_module,
148 GetHloModule(computation));
149 return absl::HashOf(*hlo_module);
150 }
151 // Safe version of ShapeUtil::MakeShapeWithLayout that fails gracefully on
152 // invalid input.
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64_t> dims,std::optional<absl::Span<const int64_t>> minor_to_major,std::optional<const std::vector<bool>> dynamic_dimensions)153 StatusOr<Shape> MakeShapeWithLayout(
154 PrimitiveType element_type, absl::Span<const int64_t> dims,
155 std::optional<absl::Span<const int64_t>> minor_to_major,
156 std::optional<const std::vector<bool>> dynamic_dimensions) {
157 Shape shape;
158 if (dynamic_dimensions) {
159 TF_ASSIGN_OR_RETURN(
160 shape, ShapeUtil::MakeValidatedShape(element_type, dims,
161 dynamic_dimensions.value()));
162 } else {
163 TF_ASSIGN_OR_RETURN(shape,
164 ShapeUtil::MakeValidatedShape(element_type, dims));
165 }
166 if (minor_to_major) {
167 *shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major);
168 TF_RETURN_IF_ERROR(
169 LayoutUtil::ValidateLayoutForShape(shape.layout(), shape));
170 } else {
171 shape.clear_layout();
172 }
173 return shape;
174 }
175
176 // Registers a 'fn_capsule' as a CPU custom call target.
177 // 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object,
178 // with name "xla._CUSTOM_CALL_TARGET".
179 // 'platform' is an XLA platform name, e.g., "Host" or "CUDA".
PyRegisterCustomCallTarget(const std::string & fn_name,py::capsule capsule,const std::string & platform)180 Status PyRegisterCustomCallTarget(const std::string& fn_name,
181 py::capsule capsule,
182 const std::string& platform) {
183 static const char* const kName = "xla._CUSTOM_CALL_TARGET";
184 // TODO(phawkins): remove old name after fixing users.
185 static const char* const kOldCpuName = "xla._CPU_CUSTOM_CALL_TARGET";
186 if (absl::string_view(capsule.name()) != kName &&
187 absl::string_view(capsule.name()) != kOldCpuName) {
188 return InvalidArgument(
189 "Argument to RegisterCustomCallTargetRegistry was not a "
190 "xla._CUSTOM_CALL_TARGET capsule.");
191 }
192 CustomCallTargetRegistry::Global()->Register(
193 fn_name, static_cast<void*>(capsule), platform);
194 return OkStatus();
195 }
196
197 template <typename T, typename Container>
DefRepeatedProperty(py::class_<T> & cls,const char * name,Container * (T::* getter)())198 void DefRepeatedProperty(py::class_<T>& cls, const char* name,
199 Container* (T::*getter)()) {
200 cls.def_property(
201 name,
202 [getter](T& obj) {
203 Container* elems = (obj.*getter)();
204 std::vector<typename Container::value_type> result;
205 result.reserve(elems->size());
206 std::copy(elems->begin(), elems->end(), std::back_inserter(result));
207 return result;
208 },
209 [getter](T& obj, std::vector<typename Container::value_type> new_elems) {
210 Container* elems = (obj.*getter)();
211 elems->Clear();
212 elems->Reserve(new_elems.size());
213 for (typename Container::value_type& e : new_elems) {
214 elems->Add(std::move(e));
215 }
216 });
217 }
218
219 } // namespace
220
BuildXlaCompilerSubmodule(py::module & m)221 void BuildXlaCompilerSubmodule(py::module& m) {
222 // Shapes
223 py::class_<Layout> layout_class(m, "Layout");
224 layout_class
225 .def("minor_to_major",
226 [](Layout layout) { return SpanToTuple(layout.minor_to_major()); })
227 .def("__eq__", [](const Layout& layout,
228 const Layout& other) { return layout == other; })
229 .def("__ne__", [](const Layout& layout,
230 const Layout& other) { return layout != other; })
231 .def("__hash__",
232 [](const Layout& layout) { return absl::HashOf(layout); })
233 .def("to_string", &Layout::ToString);
234
235 py::class_<Shape> shape_class(m, "Shape");
236 shape_class
237 .def(py::init([](const std::string& s) {
238 return std::make_unique<Shape>(ValueOrThrow(ParseShape(s)));
239 }))
240 .def_static(
241 "tuple_shape",
242 [](std::vector<Shape> shapes) -> Shape {
243 return ShapeUtil::MakeTupleShape(shapes);
244 },
245 "Constructs a tuple shape.")
246 .def_static(
247 "array_shape",
248 [](PrimitiveType type, py::object dims_seq,
249 std::optional<py::object> layout_seq,
250 std::optional<std::vector<bool>> dynamic_dimensions)
251 -> StatusOr<Shape> {
252 std::vector<int64_t> dims = SequenceToVector<int64_t>(dims_seq);
253 if (layout_seq) {
254 std::vector<int64_t> layout =
255 SequenceToVector<int64_t>(*layout_seq);
256 return MakeShapeWithLayout(type, dims, layout,
257 dynamic_dimensions);
258 } else {
259 return MakeShapeWithLayout(type, dims, std::nullopt,
260 dynamic_dimensions);
261 }
262 },
263 "Constructs an array shape.", py::arg("type"), py::arg("dims"),
264 py::arg("layout") = std::nullopt,
265 py::arg("dynamic_dimensions") = std::nullopt)
266 .def_static(
267 "array_shape",
268 [](py::dtype dtype, py::object dims_seq,
269 std::optional<py::object> layout_seq,
270 std::optional<std::vector<bool>> dynamic_dimensions)
271 -> StatusOr<Shape> {
272 PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype));
273 std::vector<int64_t> dims = SequenceToVector<int64_t>(dims_seq);
274 if (layout_seq) {
275 std::vector<int64_t> layout =
276 SequenceToVector<int64_t>(*layout_seq);
277 return MakeShapeWithLayout(type, dims, layout,
278 dynamic_dimensions);
279 } else {
280 return MakeShapeWithLayout(type, dims, std::nullopt,
281 dynamic_dimensions);
282 }
283 },
284 "Constructs an array shape.", py::arg("type"), py::arg("dims"),
285 py::arg("layout") = std::nullopt,
286 py::arg("dynamic_dimensions") = std::nullopt)
287 .def_static("token_shape", []() { return ShapeUtil::MakeTokenShape(); })
288 .def_static(
289 "scalar_shape",
290 [](PrimitiveType type) -> Shape {
291 return ShapeUtil::MakeScalarShape(type);
292 },
293 "Constructs a scalar shape.", py::arg("type"))
294 .def_static(
295 "scalar_shape",
296 [](py::dtype dtype) -> StatusOr<Shape> {
297 PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype));
298 return ShapeUtil::MakeScalarShape(type);
299 },
300 "Constructs a scalar shape.", py::arg("type"))
301 .def("dimensions",
302 [](const Shape& shape) -> py::tuple {
303 return SpanToTuple(shape.dimensions());
304 })
305 .def("layout",
306 [](const Shape& shape) -> Layout { return shape.layout(); })
307 .def("xla_element_type", &Shape::element_type)
308 .def("element_type",
309 [](const Shape& shape) {
310 return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type()));
311 })
312 .def("numpy_dtype",
313 [](const Shape& shape) {
314 if (shape.IsTuple()) {
315 return py::dtype("O");
316 }
317 return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type()));
318 })
319 .def("is_tuple", &Shape::IsTuple)
320 .def("is_array", &Shape::IsArray)
321 .def("is_token", &Shape::IsToken)
322 .def("is_static", &Shape::is_static)
323 .def("is_dynamic", &Shape::is_dynamic)
324 .def("is_dynamic_dimension", &Shape::is_dynamic_dimension,
325 py::arg("dimension"))
326 .def("set_dynamic_dimension", &Shape::set_dynamic_dimension,
327 py::arg("dimension"), py::arg("is_dynamic"))
328 .def("rank", &Shape::rank)
329 .def("to_serialized_proto",
330 [](const Shape& shape) {
331 ShapeProto proto = shape.ToProto();
332 return py::bytes(proto.SerializeAsString());
333 })
334 .def("tuple_shapes",
335 [](const Shape& shape) {
336 return std::vector<Shape>(shape.tuple_shapes());
337 })
338 .def("leaf_count",
339 [](const Shape& shape) { return ShapeUtil::GetLeafCount(shape); })
340 .def(
341 "with_major_to_minor_layout_if_absent",
342 [](const Shape& shape) {
343 Shape out = shape;
344 ShapeUtil::ForEachMutableSubshape(
345 &out, [](Shape* subshape, const ShapeIndex&) {
346 if (!subshape->has_layout()) {
347 LayoutUtil::SetToDefaultLayout(subshape);
348 }
349 });
350 return out;
351 },
352 "Returns a copy of a shape with missing layouts set to "
353 "major-to-minor.")
354 .def("__eq__", [](const Shape& shape,
355 const Shape& other) { return shape == other; })
356 .def("__ne__", [](const Shape& shape,
357 const Shape& other) { return shape != other; })
358 .def("__hash__", [](const Shape& shape) { return absl::HashOf(shape); })
359 .def("__repr__", [](const Shape& shape) {
360 return shape.ToString(/*print_layout=*/true);
361 });
362
363 py::class_<ProgramShape>(m, "ProgramShape")
364 .def(py::init(
365 [](absl::Span<const Shape> params, Shape result) -> ProgramShape {
366 ProgramShape program_shape;
367 for (const Shape& param : params) {
368 *program_shape.add_parameters() = param;
369 }
370 *program_shape.mutable_result() = result;
371 return program_shape;
372 }))
373 .def("parameter_shapes",
374 static_cast<const std::vector<Shape>& (ProgramShape::*)() const>(
375 &ProgramShape::parameters))
376 .def("result_shape", &ProgramShape::result)
377 .def("__repr__", &ProgramShape::ToString);
378
379 py::class_<ShapeIndex>(m, "ShapeIndex")
380 .def(py::init([](const std::vector<int64_t>& v) {
381 return std::make_unique<ShapeIndex>(v.begin(), v.end());
382 }))
383 .def("__repr__", &ShapeIndex::ToString)
384 .def("__eq__", [](const ShapeIndex& shape_ind,
385 const ShapeIndex& other) { return shape_ind == other; })
386 .def("__ne__", [](const ShapeIndex& shape_ind,
387 const ShapeIndex& other) { return shape_ind != other; })
388 .def("__hash__",
389 [](const ShapeIndex& shape_ind) { return absl::HashOf(shape_ind); });
390
391 // Literals
392 py::class_<Literal, std::shared_ptr<Literal>>(m, "Literal")
393 .def("__repr__", &Literal::ToString);
394
395 py::class_<XlaComputation>(m, "XlaComputation")
396 .def(py::init([](const py::bytes& serialized_hlo_module_proto)
397 -> std::unique_ptr<XlaComputation> {
398 HloModuleProto proto;
399 proto.ParseFromString(std::string(serialized_hlo_module_proto));
400 return std::make_unique<XlaComputation>(proto);
401 }))
402 .def("get_hlo_module", &GetHloModule)
403 .def("program_shape", &XlaComputation::GetProgramShape)
404 .def("name", &XlaComputation::name)
405 .def("as_serialized_hlo_module_proto", &GetComputationSerializedProto)
406 .def("as_hlo_text", &GetComputationHloText,
407 py::arg("print_large_constants") = false)
408 .def("as_hlo_dot_graph", &GetComputationHloDotGraph)
409 .def("hash", &HashComputation)
410 .def("as_hlo_module", &GetHloModule);
411
412 py::class_<HloPrintOptions> hlo_print_options_class(m, "HloPrintOptions");
413 hlo_print_options_class.def(py::init<>())
414 .def_static("short_parsable", &HloPrintOptions::ShortParsable)
415 .def_static("canonical", &HloPrintOptions::Canonical)
416 .def_static("fingerprint", &HloPrintOptions::Fingerprint)
417 .def_property("print_large_constants",
418 &HloPrintOptions::print_large_constants,
419 &HloPrintOptions::set_print_large_constants)
420 .def_property("print_metadata", &HloPrintOptions::print_metadata,
421 &HloPrintOptions::set_print_metadata)
422 .def_property("print_backend_config",
423 &HloPrintOptions::print_backend_config,
424 &HloPrintOptions::set_print_backend_config)
425 .def_property("print_result_shape", &HloPrintOptions::print_result_shape,
426 &HloPrintOptions::set_print_result_shape)
427 .def_property("print_operand_shape",
428 &HloPrintOptions::print_operand_shape,
429 &HloPrintOptions::set_print_operand_shape)
430 .def_property("print_operand_names",
431 &HloPrintOptions::print_operand_names,
432 &HloPrintOptions::set_print_operand_names)
433 .def_property("print_ids", &HloPrintOptions::print_ids,
434 &HloPrintOptions::set_print_ids)
435 .def_property("print_extra_attributes",
436 &HloPrintOptions::print_extra_attributes,
437 &HloPrintOptions::set_print_extra_attributes)
438 .def_property("print_program_shape",
439 &HloPrintOptions::print_program_shape,
440 &HloPrintOptions::set_print_program_shape)
441 .def_property("print_percent", &HloPrintOptions::print_percent,
442 &HloPrintOptions::set_print_percent)
443 .def_property("print_control_dependencies",
444 &HloPrintOptions::print_control_dependencies,
445 &HloPrintOptions::set_print_control_dependencies)
446 .def_property("compact_operands", &HloPrintOptions::compact_operands,
447 &HloPrintOptions::set_compact_operands)
448 .def_property("include_layout_in_shapes",
449 &HloPrintOptions::include_layout_in_shapes,
450 &HloPrintOptions::set_include_layout_in_shapes)
451 .def_property("canonicalize_instruction_names",
452 &HloPrintOptions::canonicalize_instruction_names,
453 &HloPrintOptions::set_canonicalize_instruction_names)
454 .def_property("canonicalize_computations",
455 &HloPrintOptions::canonicalize_computations,
456 &HloPrintOptions::set_canonicalize_computations)
457 .def_property("indent_amount", &HloPrintOptions::indent_amount,
458 &HloPrintOptions::set_indent_amount)
459 .def_property("is_in_nested_computation",
460 &HloPrintOptions::is_in_nested_computation,
461 &HloPrintOptions::set_is_in_nested_computation);
462
463 py::class_<HloModule, std::shared_ptr<HloModule>> hlo_module_class(
464 m, "HloModule");
465 hlo_module_class.def_property_readonly("name", &HloModule::name)
466 .def(
467 "to_string",
468 static_cast<std::string (HloModule::*)(const HloPrintOptions&) const>(
469 &HloModule::ToString),
470 py::arg("options") = HloPrintOptions())
471 .def("as_serialized_hlo_module_proto", &GetHloModuleSerializedProto)
472 .def("from_serialized_hlo_module_proto", &HloModuleFromSerializedProto)
473 .def_property_readonly(
474 "spmd_output_sharding",
475 [](const HloModule& m) -> std::optional<xla::OpSharding> {
476 if (!m.has_spmd_output_sharding()) return std::nullopt;
477 return m.spmd_output_sharding().ToProto();
478 })
479 .def_property_readonly(
480 "spmd_parameters_shardings",
481 [](const HloModule& m)
482 -> std::optional<std::vector<xla::OpSharding>> {
483 if (!m.has_spmd_parameters_shardings()) return std::nullopt;
484 std::vector<xla::OpSharding> param_shardings;
485 for (const auto& parameter_sharding :
486 m.spmd_parameters_shardings()) {
487 param_shardings.push_back(parameter_sharding.ToProto());
488 }
489 return param_shardings;
490 });
491
492 m.def("hlo_module_to_dot_graph",
493 [](const HloModule& hlo_module) -> StatusOr<std::string> {
494 return RenderGraph(*hlo_module.entry_computation(), /*label=*/"",
495 hlo_module.config().debug_options(),
496 RenderedGraphFormat::kDot);
497 });
498 m.def(
499 "hlo_module_cost_analysis",
500 [](PyClient* client,
501 const HloModule& module) -> StatusOr<HloCostAnalysis::Properties> {
502 TF_ASSIGN_OR_RETURN(auto analysis,
503 client->pjrt_client()->GetHloCostAnalysis());
504 TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get()));
505 return analysis->properties();
506 });
507
508 py::class_<XlaOp> xla_op_class(m, "XlaOp");
509
510 py::class_<XlaBuilder>(m, "XlaBuilder")
511 .def(py::init([](const std::string& name) -> std::unique_ptr<XlaBuilder> {
512 return std::make_unique<XlaBuilder>(UniquifyName(name));
513 }))
514 // TODO(phawkins): delete capitalized names after updating callers.
515 .def(
516 "Build",
517 [](XlaBuilder& builder, std::optional<XlaOp> root) {
518 return root ? builder.Build(*root) : builder.Build();
519 },
520 "Builds a computation from the contents of the builder.",
521 py::arg("root") = std::nullopt)
522 .def("GetShape", &XlaBuilder::GetShape)
523 .def(
524 "build",
525 [](XlaBuilder& builder, std::optional<XlaOp> root) {
526 return root ? builder.Build(*root) : builder.Build();
527 },
528 "Builds a computation from the contents of the builder.",
529 py::arg("root") = std::nullopt)
530 .def("clear_op_metadata", &XlaBuilder::ClearOpMetadata)
531 .def("get_shape", &XlaBuilder::GetShape)
532 .def(
533 "get_program_shape",
534 [](const XlaBuilder& builder,
535 std::optional<XlaOp> root) -> StatusOr<ProgramShape> {
536 return root ? builder.GetProgramShape(*root)
537 : builder.GetProgramShape();
538 },
539 py::arg("root") = std::nullopt)
540 .def("is_constant", &XlaBuilder::IsConstant)
541 .def("set_op_metadata", &XlaBuilder::SetOpMetadata)
542 .def("set_sharding", &XlaBuilder::SetSharding)
543 .def("clear_sharding", &XlaBuilder::ClearSharding)
544 .def("set_frontend_attributes", &XlaBuilder::SetFrontendAttributes)
545 .def("clear_frontend_attributes", &XlaBuilder::ClearFrontendAttributes)
546 .def("setup_alias",
547 [](XlaBuilder& builder, const std::vector<int64_t>& output_index,
548 int64_t param_number, const std::vector<int64_t>& param_index) {
549 builder.SetUpAlias(
550 ShapeIndex(output_index.begin(), output_index.end()),
551 param_number,
552 ShapeIndex(param_index.begin(), param_index.end()));
553 });
554
555 // Device assignments
556 py::class_<DeviceAssignment>(m, "DeviceAssignment")
557 .def_static("create",
558 [](py::array_t<int> array) -> StatusOr<DeviceAssignment> {
559 if (array.ndim() != 2) {
560 return InvalidArgument(
561 "Argument to DeviceAssignment constructor must be a "
562 "2D array, received an %dD array.",
563 array.ndim());
564 }
565 DeviceAssignment result(array.shape(0), array.shape(1));
566 for (int i = 0; i < array.shape(0); ++i) {
567 for (int j = 0; j < array.shape(1); ++j) {
568 result(i, j) = array.at(i, j);
569 }
570 }
571 return result;
572 })
573 .def("replica_count", &DeviceAssignment::replica_count)
574 .def("computation_count", &DeviceAssignment::computation_count)
575 .def("__repr__", &DeviceAssignment::ToString)
576 .def("serialize", [](const DeviceAssignment& da) -> StatusOr<py::bytes> {
577 DeviceAssignmentProto proto;
578 TF_RETURN_IF_ERROR(da.Serialize(&proto));
579 std::string result;
580 if (!tensorflow::SerializeToStringDeterministic(proto, &result)) {
581 return Unknown("Failed to serialize the DeviceAssignmentProto.");
582 }
583 return py::bytes(result);
584 });
585
586 py::class_<CompileOptions> compile_options(m, "CompileOptions");
587 compile_options
588 .def(py::init([]() -> CompileOptions {
589 CompileOptions options;
590 DebugOptions* debug_options =
591 options.executable_build_options.mutable_debug_options();
592 // Sets fast-math-disabling default options expected by JAX.
593 debug_options->set_xla_cpu_enable_fast_min_max(false);
594 debug_options->set_xla_gpu_enable_fast_min_max(false);
595 return options;
596 }))
597 .def_readwrite("argument_layouts", &CompileOptions::argument_layouts)
598 .def_readwrite("parameter_is_tupled_arguments",
599 &CompileOptions::parameter_is_tupled_arguments)
600 .def_readwrite("compile_portable_executable",
601 &CompileOptions::compile_portable_executable)
602 .def_readonly("executable_build_options",
603 &CompileOptions::executable_build_options)
604 // TODO(phawkins): the following fields exist for backward compatibility.
605 // Remove them after JAX has been updated not to use them.
606 .def_readwrite("tuple_arguments",
607 &CompileOptions::parameter_is_tupled_arguments)
608 .def_property(
609 "num_replicas",
610 [](const CompileOptions& options) {
611 return options.executable_build_options.num_replicas();
612 },
613 [](CompileOptions& options, int num_replicas) {
614 options.executable_build_options.set_num_replicas(num_replicas);
615 })
616 .def_property(
617 "num_partitions",
618 [](const CompileOptions& options) {
619 return options.executable_build_options.num_partitions();
620 },
621 [](CompileOptions& options, int num_partitions) {
622 options.executable_build_options.set_num_partitions(num_partitions);
623 })
624 .def_property(
625 "profile_version",
626 [](const CompileOptions& options) { return options.profile_version; },
627 [](CompileOptions& options, int64_t profile_version) {
628 options.profile_version = profile_version;
629 })
630 .def_property(
631 "device_assignment",
632 [](const CompileOptions& options) -> std::optional<DeviceAssignment> {
633 return options.executable_build_options.has_device_assignment()
634 ? std::optional<DeviceAssignment>(
635 options.executable_build_options
636 .device_assignment())
637 : std::nullopt;
638 },
639 [](CompileOptions& options,
640 const DeviceAssignment& device_assignment) {
641 options.executable_build_options.set_device_assignment(
642 device_assignment);
643 });
644
645 // Custom-call targets.
646 m.def("register_custom_call_target", &PyRegisterCustomCallTarget);
647
648 py::class_<DebugOptions>(m, "DebugOptions")
649 .def("__repr__", &DebugOptions::DebugString)
650 .def_property("xla_backend_optimization_level",
651 &DebugOptions::xla_backend_optimization_level,
652 &DebugOptions::set_xla_backend_optimization_level)
653 .def_property("xla_cpu_enable_fast_math",
654 &DebugOptions::xla_cpu_enable_fast_math,
655 &DebugOptions::set_xla_cpu_enable_fast_math)
656 .def_property("xla_cpu_enable_xprof_traceme",
657 &DebugOptions::xla_cpu_enable_xprof_traceme,
658 &DebugOptions::set_xla_cpu_enable_xprof_traceme)
659 .def_property("xla_cpu_fast_math_honor_infs",
660 &DebugOptions::xla_cpu_fast_math_honor_infs,
661 &DebugOptions::set_xla_cpu_fast_math_honor_infs)
662 .def_property("xla_cpu_fast_math_honor_nans",
663 &DebugOptions::xla_cpu_fast_math_honor_nans,
664 &DebugOptions::set_xla_cpu_fast_math_honor_nans)
665 .def_property("xla_cpu_fast_math_honor_division",
666 &DebugOptions::xla_cpu_fast_math_honor_division,
667 &DebugOptions::set_xla_cpu_fast_math_honor_division)
668 .def_property("xla_cpu_fast_math_honor_functions",
669 &DebugOptions::xla_cpu_fast_math_honor_functions,
670 &DebugOptions::set_xla_cpu_fast_math_honor_functions)
671 .def_property("xla_detailed_logging_and_dumping",
672 &DebugOptions::xla_detailed_logging_and_dumping,
673 &DebugOptions::set_xla_detailed_logging_and_dumping)
674 .def_property("xla_gpu_enable_fast_min_max",
675 &DebugOptions::xla_gpu_enable_fast_min_max,
676 &DebugOptions::set_xla_gpu_enable_fast_min_max)
677 .def_property("xla_gpu_cuda_data_dir",
678 &DebugOptions::xla_gpu_cuda_data_dir,
679 [](DebugOptions* self, std::string value) {
680 self->set_xla_gpu_cuda_data_dir(value);
681 })
682 .def_property("xla_llvm_disable_expensive_passes",
683 &DebugOptions::xla_llvm_disable_expensive_passes,
684 &DebugOptions::set_xla_llvm_disable_expensive_passes)
685 .def_property("xla_test_all_input_layouts",
686 &DebugOptions::xla_test_all_input_layouts,
687 &DebugOptions::set_xla_test_all_input_layouts);
688
689 py::class_<ExecutableBuildOptions>(m, "ExecutableBuildOptions")
690 .def(py::init<>())
691 .def("__repr__", &ExecutableBuildOptions::ToString)
692 .def_property(
693 "result_layout",
694 [](const ExecutableBuildOptions& options) -> std::optional<Shape> {
695 return options.result_layout()
696 ? std::optional<Shape>(*options.result_layout())
697 : std::nullopt;
698 },
699 &ExecutableBuildOptions::set_result_layout)
700 .def_property("num_replicas", &ExecutableBuildOptions::num_replicas,
701 &ExecutableBuildOptions::set_num_replicas)
702 .def_property("num_partitions", &ExecutableBuildOptions::num_partitions,
703 &ExecutableBuildOptions::set_num_partitions)
704 .def_property_readonly(
705 "debug_options", &ExecutableBuildOptions::mutable_debug_options,
706 py::return_value_policy::reference, py::keep_alive<1, 0>())
707 .def_property(
708 "device_assignment",
709 [](const ExecutableBuildOptions& options)
710 -> std::optional<DeviceAssignment> {
711 return options.has_device_assignment()
712 ? std::optional<DeviceAssignment>(
713 options.device_assignment())
714 : std::nullopt;
715 },
716 &ExecutableBuildOptions::set_device_assignment)
717 .def_property("use_spmd_partitioning",
718 &ExecutableBuildOptions::use_spmd_partitioning,
719 &ExecutableBuildOptions::set_use_spmd_partitioning)
720 .def_property("use_auto_spmd_partitioning",
721 &ExecutableBuildOptions::use_auto_spmd_partitioning,
722 &ExecutableBuildOptions::set_use_auto_spmd_partitioning)
723 .def_property(
724 "auto_spmd_partitioning_mesh_shape",
725 &ExecutableBuildOptions::auto_spmd_partitioning_mesh_shape,
726 &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_shape)
727 .def_property(
728 "auto_spmd_partitioning_mesh_ids",
729 &ExecutableBuildOptions::auto_spmd_partitioning_mesh_ids,
730 &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_ids)
731 .def_property(
732 "allow_spmd_sharding_propagation_to_output",
733 &ExecutableBuildOptions::allow_spmd_sharding_propagation_to_output,
734 &ExecutableBuildOptions::
735 set_allow_spmd_sharding_propagation_to_output);
736
737 py::enum_<OpSharding::Type> op_sharding_type(m, "OpSharding_Type");
738 op_sharding_type.value("REPLICATED", OpSharding::REPLICATED)
739 .value("MAXIMAL", OpSharding::MAXIMAL)
740 .value("MANUAL", OpSharding::MANUAL)
741 .value("TUPLE", OpSharding::TUPLE)
742 .value("OTHER", OpSharding::OTHER);
743
744 py::class_<OpSharding> op_sharding(m, "OpSharding");
745 op_sharding
746 .def_property_readonly_static(
747 "Type",
748 [op_sharding_type](const py::object&) { return op_sharding_type; })
749 .def(py::init<>())
750 .def_property("type", &xla::OpSharding::type, &xla::OpSharding::set_type)
751 .def_property("replicate_on_last_tile_dim",
752 &xla::OpSharding::replicate_on_last_tile_dim,
753 &xla::OpSharding::set_replicate_on_last_tile_dim)
754 .def("__repr__", &xla::OpSharding::DebugString)
755 .def("SerializeToString",
756 [](const OpSharding& sharding) {
757 return py::bytes(sharding.SerializeAsString());
758 })
759 .def("clone",
760 [](const OpSharding& sharding) { return OpSharding(sharding); });
761 DefRepeatedProperty(op_sharding, "tile_assignment_dimensions",
762 &xla::OpSharding::mutable_tile_assignment_dimensions);
763 DefRepeatedProperty(op_sharding, "tile_assignment_devices",
764 &xla::OpSharding::mutable_tile_assignment_devices);
765 DefRepeatedProperty(op_sharding, "tuple_shardings",
766 &xla::OpSharding::mutable_tuple_shardings);
767 DefRepeatedProperty(op_sharding, "last_tile_dims",
768 &xla::OpSharding::mutable_last_tile_dims);
769
770 py::class_<HloSharding> hlo_sharding(m, "HloSharding");
771 hlo_sharding.def_static("from_proto", &xla::HloSharding::FromProto)
772 .def("__eq__", [](const xla::HloSharding& a,
773 const xla::HloSharding& b) { return a == b; })
774 .def("__hash__",
775 [](const xla::HloSharding& self) { return absl::HashOf(self); })
776 .def("is_replicated", &xla::HloSharding::IsReplicated)
777 .def("to_proto", &xla::HloSharding::ToProto);
778
779 py::class_<FrontendAttributes> frontend_attributes(m, "FrontendAttributes");
780 frontend_attributes.def(py::init<>())
781 .def("__setitem__",
782 [](FrontendAttributes* attr, std::string key, std::string value) {
783 (*attr->mutable_map())[key] = value;
784 });
785
786 py::enum_<PrecisionConfig::Precision>(m, "PrecisionConfig_Precision")
787 .value("DEFAULT", PrecisionConfig::DEFAULT)
788 .value("HIGH", PrecisionConfig::HIGH)
789 .value("HIGHEST", PrecisionConfig::HIGHEST);
790
791 py::enum_<ChannelHandle::ChannelType>(m, "ChannelHandle_ChannelType")
792 .value("CHANNEL_TYPE_INVALID", ChannelHandle::CHANNEL_TYPE_INVALID)
793 .value("DEVICE_TO_DEVICE", ChannelHandle::DEVICE_TO_DEVICE)
794 .value("DEVICE_TO_HOST", ChannelHandle::DEVICE_TO_HOST)
795 .value("HOST_TO_DEVICE", ChannelHandle::HOST_TO_DEVICE);
796
797 py::class_<ChannelHandle>(m, "ChannelHandle")
798 .def_property("type", &ChannelHandle::type,
799 [](ChannelHandle* h, ChannelHandle::ChannelType type) {
800 h->set_type(type);
801 })
802 .def_property(
803 "handle", &ChannelHandle::handle,
804 [](ChannelHandle* h, int64_t handle) { h->set_handle(handle); })
805 .def("__repr__", [](ChannelHandle* h) { return h->DebugString(); });
806
807 py::enum_<FftType>(m, "FftType")
808 .value("FFT", FftType::FFT)
809 .value("IFFT", FftType::IFFT)
810 .value("RFFT", FftType::RFFT)
811 .value("IRFFT", FftType::IRFFT);
812 } // NOLINT(readability/fn_size)
813 } // namespace xla
814