xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/xla_client.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""An XLA client in Python."""
16
17import atexit
18import contextlib
19import enum  # pylint: disable=g-bad-import-order
20import gzip
21import inspect
22import os
23from typing import List, Sequence, Tuple, Union
24
25from . import xla_extension as _xla
26
27import numpy as np
28
29# Note this module does *not* depend on any Python protocol buffers. The XLA
30# Python bindings are currently packaged both as part of jaxlib and as part
31# of TensorFlow. If we use protocol buffers here, then importing both jaxlib
32# and TensorFlow may fail with duplicate protocol buffer message definitions.
33
34# Most functions are snake_case for consistency with other modules, some
35# method names are CamelCase for consistency with XLA.
36# pylint: disable=invalid-name
37
38# Pylint has false positives for type annotations.
39# pylint: disable=invalid-sequence-index
40
41ops = _xla.ops
42profiler = _xla.profiler
43
44# Just an internal arbitrary increasing number to help with backward-compatible
45# changes.
46_version = 89
47
48# Version number for MLIR:Python components.
49mlir_api_version = 32
50
51xla_platform_names = {
52    'cpu': 'Host',
53    'gpu': 'CUDA',
54}
55
56
57def make_interpreter_client():
58  return _xla.get_interpreter_client()
59
60
61def make_cpu_client(*, use_tfrt: bool = True) -> ...:
62  if use_tfrt:
63    return _xla.get_tfrt_cpu_client(asynchronous=True)
64  else:
65    return _xla.get_cpu_client(asynchronous=True)
66
67
68def make_gpu_client(distributed_client=None, node_id=0, platform_name=None,
69                    allowed_devices=None):
70  """Returns a GPU client. BFC allocator is used by default."""
71  allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower()
72  memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION')
73  preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE')
74  if allocator not in ('default', 'platform', 'bfc', 'cuda_async'):
75    raise ValueError(
76        'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", '
77        '"bfc", or "cuda_async", got "%s"' % allocator)
78  config = _xla.GpuAllocatorConfig()
79  if allocator == 'default':
80    config.kind = _xla.GpuAllocatorConfig.Kind.DEFAULT
81  if allocator == 'platform':
82    config.kind = _xla.GpuAllocatorConfig.Kind.PLATFORM
83  if allocator == 'bfc':
84    config.kind = _xla.GpuAllocatorConfig.Kind.BFC
85  if allocator == 'cuda_async':
86    config.kind = _xla.GpuAllocatorConfig.Kind.CUDA_ASYNC
87  if memory_fraction:
88    config.memory_fraction = float(memory_fraction)
89  config.preallocate = preallocate not in ('0', 'false', 'False')
90
91  return _xla.get_gpu_client(
92      asynchronous=True,
93      allocator_config=config,
94      distributed_client=distributed_client,
95      node_id=node_id,
96      platform_name=platform_name,
97      allowed_devices=allowed_devices)
98
99
100def make_tpu_client():
101  """Returns a TPU client. Defaults to allowing 32 in-flight computations."""
102  max_inflight_computations = os.getenv(
103      'JAX_TPU_MAX_INFLIGHT_COMPUTATIONS', '32')
104  try:
105    max_inflight_computations = int(max_inflight_computations)
106  except ValueError as e:
107    raise ValueError(
108        f'JAX_TPU_MAX_INFLIGHT_COMPUTATIONS env var must be an int, '
109        f'got {max_inflight_computations}') from e
110  return _xla.get_tpu_client(
111      max_inflight_computations=max_inflight_computations)
112
113
114class OpMetadata:
115  """Python representation of a xla.OpMetadata protobuf."""
116  __slots__ = ('op_type', 'op_name', 'source_file', 'source_line')
117
118  def __init__(self, op_type='', op_name='', source_file='', source_line=0):
119    self.op_type = op_type
120    self.op_name = op_name
121    self.source_file = source_file
122    self.source_line = source_line
123
124
125def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
126  """Helper for use in source mapping that returns an OpMetadata object."""
127  full_filename, lineno = inspect.stack()[skip_frames][1:3]
128  filename = os.path.basename(full_filename)
129  return OpMetadata(
130      op_type=op_type,
131      op_name=op_name,
132      source_file=filename,
133      source_line=lineno)
134
135
136PrimitiveType = _xla.PrimitiveType
137
138bfloat16 = _xla.bfloat16_dtype()
139
140XLA_ELEMENT_TYPE_TO_DTYPE = {
141    PrimitiveType.PRED: np.dtype('bool'),
142    PrimitiveType.S8: np.dtype('int8'),
143    PrimitiveType.S16: np.dtype('int16'),
144    PrimitiveType.S32: np.dtype('int32'),
145    PrimitiveType.S64: np.dtype('int64'),
146    PrimitiveType.U8: np.dtype('uint8'),
147    PrimitiveType.U16: np.dtype('uint16'),
148    PrimitiveType.U32: np.dtype('uint32'),
149    PrimitiveType.U64: np.dtype('uint64'),
150    PrimitiveType.BF16: np.dtype(bfloat16),
151    PrimitiveType.F16: np.dtype('float16'),
152    PrimitiveType.F32: np.dtype('float32'),
153    PrimitiveType.F64: np.dtype('float64'),
154    PrimitiveType.C64: np.dtype('complex64'),
155    PrimitiveType.C128: np.dtype('complex128'),
156    PrimitiveType.TUPLE: np.dtype(np.object_),
157    PrimitiveType.TOKEN: np.dtype(np.object_),
158}
159
160# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
161# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus,
162# when keying by dtype in this dict, we use the string form of dtypes.
163DTYPE_TO_XLA_ELEMENT_TYPE = {
164    str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items()
165}
166
167
168def dtype_to_etype(dtype):
169  """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE."""
170  return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
171
172
173Shape = _xla.Shape
174Shape.__doc__ = """
175A Shape is an object defined in C++ that duck types like the following class:
176
177class Shape:
178  '''Represents an XLA shape.
179
180  A shape is either an array shape, having rank-many integer
181  dimensions and an element type (represented by a Numpy dtype), or it
182  is a tuple shape, having a shape for every tuple component:
183
184    type shape =
185        TupleShape of shape list
186      | ArrayShape of { dimensions: int list; element_type: dtype }
187  '''
188
189  @staticmethod
190  def tuple_shape(tuple_shapes) -> Shape:
191    "Construct a tuple shape."
192
193  @staticmethod
194  def array_shape(element_type, dimensions, minor_to_major=None) -> Shape:
195
196  @staticmethod
197  def from_pyval(pyval) -> Shape:
198    "Returns a Shape that describes a tuple-tree of Numpy arrays."
199
200  def __init__(self, str) -> Shape:
201    "Parses a shape string."
202  def __eq__(self, other: Shape) -> bool:
203  def __ne__(self, other: Shape) -> bool:
204  def __hash__(self):
205  def __repr__(self):
206  def is_tuple(self) -> bool:
207  def is_array(self) -> bool:
208  def tuple_shapes(self) -> [Shape]:
209  def numpy_dtype(self) -> np.dtype:
210    "Like element_type(), but returns dtype('O') for a tuple shape."
211  def xla_element_type(self) -> PrimitiveType:
212  def element_type(self) -> np.dtype:
213  def dimensions(self) -> (int, int, ...):
214  def rank(self) -> int:
215  def with_major_to_minor_layout_if_absent(self) -> Shape:
216    "Returns a copy with missing layouts set to major-to-minor."
217
218  def to_serialized_proto(self) -> bytes:
219    "Returns 'shape' as a serialized proto."
220"""
221
222ProgramShape = _xla.ProgramShape
223ProgramShape.__doc__ = """
224A ProgramShape is a C++ object that duck types like the following class.
225
226class ProgramShape:
227  def __init__(self, parameter_shapes, result_shape):
228  def parameter_shapes(self) -> [Shape]:
229  def result_shape(self) -> Shape:
230  def __repr__(self):
231"""
232
233ShapeIndex = _xla.ShapeIndex
234ShapeIndex.__doc__ = """
235A Shape is an object defined in C++ that duck types like the following class:
236
237class ShapeIndex:
238  '''Represents an XLA ShapeIndex.
239
240  An index for specifying a particular nested subshape within a shape. Used in
241  ShapeUtil::GetSubshape and other interfaces. ShapeIndex defines a path through
242  the Shape tree where each element of ShapeIndex indexes into a tuple (or
243  nested tuple) within the shape. For a non-nested tuple, an index has a single
244  element.
245  '''
246
247  def __init__(self, List[int]) -> ShapeIndex:
248  def __eq__(self, other: Shape) -> bool:
249  def __ne__(self, other: Shape) -> bool:
250  def __hash__(self):
251  def __repr__(self):
252"""
253
254
255def shape_from_pyval(pyval):
256  """Returns a Shape that describes a tuple-tree of Numpy arrays."""
257
258  def convert(pyval):
259    if isinstance(pyval, tuple):
260      return Shape.tuple_shape(tuple(convert(elt) for elt in pyval))
261    else:
262      return Shape.array_shape(pyval.dtype, np.shape(pyval))
263
264  return convert(pyval)
265
266
267DeviceAssignment = _xla.DeviceAssignment
268DeviceAssignment.__doc__ = """
269A DeviceAssignment is a C++ object with the following signature.
270
271def create(assignment):
272  '''Builds a device assignment.
273
274   Args:
275     assignment: a 2D numpy array of device ordinal integers, indexed by
276       [replica][computation_in_replica].
277   Returns:
278     A device assignment.
279  '''
280
281def replica_count():
282  '''Returns the number of replicas.'''
283def computation_count():
284  '''Returns the number of computations per replica.'''
285"""
286
287Device = _xla.Device
288CompileOptions = _xla.CompileOptions
289
290HostBufferSemantics = _xla.HostBufferSemantics
291
292# An Executable is a C++ class that duck types with the following API:
293# class Executable:
294#   def local_devices(self) -> [Device]:
295#   def execute(self, arguments : [Buffer]) -> Buffer:
296#     """Execute on one replica with Buffer arguments and return value."""
297#
298#   def size_of_generated_code_in_bytes(self) -> int:
299#     """Return generated binary size, or -1 if not known."""
300#
301#   def execute_sharded_on_local_devices(self, arguments: [[Buffer]])
302#       -> [Buffer]:
303#     """Execute on many replicas with Buffer arguments and return value.
304#
305#     Args:
306#       arguments: A sequence of sequences of Buffers. The i'th element of each
307#         sequence comprises the arguments for execution on the i'th local
308#         device.
309#
310#     Returns:
311#       A list of the computation's outputs as a list of Buffers for each
312#       device.
313#     """
314#
315# There are different implementations of Executable for different backends.
316
317
318def execute_with_python_values(executable, arguments, backend):
319  """Execute on one replica with Python values as arguments and output."""
320
321  def put(arg):
322    return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])
323
324  arguments = [put(arg) for arg in arguments]
325  outputs = executable.execute(arguments)
326  return [x.to_py() for x in outputs]
327
328
329def execute_with_python_values_replicated(executable, arguments, backend):
330  """Execute on many replicas with Python values as arguments and output.
331
332  Args:
333    executable: the program to run.
334    arguments: a list of lists of Python values indexed by `[replica][arg_num]`
335      to pass as inputs.
336    backend: the backend we are targeting.
337
338  Returns:
339    A list of python values, one per replica.
340  """
341  devices = executable.local_devices()
342
343  # pylint: disable=g-complex-comprehension
344  def copy_to_devices(pyvals):
345    return [backend.buffer_from_pyval(v, d) for v, d in zip(pyvals, devices)]
346
347  inputs = [copy_to_devices(pyvals) for pyvals in zip(*arguments)]
348  outputs = executable.execute_sharded_on_local_devices(inputs)
349  return [[x.to_py() for x in xs] for xs in zip(*outputs)]
350
351
352class PaddingType(enum.Enum):
353  VALID = 1
354  SAME = 2
355
356
357def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
358                                      window_strides):
359  """Maps PaddingType or string to pad values (list of pairs of ints)."""
360  if not isinstance(padding_type, (str, PaddingType)):
361    msg = 'padding_type must be str or PaddingType, got {}.'
362    raise TypeError(msg.format(type(padding_type)))
363
364  if isinstance(padding_type, str):
365    if padding_type.upper() == 'VALID':
366      padding_type = PaddingType.VALID
367    elif padding_type.upper() == 'SAME':
368      padding_type = PaddingType.SAME
369    else:
370      msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.'
371      raise ValueError(msg.format(padding_type))
372
373  if padding_type == PaddingType.VALID:
374    return [(0, 0)] * len(window_strides)
375  elif padding_type == PaddingType.SAME:
376    out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int)
377    pad_sizes = [
378        max((out_size - 1) * stride + filter_size - in_size, 0)
379        for out_size, stride, filter_size, in_size in zip(
380            out_shape, window_strides, rhs_dims, lhs_dims)
381    ]
382    return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
383  else:
384    msg = 'Unexpected PaddingType value: {}'
385    raise ValueError(msg.format(padding_type))
386
387
388XlaBuilder = _xla.XlaBuilder
389XlaComputation = _xla.XlaComputation
390XlaOp = _xla.XlaOp
391FftType = _xla.FftType
392Client = _xla.Client
393Buffer = _xla.Buffer
394DeviceArrayBase = _xla.DeviceArrayBase
395Executable = _xla.Executable
396OpSharding = _xla.OpSharding
397HloSharding = _xla.HloSharding
398
399
400def register_custom_call_target(name, fn, platform='cpu'):
401  """Registers a custom call target.
402
403  Args:
404    name: bytes containing the name of the function.
405    fn: a PyCapsule object containing the function pointer.
406    platform: the target platform.
407  """
408  # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM"
409  # Since that is hardcoded to CUDA, we are using the following as workaround.
410  _xla.register_custom_call_target(name, fn,
411                                   xla_platform_names.get(platform, platform))
412
413
414# Deprecated. Use register_custom_call_target instead.
415register_cpu_custom_call_target = register_custom_call_target
416
417
418class PaddingConfigDimension:
419  """Python representation of a xla.PaddingConfigDimension protobuf."""
420  __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding')
421
422  edge_padding_low: int
423  edge_padding_high: int
424  interior_padding: int
425
426  def __init__(self):
427    self.edge_padding_low = 0
428    self.edge_padding_high = 0
429    self.interior_padding = 0
430
431
432class PaddingConfig:
433  """Python representation of a xla.PaddingConfig protobuf."""
434  __slots__ = ('dimensions',)
435
436  def __init__(self):
437    self.dimensions = []
438
439
440def make_padding_config(
441    padding_config: Union[PaddingConfig, Sequence[Tuple[int, int, int]]]
442) -> PaddingConfig:
443  """Create PaddingConfig proto from list of triples of integers.
444
445  Args:
446    padding_config: either a PaddingConfig or a list of integer triples
447      (edge_padding_low, edge_padding_high, interior_padding) representing the
448      configuration of the padding operation.
449
450  Returns:
451    A `PaddingConfig` object.
452  """
453  if not isinstance(padding_config, PaddingConfig):
454    triples = padding_config
455    padding_config = PaddingConfig()
456    for lo, hi, interior in triples:
457      dimension = PaddingConfigDimension()
458      dimension.edge_padding_low = lo
459      dimension.edge_padding_high = hi
460      dimension.interior_padding = interior
461      padding_config.dimensions.append(dimension)
462  return padding_config
463
464
465class DotDimensionNumbers:
466  """Python representation of a xla.DotDimensionNumbers protobuf."""
467  __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions',
468               'lhs_batch_dimensions', 'rhs_batch_dimensions')
469
470  def __init__(self):
471    self.lhs_contracting_dimensions = []
472    self.rhs_contracting_dimensions = []
473    self.lhs_batch_dimensions = []
474    self.rhs_batch_dimensions = []
475
476
477def make_dot_dimension_numbers(
478    dimension_numbers: Union[DotDimensionNumbers,
479                             Tuple[Tuple[List[int], List[int]],
480                                   Tuple[List[int], List[int]]]]
481) -> DotDimensionNumbers:
482  """Builds a DotDimensionNumbers object from a specification.
483
484  Args:
485    dimension_numbers: either a `DotDimensionNumbers` or a nested tuple
486      `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of
487      integers representing the dimensions to treat as contracting dimensions
488      and batch dimensions on each input operand.
489
490  Returns:
491    A `DotDimensionNumbers` object.
492  """
493  if isinstance(dimension_numbers, (list, tuple)):
494    (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
495    dot_dims_proto = DotDimensionNumbers()
496    dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract)
497    dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract)
498    dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch)
499    dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch)
500    return dot_dims_proto
501  else:
502    return dimension_numbers
503
504
505class ConvolutionDimensionNumbers:
506  """Python representation of a xla.ConvolutionDimensionNumbers protobuf."""
507  __slots__ = ('input_batch_dimension', 'input_feature_dimension',
508               'input_spatial_dimensions', 'kernel_input_feature_dimension',
509               'kernel_output_feature_dimension', 'kernel_spatial_dimensions',
510               'output_batch_dimension', 'output_feature_dimension',
511               'output_spatial_dimensions')
512
513  def __init__(self):
514    self.input_batch_dimension = 0
515    self.input_feature_dimension = 0
516    self.input_spatial_dimensions = []
517    self.kernel_input_feature_dimension = 0
518    self.kernel_output_feature_dimension = 0
519    self.kernel_spatial_dimensions = []
520    self.output_batch_dimension = 0
521    self.output_feature_dimension = 0
522    self.output_spatial_dimensions = []
523
524
525def make_convolution_dimension_numbers(
526    dimension_numbers: Union[None, ConvolutionDimensionNumbers, Tuple[str, str,
527                                                                      str]],
528    num_spatial_dimensions: int) -> ConvolutionDimensionNumbers:
529  """Builds a ConvolutionDimensionNumbers object from a specification.
530
531  Args:
532    dimension_numbers: optional, either a ConvolutionDimensionNumbers object or
533      a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of
534      length N+2 identifying by position: (1) batch dimensions in lhs, rhs, and
535        the output with the character 'N', (2) feature dimensions in lhs and the
536        output with the character 'C', (3) input and output feature dimensions
537        in rhs with the characters 'I' and 'O' respectively, and (4) spatial
538        dimension correspondences between lhs, rhs, and the output using any
539        distinct characters. For example, to indicate dimension numbers
540        consistent with the Conv operation with two spatial dimensions, one
541        could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate
542        dimension numbers consistent with the TensorFlow Conv2D operation, one
543        could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of
544        convolution dimension specification, window strides are associated with
545        spatial dimension character labels according to the order in which the
546        labels appear in the rhs_spec string, so that window_strides[0] is
547        matched with the dimension corresponding to the first character
548        appearing in rhs_spec that is not 'I' or 'O'. By default, use the same
549        dimension numbering as Conv and ConvWithGeneralPadding.
550    num_spatial_dimensions: the number of spatial dimensions.
551
552  Returns:
553    A `ConvolutionDimensionNumbers` object.
554  """
555  if dimension_numbers is None:
556    nd = num_spatial_dimensions
557    dimension_numbers = ConvolutionDimensionNumbers()
558    dimension_numbers.input_batch_dimension = 0
559    dimension_numbers.input_feature_dimension = 1
560    dimension_numbers.output_batch_dimension = 0
561    dimension_numbers.output_feature_dimension = 1
562    dimension_numbers.kernel_output_feature_dimension = 0
563    dimension_numbers.kernel_input_feature_dimension = 1
564    dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd))
565    dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd))
566    dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd))
567  elif isinstance(dimension_numbers, tuple):
568    lhs_spec, rhs_spec, out_spec = dimension_numbers
569    dimension_numbers = ConvolutionDimensionNumbers()
570
571    dimension_numbers.input_batch_dimension = lhs_spec.index('N')
572    dimension_numbers.input_feature_dimension = lhs_spec.index('C')
573    dimension_numbers.output_batch_dimension = out_spec.index('N')
574    dimension_numbers.output_feature_dimension = out_spec.index('C')
575    dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O')
576    dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I')
577
578    dimension_numbers.kernel_spatial_dimensions.extend(
579        i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'})
580    dimension_numbers.input_spatial_dimensions.extend(
581        sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}),
582               key=lambda i: rhs_spec.index(lhs_spec[i])))
583    dimension_numbers.output_spatial_dimensions.extend(
584        sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}),
585               key=lambda i: rhs_spec.index(out_spec[i])))
586  return dimension_numbers
587
588
589class PrecisionConfig:
590  """Python representation of a xla.PrecisionConfig protobuf."""
591  __slots__ = ('operand_precision',)
592
593  Precision = _xla.PrecisionConfig_Precision
594
595  def __init__(self):
596    self.operand_precision = []
597
598
599class GatherDimensionNumbers:
600  """Python representation of a xla.GatherDimensionNumbers protobuf."""
601  __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map',
602               'index_vector_dim')
603
604  def __init__(self):
605    self.offset_dims = []
606    self.collapsed_slice_dims = []
607    self.start_index_map = []
608    self.index_vector_dim = 0
609
610
611class ScatterDimensionNumbers:
612  """Python representation of a xla.ScatterDimensionNumbers protobuf."""
613  __slots__ = ('update_window_dims', 'inserted_window_dims',
614               'scatter_dims_to_operand_dims', 'index_vector_dim')
615
616  def __init__(self):
617    self.update_window_dims = []
618    self.inserted_window_dims = []
619    self.scatter_dims_to_operand_dims = []
620    self.index_vector_dim = 0
621
622
623class ReplicaGroup:
624  """Python representation of a xla.ReplicaGroup protobuf."""
625  __slots__ = ('replica_ids',)
626
627  def __init__(self):
628    self.replica_ids = []
629
630
631def _make_replica_group_proto(replica_group):
632  replica_group_proto = ReplicaGroup()
633  replica_group_proto.replica_ids.extend(replica_group)
634  return replica_group_proto
635
636
637def make_replica_groups(replica_groups):
638  if replica_groups is None:
639    replica_groups_protos = []  # special value for XLA API
640  else:
641    replica_groups = list(replica_groups)
642    replica_groups_protos = [
643        _make_replica_group_proto(group) for group in replica_groups
644    ]
645  return replica_groups_protos
646
647
648Traceback = _xla.Traceback
649Frame = _xla.Frame
650
651
652@contextlib.contextmanager
653def tracebacks(enabled=True):
654  """Context manager that enables or disables traceback collection."""
655  saved = Traceback.enabled
656  Traceback.enabled = enabled
657  try:
658    yield
659  finally:
660    Traceback.enabled = saved
661
662
663def heap_profile(client: Client) -> bytes:
664  """Returns a gzipped pprof protocol buffer containing a heap profile."""
665  return gzip.compress(client.heap_profile())
666
667
668XlaRuntimeError = _xla.XlaRuntimeError
669
670# Perform one last garbage collection of deferred Python references. This is
671# mostly to keep ASAN happy.
672atexit.register(_xla.collect_garbage)
673
674weakref_lru_cache = _xla.weakref_lru_cache
675