1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""An XLA client in Python.""" 16 17import atexit 18import contextlib 19import enum # pylint: disable=g-bad-import-order 20import gzip 21import inspect 22import os 23from typing import List, Sequence, Tuple, Union 24 25from . import xla_extension as _xla 26 27import numpy as np 28 29# Note this module does *not* depend on any Python protocol buffers. The XLA 30# Python bindings are currently packaged both as part of jaxlib and as part 31# of TensorFlow. If we use protocol buffers here, then importing both jaxlib 32# and TensorFlow may fail with duplicate protocol buffer message definitions. 33 34# Most functions are snake_case for consistency with other modules, some 35# method names are CamelCase for consistency with XLA. 36# pylint: disable=invalid-name 37 38# Pylint has false positives for type annotations. 39# pylint: disable=invalid-sequence-index 40 41ops = _xla.ops 42profiler = _xla.profiler 43 44# Just an internal arbitrary increasing number to help with backward-compatible 45# changes. 46_version = 89 47 48# Version number for MLIR:Python components. 49mlir_api_version = 32 50 51xla_platform_names = { 52 'cpu': 'Host', 53 'gpu': 'CUDA', 54} 55 56 57def make_interpreter_client(): 58 return _xla.get_interpreter_client() 59 60 61def make_cpu_client(*, use_tfrt: bool = True) -> ...: 62 if use_tfrt: 63 return _xla.get_tfrt_cpu_client(asynchronous=True) 64 else: 65 return _xla.get_cpu_client(asynchronous=True) 66 67 68def make_gpu_client(distributed_client=None, node_id=0, platform_name=None, 69 allowed_devices=None): 70 """Returns a GPU client. BFC allocator is used by default.""" 71 allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() 72 memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION') 73 preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE') 74 if allocator not in ('default', 'platform', 'bfc', 'cuda_async'): 75 raise ValueError( 76 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", ' 77 '"bfc", or "cuda_async", got "%s"' % allocator) 78 config = _xla.GpuAllocatorConfig() 79 if allocator == 'default': 80 config.kind = _xla.GpuAllocatorConfig.Kind.DEFAULT 81 if allocator == 'platform': 82 config.kind = _xla.GpuAllocatorConfig.Kind.PLATFORM 83 if allocator == 'bfc': 84 config.kind = _xla.GpuAllocatorConfig.Kind.BFC 85 if allocator == 'cuda_async': 86 config.kind = _xla.GpuAllocatorConfig.Kind.CUDA_ASYNC 87 if memory_fraction: 88 config.memory_fraction = float(memory_fraction) 89 config.preallocate = preallocate not in ('0', 'false', 'False') 90 91 return _xla.get_gpu_client( 92 asynchronous=True, 93 allocator_config=config, 94 distributed_client=distributed_client, 95 node_id=node_id, 96 platform_name=platform_name, 97 allowed_devices=allowed_devices) 98 99 100def make_tpu_client(): 101 """Returns a TPU client. Defaults to allowing 32 in-flight computations.""" 102 max_inflight_computations = os.getenv( 103 'JAX_TPU_MAX_INFLIGHT_COMPUTATIONS', '32') 104 try: 105 max_inflight_computations = int(max_inflight_computations) 106 except ValueError as e: 107 raise ValueError( 108 f'JAX_TPU_MAX_INFLIGHT_COMPUTATIONS env var must be an int, ' 109 f'got {max_inflight_computations}') from e 110 return _xla.get_tpu_client( 111 max_inflight_computations=max_inflight_computations) 112 113 114class OpMetadata: 115 """Python representation of a xla.OpMetadata protobuf.""" 116 __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') 117 118 def __init__(self, op_type='', op_name='', source_file='', source_line=0): 119 self.op_type = op_type 120 self.op_name = op_name 121 self.source_file = source_file 122 self.source_line = source_line 123 124 125def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): 126 """Helper for use in source mapping that returns an OpMetadata object.""" 127 full_filename, lineno = inspect.stack()[skip_frames][1:3] 128 filename = os.path.basename(full_filename) 129 return OpMetadata( 130 op_type=op_type, 131 op_name=op_name, 132 source_file=filename, 133 source_line=lineno) 134 135 136PrimitiveType = _xla.PrimitiveType 137 138bfloat16 = _xla.bfloat16_dtype() 139 140XLA_ELEMENT_TYPE_TO_DTYPE = { 141 PrimitiveType.PRED: np.dtype('bool'), 142 PrimitiveType.S8: np.dtype('int8'), 143 PrimitiveType.S16: np.dtype('int16'), 144 PrimitiveType.S32: np.dtype('int32'), 145 PrimitiveType.S64: np.dtype('int64'), 146 PrimitiveType.U8: np.dtype('uint8'), 147 PrimitiveType.U16: np.dtype('uint16'), 148 PrimitiveType.U32: np.dtype('uint32'), 149 PrimitiveType.U64: np.dtype('uint64'), 150 PrimitiveType.BF16: np.dtype(bfloat16), 151 PrimitiveType.F16: np.dtype('float16'), 152 PrimitiveType.F32: np.dtype('float32'), 153 PrimitiveType.F64: np.dtype('float64'), 154 PrimitiveType.C64: np.dtype('complex64'), 155 PrimitiveType.C128: np.dtype('complex128'), 156 PrimitiveType.TUPLE: np.dtype(np.object_), 157 PrimitiveType.TOKEN: np.dtype(np.object_), 158} 159 160# Note the conversion on the key. Numpy has a known issue wherein dtype hashing 161# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, 162# when keying by dtype in this dict, we use the string form of dtypes. 163DTYPE_TO_XLA_ELEMENT_TYPE = { 164 str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() 165} 166 167 168def dtype_to_etype(dtype): 169 """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" 170 return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] 171 172 173Shape = _xla.Shape 174Shape.__doc__ = """ 175A Shape is an object defined in C++ that duck types like the following class: 176 177class Shape: 178 '''Represents an XLA shape. 179 180 A shape is either an array shape, having rank-many integer 181 dimensions and an element type (represented by a Numpy dtype), or it 182 is a tuple shape, having a shape for every tuple component: 183 184 type shape = 185 TupleShape of shape list 186 | ArrayShape of { dimensions: int list; element_type: dtype } 187 ''' 188 189 @staticmethod 190 def tuple_shape(tuple_shapes) -> Shape: 191 "Construct a tuple shape." 192 193 @staticmethod 194 def array_shape(element_type, dimensions, minor_to_major=None) -> Shape: 195 196 @staticmethod 197 def from_pyval(pyval) -> Shape: 198 "Returns a Shape that describes a tuple-tree of Numpy arrays." 199 200 def __init__(self, str) -> Shape: 201 "Parses a shape string." 202 def __eq__(self, other: Shape) -> bool: 203 def __ne__(self, other: Shape) -> bool: 204 def __hash__(self): 205 def __repr__(self): 206 def is_tuple(self) -> bool: 207 def is_array(self) -> bool: 208 def tuple_shapes(self) -> [Shape]: 209 def numpy_dtype(self) -> np.dtype: 210 "Like element_type(), but returns dtype('O') for a tuple shape." 211 def xla_element_type(self) -> PrimitiveType: 212 def element_type(self) -> np.dtype: 213 def dimensions(self) -> (int, int, ...): 214 def rank(self) -> int: 215 def with_major_to_minor_layout_if_absent(self) -> Shape: 216 "Returns a copy with missing layouts set to major-to-minor." 217 218 def to_serialized_proto(self) -> bytes: 219 "Returns 'shape' as a serialized proto." 220""" 221 222ProgramShape = _xla.ProgramShape 223ProgramShape.__doc__ = """ 224A ProgramShape is a C++ object that duck types like the following class. 225 226class ProgramShape: 227 def __init__(self, parameter_shapes, result_shape): 228 def parameter_shapes(self) -> [Shape]: 229 def result_shape(self) -> Shape: 230 def __repr__(self): 231""" 232 233ShapeIndex = _xla.ShapeIndex 234ShapeIndex.__doc__ = """ 235A Shape is an object defined in C++ that duck types like the following class: 236 237class ShapeIndex: 238 '''Represents an XLA ShapeIndex. 239 240 An index for specifying a particular nested subshape within a shape. Used in 241 ShapeUtil::GetSubshape and other interfaces. ShapeIndex defines a path through 242 the Shape tree where each element of ShapeIndex indexes into a tuple (or 243 nested tuple) within the shape. For a non-nested tuple, an index has a single 244 element. 245 ''' 246 247 def __init__(self, List[int]) -> ShapeIndex: 248 def __eq__(self, other: Shape) -> bool: 249 def __ne__(self, other: Shape) -> bool: 250 def __hash__(self): 251 def __repr__(self): 252""" 253 254 255def shape_from_pyval(pyval): 256 """Returns a Shape that describes a tuple-tree of Numpy arrays.""" 257 258 def convert(pyval): 259 if isinstance(pyval, tuple): 260 return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) 261 else: 262 return Shape.array_shape(pyval.dtype, np.shape(pyval)) 263 264 return convert(pyval) 265 266 267DeviceAssignment = _xla.DeviceAssignment 268DeviceAssignment.__doc__ = """ 269A DeviceAssignment is a C++ object with the following signature. 270 271def create(assignment): 272 '''Builds a device assignment. 273 274 Args: 275 assignment: a 2D numpy array of device ordinal integers, indexed by 276 [replica][computation_in_replica]. 277 Returns: 278 A device assignment. 279 ''' 280 281def replica_count(): 282 '''Returns the number of replicas.''' 283def computation_count(): 284 '''Returns the number of computations per replica.''' 285""" 286 287Device = _xla.Device 288CompileOptions = _xla.CompileOptions 289 290HostBufferSemantics = _xla.HostBufferSemantics 291 292# An Executable is a C++ class that duck types with the following API: 293# class Executable: 294# def local_devices(self) -> [Device]: 295# def execute(self, arguments : [Buffer]) -> Buffer: 296# """Execute on one replica with Buffer arguments and return value.""" 297# 298# def size_of_generated_code_in_bytes(self) -> int: 299# """Return generated binary size, or -1 if not known.""" 300# 301# def execute_sharded_on_local_devices(self, arguments: [[Buffer]]) 302# -> [Buffer]: 303# """Execute on many replicas with Buffer arguments and return value. 304# 305# Args: 306# arguments: A sequence of sequences of Buffers. The i'th element of each 307# sequence comprises the arguments for execution on the i'th local 308# device. 309# 310# Returns: 311# A list of the computation's outputs as a list of Buffers for each 312# device. 313# """ 314# 315# There are different implementations of Executable for different backends. 316 317 318def execute_with_python_values(executable, arguments, backend): 319 """Execute on one replica with Python values as arguments and output.""" 320 321 def put(arg): 322 return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) 323 324 arguments = [put(arg) for arg in arguments] 325 outputs = executable.execute(arguments) 326 return [x.to_py() for x in outputs] 327 328 329def execute_with_python_values_replicated(executable, arguments, backend): 330 """Execute on many replicas with Python values as arguments and output. 331 332 Args: 333 executable: the program to run. 334 arguments: a list of lists of Python values indexed by `[replica][arg_num]` 335 to pass as inputs. 336 backend: the backend we are targeting. 337 338 Returns: 339 A list of python values, one per replica. 340 """ 341 devices = executable.local_devices() 342 343 # pylint: disable=g-complex-comprehension 344 def copy_to_devices(pyvals): 345 return [backend.buffer_from_pyval(v, d) for v, d in zip(pyvals, devices)] 346 347 inputs = [copy_to_devices(pyvals) for pyvals in zip(*arguments)] 348 outputs = executable.execute_sharded_on_local_devices(inputs) 349 return [[x.to_py() for x in xs] for xs in zip(*outputs)] 350 351 352class PaddingType(enum.Enum): 353 VALID = 1 354 SAME = 2 355 356 357def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, 358 window_strides): 359 """Maps PaddingType or string to pad values (list of pairs of ints).""" 360 if not isinstance(padding_type, (str, PaddingType)): 361 msg = 'padding_type must be str or PaddingType, got {}.' 362 raise TypeError(msg.format(type(padding_type))) 363 364 if isinstance(padding_type, str): 365 if padding_type.upper() == 'VALID': 366 padding_type = PaddingType.VALID 367 elif padding_type.upper() == 'SAME': 368 padding_type = PaddingType.SAME 369 else: 370 msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.' 371 raise ValueError(msg.format(padding_type)) 372 373 if padding_type == PaddingType.VALID: 374 return [(0, 0)] * len(window_strides) 375 elif padding_type == PaddingType.SAME: 376 out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) 377 pad_sizes = [ 378 max((out_size - 1) * stride + filter_size - in_size, 0) 379 for out_size, stride, filter_size, in_size in zip( 380 out_shape, window_strides, rhs_dims, lhs_dims) 381 ] 382 return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] 383 else: 384 msg = 'Unexpected PaddingType value: {}' 385 raise ValueError(msg.format(padding_type)) 386 387 388XlaBuilder = _xla.XlaBuilder 389XlaComputation = _xla.XlaComputation 390XlaOp = _xla.XlaOp 391FftType = _xla.FftType 392Client = _xla.Client 393Buffer = _xla.Buffer 394DeviceArrayBase = _xla.DeviceArrayBase 395Executable = _xla.Executable 396OpSharding = _xla.OpSharding 397HloSharding = _xla.HloSharding 398 399 400def register_custom_call_target(name, fn, platform='cpu'): 401 """Registers a custom call target. 402 403 Args: 404 name: bytes containing the name of the function. 405 fn: a PyCapsule object containing the function pointer. 406 platform: the target platform. 407 """ 408 # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" 409 # Since that is hardcoded to CUDA, we are using the following as workaround. 410 _xla.register_custom_call_target(name, fn, 411 xla_platform_names.get(platform, platform)) 412 413 414# Deprecated. Use register_custom_call_target instead. 415register_cpu_custom_call_target = register_custom_call_target 416 417 418class PaddingConfigDimension: 419 """Python representation of a xla.PaddingConfigDimension protobuf.""" 420 __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') 421 422 edge_padding_low: int 423 edge_padding_high: int 424 interior_padding: int 425 426 def __init__(self): 427 self.edge_padding_low = 0 428 self.edge_padding_high = 0 429 self.interior_padding = 0 430 431 432class PaddingConfig: 433 """Python representation of a xla.PaddingConfig protobuf.""" 434 __slots__ = ('dimensions',) 435 436 def __init__(self): 437 self.dimensions = [] 438 439 440def make_padding_config( 441 padding_config: Union[PaddingConfig, Sequence[Tuple[int, int, int]]] 442) -> PaddingConfig: 443 """Create PaddingConfig proto from list of triples of integers. 444 445 Args: 446 padding_config: either a PaddingConfig or a list of integer triples 447 (edge_padding_low, edge_padding_high, interior_padding) representing the 448 configuration of the padding operation. 449 450 Returns: 451 A `PaddingConfig` object. 452 """ 453 if not isinstance(padding_config, PaddingConfig): 454 triples = padding_config 455 padding_config = PaddingConfig() 456 for lo, hi, interior in triples: 457 dimension = PaddingConfigDimension() 458 dimension.edge_padding_low = lo 459 dimension.edge_padding_high = hi 460 dimension.interior_padding = interior 461 padding_config.dimensions.append(dimension) 462 return padding_config 463 464 465class DotDimensionNumbers: 466 """Python representation of a xla.DotDimensionNumbers protobuf.""" 467 __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions', 468 'lhs_batch_dimensions', 'rhs_batch_dimensions') 469 470 def __init__(self): 471 self.lhs_contracting_dimensions = [] 472 self.rhs_contracting_dimensions = [] 473 self.lhs_batch_dimensions = [] 474 self.rhs_batch_dimensions = [] 475 476 477def make_dot_dimension_numbers( 478 dimension_numbers: Union[DotDimensionNumbers, 479 Tuple[Tuple[List[int], List[int]], 480 Tuple[List[int], List[int]]]] 481) -> DotDimensionNumbers: 482 """Builds a DotDimensionNumbers object from a specification. 483 484 Args: 485 dimension_numbers: either a `DotDimensionNumbers` or a nested tuple 486 `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of 487 integers representing the dimensions to treat as contracting dimensions 488 and batch dimensions on each input operand. 489 490 Returns: 491 A `DotDimensionNumbers` object. 492 """ 493 if isinstance(dimension_numbers, (list, tuple)): 494 (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers 495 dot_dims_proto = DotDimensionNumbers() 496 dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) 497 dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) 498 dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) 499 dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) 500 return dot_dims_proto 501 else: 502 return dimension_numbers 503 504 505class ConvolutionDimensionNumbers: 506 """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" 507 __slots__ = ('input_batch_dimension', 'input_feature_dimension', 508 'input_spatial_dimensions', 'kernel_input_feature_dimension', 509 'kernel_output_feature_dimension', 'kernel_spatial_dimensions', 510 'output_batch_dimension', 'output_feature_dimension', 511 'output_spatial_dimensions') 512 513 def __init__(self): 514 self.input_batch_dimension = 0 515 self.input_feature_dimension = 0 516 self.input_spatial_dimensions = [] 517 self.kernel_input_feature_dimension = 0 518 self.kernel_output_feature_dimension = 0 519 self.kernel_spatial_dimensions = [] 520 self.output_batch_dimension = 0 521 self.output_feature_dimension = 0 522 self.output_spatial_dimensions = [] 523 524 525def make_convolution_dimension_numbers( 526 dimension_numbers: Union[None, ConvolutionDimensionNumbers, Tuple[str, str, 527 str]], 528 num_spatial_dimensions: int) -> ConvolutionDimensionNumbers: 529 """Builds a ConvolutionDimensionNumbers object from a specification. 530 531 Args: 532 dimension_numbers: optional, either a ConvolutionDimensionNumbers object or 533 a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of 534 length N+2 identifying by position: (1) batch dimensions in lhs, rhs, and 535 the output with the character 'N', (2) feature dimensions in lhs and the 536 output with the character 'C', (3) input and output feature dimensions 537 in rhs with the characters 'I' and 'O' respectively, and (4) spatial 538 dimension correspondences between lhs, rhs, and the output using any 539 distinct characters. For example, to indicate dimension numbers 540 consistent with the Conv operation with two spatial dimensions, one 541 could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate 542 dimension numbers consistent with the TensorFlow Conv2D operation, one 543 could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of 544 convolution dimension specification, window strides are associated with 545 spatial dimension character labels according to the order in which the 546 labels appear in the rhs_spec string, so that window_strides[0] is 547 matched with the dimension corresponding to the first character 548 appearing in rhs_spec that is not 'I' or 'O'. By default, use the same 549 dimension numbering as Conv and ConvWithGeneralPadding. 550 num_spatial_dimensions: the number of spatial dimensions. 551 552 Returns: 553 A `ConvolutionDimensionNumbers` object. 554 """ 555 if dimension_numbers is None: 556 nd = num_spatial_dimensions 557 dimension_numbers = ConvolutionDimensionNumbers() 558 dimension_numbers.input_batch_dimension = 0 559 dimension_numbers.input_feature_dimension = 1 560 dimension_numbers.output_batch_dimension = 0 561 dimension_numbers.output_feature_dimension = 1 562 dimension_numbers.kernel_output_feature_dimension = 0 563 dimension_numbers.kernel_input_feature_dimension = 1 564 dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) 565 dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) 566 dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) 567 elif isinstance(dimension_numbers, tuple): 568 lhs_spec, rhs_spec, out_spec = dimension_numbers 569 dimension_numbers = ConvolutionDimensionNumbers() 570 571 dimension_numbers.input_batch_dimension = lhs_spec.index('N') 572 dimension_numbers.input_feature_dimension = lhs_spec.index('C') 573 dimension_numbers.output_batch_dimension = out_spec.index('N') 574 dimension_numbers.output_feature_dimension = out_spec.index('C') 575 dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') 576 dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') 577 578 dimension_numbers.kernel_spatial_dimensions.extend( 579 i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) 580 dimension_numbers.input_spatial_dimensions.extend( 581 sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), 582 key=lambda i: rhs_spec.index(lhs_spec[i]))) 583 dimension_numbers.output_spatial_dimensions.extend( 584 sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), 585 key=lambda i: rhs_spec.index(out_spec[i]))) 586 return dimension_numbers 587 588 589class PrecisionConfig: 590 """Python representation of a xla.PrecisionConfig protobuf.""" 591 __slots__ = ('operand_precision',) 592 593 Precision = _xla.PrecisionConfig_Precision 594 595 def __init__(self): 596 self.operand_precision = [] 597 598 599class GatherDimensionNumbers: 600 """Python representation of a xla.GatherDimensionNumbers protobuf.""" 601 __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map', 602 'index_vector_dim') 603 604 def __init__(self): 605 self.offset_dims = [] 606 self.collapsed_slice_dims = [] 607 self.start_index_map = [] 608 self.index_vector_dim = 0 609 610 611class ScatterDimensionNumbers: 612 """Python representation of a xla.ScatterDimensionNumbers protobuf.""" 613 __slots__ = ('update_window_dims', 'inserted_window_dims', 614 'scatter_dims_to_operand_dims', 'index_vector_dim') 615 616 def __init__(self): 617 self.update_window_dims = [] 618 self.inserted_window_dims = [] 619 self.scatter_dims_to_operand_dims = [] 620 self.index_vector_dim = 0 621 622 623class ReplicaGroup: 624 """Python representation of a xla.ReplicaGroup protobuf.""" 625 __slots__ = ('replica_ids',) 626 627 def __init__(self): 628 self.replica_ids = [] 629 630 631def _make_replica_group_proto(replica_group): 632 replica_group_proto = ReplicaGroup() 633 replica_group_proto.replica_ids.extend(replica_group) 634 return replica_group_proto 635 636 637def make_replica_groups(replica_groups): 638 if replica_groups is None: 639 replica_groups_protos = [] # special value for XLA API 640 else: 641 replica_groups = list(replica_groups) 642 replica_groups_protos = [ 643 _make_replica_group_proto(group) for group in replica_groups 644 ] 645 return replica_groups_protos 646 647 648Traceback = _xla.Traceback 649Frame = _xla.Frame 650 651 652@contextlib.contextmanager 653def tracebacks(enabled=True): 654 """Context manager that enables or disables traceback collection.""" 655 saved = Traceback.enabled 656 Traceback.enabled = enabled 657 try: 658 yield 659 finally: 660 Traceback.enabled = saved 661 662 663def heap_profile(client: Client) -> bytes: 664 """Returns a gzipped pprof protocol buffer containing a heap profile.""" 665 return gzip.compress(client.heap_profile()) 666 667 668XlaRuntimeError = _xla.XlaRuntimeError 669 670# Perform one last garbage collection of deferred Python references. This is 671# mostly to keep ASAN happy. 672atexit.register(_xla.collect_garbage) 673 674weakref_lru_cache = _xla.weakref_lru_cache 675