1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ====================================== 15 16"""Library of TPU helper functions.""" 17 18import collections 19import enum 20import typing 21from typing import Any, Callable, Iterable, List, Optional, Text, Tuple, Union 22 23from absl import logging 24import numpy as np 25 26from tensorflow.compiler.tf2xla.python import xla as tf2xla 27from tensorflow.core.framework import attr_value_pb2 28from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding 29from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as embedding_pb2 30from tensorflow.python import tf2 31from tensorflow.python.compiler.xla import xla 32from tensorflow.python.distribute import device_util 33from tensorflow.python.distribute import distribution_strategy_context 34from tensorflow.python.framework import auto_control_deps 35from tensorflow.python.framework import c_api_util 36from tensorflow.python.framework import composite_tensor 37from tensorflow.python.framework import config 38from tensorflow.python.framework import constant_op 39from tensorflow.python.framework import device as pydev 40from tensorflow.python.framework import dtypes 41from tensorflow.python.framework import errors 42from tensorflow.python.framework import func_graph 43from tensorflow.python.framework import function 44from tensorflow.python.framework import ops 45from tensorflow.python.framework import tensor_shape 46from tensorflow.python.ops import array_ops 47from tensorflow.python.ops import control_flow_ops 48from tensorflow.python.ops import math_ops 49from tensorflow.python.ops import variable_scope 50from tensorflow.python.ops import variables 51from tensorflow.python.tpu import device_assignment as device_assignment_lib 52from tensorflow.python.tpu import tpu_feed 53from tensorflow.python.tpu import tpu_function 54from tensorflow.python.tpu import tpu_name_util 55from tensorflow.python.tpu.ops import tpu_ops 56from tensorflow.python.types import core as core_types 57from tensorflow.python.util import compat 58from tensorflow.python.util import nest 59from tensorflow.python.util import object_identity 60from tensorflow.python.util import traceback_utils 61from tensorflow.python.util import variable_utils 62from tensorflow.python.util.tf_export import tf_export 63 64 65ops.NotDifferentiable("TPUReplicatedInput") 66 67# Operations that indicate some error in the users graph, e.g. a placeholder 68# that's introduced outside of the infeed. 69_DENYLISTED_OPS = set([ 70 "Placeholder", 71]) 72 73# XLA doesn't currently support reading of intermediate tensors, thus some ops 74# are not supported. 75_UNSUPPORTED_OPS = set([ 76 "AudioSummary", 77 "AudioSummaryV2", 78 "HistogramSummary", 79 "ImageSummary", 80 "MergeSummary", 81 "Print", 82 "ScalarSummary", 83 "TensorSummary", 84 "TensorSummaryV2", 85 ]) 86 87# Ops which can be safely pruned from XLA compile if they have no consumers. 88# These ops should also have no inputs. 89_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"]) 90 91_MAX_WARNING_LINES = 5 92 93_TPU_REPLICATE_ATTR = "_tpu_replicate" 94_POST_DEVICE_REWRITE_ATTR = "_post_device_rewrite" 95_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status" 96_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation" 97_PIVOT_FOR_CLUSTER = "_pivot_for_cluster" 98 99 100core = tpu_name_util.core 101 102 103def _tpu_system_device_name(job: Optional[Text]) -> Text: 104 """Returns the device name for the TPU_SYSTEM device of `job`.""" 105 if job is None: 106 return "/device:TPU_SYSTEM:0" 107 else: 108 return "/job:%s/device:TPU_SYSTEM:0" % job 109 110 111@tf_export(v1=["tpu.initialize_system"]) 112def initialize_system( 113 embedding_config: Optional[embedding_pb2.TPUEmbeddingConfiguration] = None, 114 job: Optional[Text] = None, 115 compilation_failure_closes_chips: bool = True, 116 tpu_cancellation_closes_chips: Optional[bool] = None, 117) -> core_types.Tensor: 118 """Initializes a distributed TPU system for use with TensorFlow. 119 120 Args: 121 embedding_config: If not None, a `TPUEmbeddingConfiguration` proto 122 describing the desired configuration of the hardware embedding lookup 123 tables. If embedding_config is None, no hardware embeddings can be used. 124 job: The job (the XXX in TensorFlow device specification /job:XXX) that 125 contains the TPU devices that will be initialized. If job=None it is 126 assumed there is only one job in the TensorFlow flock, and an error will 127 be returned if this assumption does not hold. 128 compilation_failure_closes_chips: Set the configuration whether 129 we want to close TPU chips when there is a compilation failure. 130 tpu_cancellation_closes_chips: Set the configuration whether 131 we want to close TPU chips when a TPU execution is cancelled. If the value 132 is None, the behavior will be determined by the command line flag 133 `tpu_cancellation_closes_chips` for the TPU worker. WARNING: this argument 134 only applies to TFRT TPU runtime. 135 Returns: 136 A serialized `TopologyProto` that describes the TPU system. Note: 137 the topology must be evaluated using `Session.run` before it can be used. 138 """ 139 config_string = ("" if embedding_config is None else 140 embedding_config.SerializeToString()) 141 142 # The enum is defined in core/tpu/kernels/tpu_execute_op_options.h. 143 tpu_cancellation_closes_chips_enum = 0 144 if tpu_cancellation_closes_chips is not None: 145 if tpu_cancellation_closes_chips: 146 tpu_cancellation_closes_chips_enum = 1 147 else: 148 tpu_cancellation_closes_chips_enum = 2 149 150 with ops.device(_tpu_system_device_name(job)): 151 topology = tpu_ops.configure_distributed_tpu( 152 compilation_failure_closes_chips=compilation_failure_closes_chips, 153 tpu_cancellation_closes_chips=tpu_cancellation_closes_chips_enum, 154 ) 155 156 if embedding_config is None: 157 return topology 158 159 # This set of control dependencies is needed as this function is expected to 160 # return an op which will return the topology when executed, but we need to 161 # call the embedding initialization op between initializing the TPU and 162 # returning the topology. 163 with ops.control_dependencies([topology]): 164 embedding_init = tpu_ops.configure_tpu_embedding(config=config_string) 165 with ops.control_dependencies([embedding_init]): 166 return array_ops.identity(topology, name="tpu_init_identity") 167 168 169def initialize_system_for_tpu_embedding( 170 embedding_config: embedding_pb2.TPUEmbeddingConfiguration, 171 job: Optional[Text] = None, 172) -> ops.Operation: 173 """Initializes a distributed TPU Embedding system for use with TensorFlow. 174 175 The following two are equivalent: 176 1. initialize_system() with embedding_config. 177 2. initialize_system() without embedding_config, then 178 initialize_system_for_tpu_embedding(). 179 initialize_system() should not be called with embedding_config if 180 initialize_system_for_tpu_embedding() is meant to be called later. 181 182 Args: 183 embedding_config: a `TPUEmbeddingConfiguration` proto describing the desired 184 configuration of the hardware embedding lookup tables. 185 job: The job (the XXX in TensorFlow device specification /job:XXX) that 186 contains the TPU devices that will be initialized. If job=None it is 187 assumed there is only one job in the TensorFlow flock, and an error will 188 be returned if this assumption does not hold. 189 190 Returns: 191 A no-op. 192 """ 193 config_string = embedding_config.SerializeToString() 194 with ops.device(_tpu_system_device_name(job)): 195 return tpu_ops.configure_tpu_embedding(config=config_string) 196 197 198@tf_export(v1=["tpu.shutdown_system"]) 199def shutdown_system(job: Optional[Text] = None) -> ops.Operation: 200 """Shuts down a running a distributed TPU system. 201 202 Args: 203 job: The job (the XXX in TensorFlow device specification /job:XXX) that 204 contains the TPU devices that will be shutdown. If job=None it is 205 assumed there is only one job in the TensorFlow flock, and an error will 206 be returned if this assumption does not hold. 207 """ 208 with ops.device(_tpu_system_device_name(job)): 209 shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu() 210 return shutdown_distributed_tpu 211 212 213def _enclosing_tpu_context_and_graph() -> Tuple[Any, Any]: 214 """Returns the TPUReplicateContext and its associated graph.""" 215 graph = ops.get_default_graph() 216 while graph is not None: 217 # pylint: disable=protected-access 218 context_ = graph._get_control_flow_context() 219 # pylint: enable=protected-access 220 while context_ is not None: 221 if isinstance(context_, TPUReplicateContext): 222 return context_, graph 223 context_ = context_.outer_context 224 graph = getattr(graph, "outer_graph", None) 225 raise ValueError("get_replicated_var_handle() called without " 226 "TPUReplicateContext. This shouldn't happen. Please file " 227 "a bug.") 228 229 230def is_tpu_strategy(strategy: Any) -> bool: 231 is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy") 232 clz = strategy.__class__ 233 return is_tpu_strat(clz) or any(map(is_tpu_strat, clz.__bases__)) 234 235 236def _enclosing_tpu_device_assignment( 237) -> Optional[device_assignment_lib.DeviceAssignment]: 238 if not distribution_strategy_context.has_strategy(): 239 return None 240 strategy = distribution_strategy_context.get_strategy() 241 if not is_tpu_strategy(strategy): 242 return None 243 return strategy.extended._device_assignment # pylint: disable=protected-access 244 245 246@auto_control_deps.register_acd_resource_resolver 247def tpu_replicated_input_resolver( 248 op: ops.Operation, 249 resource_reads: object_identity.ObjectIdentitySet, 250 resource_writes: object_identity.ObjectIdentitySet) -> bool: 251 """Replaces TPUReplicatedInput outputs with its inputs in resource_inputs.""" 252 # Ignore TPUReplicatedInput for ACD purposes since we will be directly adding 253 # control deps on the replicated inputs. 254 if op.type == "TPUReplicatedInput": 255 if resource_reads or resource_writes: 256 resource_reads.clear() 257 resource_writes.clear() 258 return True 259 else: 260 return False 261 # Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput 262 # with the actual replicated inputs. This allows ACD to correct add control 263 # deps when there are multiple calls to `run` in a 264 # `tf.function`. 265 def replace_with_unreplicated_resources(resource_inputs): 266 """Replaces handles in `resource_inputs` with their unreplicated inputs.""" 267 to_remove = [] 268 to_add = [] 269 for resource in resource_inputs: 270 if resource.op.type == "TPUReplicatedInput": 271 to_remove.append(resource) 272 to_add.extend(resource.op.inputs) 273 for t in to_remove: 274 resource_inputs.discard(t) 275 resource_inputs.update(to_add) 276 return to_add or to_remove 277 278 return bool(replace_with_unreplicated_resources(resource_reads) or 279 replace_with_unreplicated_resources(resource_writes)) 280 281 282class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): 283 """A `ControlFlowContext` for nodes inside a TPU computation. 284 285 The primary role of `TPUReplicateContext` is to mark operators inside a 286 tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ 287 is a unique name. 288 289 We use a `ControlFlowContext` to perform the annotation since it integrates 290 with Tensorflow constructs like ResourceVariables. For example, if a 291 `ResourceVariable` is constructed inside a tpu.replicate() block, the 292 `ResourceVariable` implementation can use 293 `with ops.control_dependencies(None)` to build the variable's definition 294 outside the replicated computation. 295 """ 296 297 def __init__(self, name: Text, num_replicas: int, pivot: ops.Operation): 298 """Builds a new TPUReplicateContext. 299 300 Args: 301 name: a unique name for the context, used to populate the `_tpu_replicate` 302 attribute. 303 num_replicas: an integer that gives the number of replicas for the 304 computation. 305 pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any 306 inputs will have a control dependency on the pivot node. This ensures 307 that nodes are correctly included in any enclosing control flow 308 contexts. 309 """ 310 super(TPUReplicateContext, self).__init__() 311 self._num_replicas = num_replicas 312 self._outer_device_function_stack = None 313 self._oc_dev_fn_stack = None 314 self._outside_compilation_cluster = None 315 self._outside_compilation_v2_context = None 316 self._outside_compilation_counter = 0 317 self._in_gradient_colocation = None 318 self._gradient_colocation_stack = [] 319 self._host_compute_core = [] 320 self._name = name 321 self._name_as_bytes = compat.as_bytes(name) 322 self._tpu_relicate_attr_buf = c_api_util.ScopedTFBuffer( 323 attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString()) 324 self._unsupported_ops = [] 325 self._pivot = pivot 326 self._replicated_vars = {} 327 328 def get_replicated_var_handle(self, 329 name: Text, 330 handle_id: Text, 331 vars_: Union[List[core_types.Tensor], 332 List[variables.Variable]], 333 is_mirrored: bool = False, 334 is_packed: bool = False) -> core_types.Tensor: 335 """Returns a variable handle for replicated TPU variable 'var'. 336 337 This is a method used by an experimental replicated variable implementation 338 and is not intended as a public API. 339 340 Args: 341 name: The common name of the variable. 342 handle_id: Unique ID of the variable handle, used as the cache key. 343 vars_: The replicated TPU variables or handles. 344 is_mirrored: Whether the variables are mirrored, which guarantees the 345 values in each replica are always the same. 346 is_packed: Whether the replicated variables are packed into one variable. 347 348 Returns: 349 The handle of the TPU replicated input node. 350 """ 351 device_assignment = _enclosing_tpu_device_assignment() 352 # We don't need to put device assignment as part of the replicated_vars key 353 # because each TPUReplicateContext will only have one device assignment. 354 handle = self._replicated_vars.get(handle_id) 355 if handle is not None: 356 return handle 357 358 if device_assignment is not None and not is_packed: 359 # Find a variable copy for each replica in the device assignment. 360 # Note that the order of devices for replicas for the variable and the 361 # device assignment might not match. 362 job_name = pydev.DeviceSpec.from_string(vars_[0].device).job 363 devices_to_vars = {device_util.canonicalize(v.device): v for v in vars_} 364 replicated_vars = [] 365 for replica_id in range(device_assignment.num_replicas): 366 for logical_core in range(device_assignment.num_cores_per_replica): 367 device = device_util.canonicalize( 368 device_assignment.tpu_device( 369 replica=replica_id, logical_core=logical_core, job=job_name)) 370 if device in devices_to_vars: 371 replicated_vars.append(devices_to_vars[device]) 372 break 373 else: 374 raise ValueError( 375 "Failed to find a variable on any device in replica {} for " 376 "current device assignment".format(replica_id)) 377 else: 378 replicated_vars = vars_ 379 380 # Builds a TPUReplicatedInput node for the variable, if one does not already 381 # exist. The TPUReplicatedInput node must belong to the enclosing 382 # control-flow scope of the TPUReplicateContext. 383 # TODO(phawkins): consider changing the contract of the TPU encapsulation 384 # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope 385 # instead. 386 387 _, graph = _enclosing_tpu_context_and_graph() 388 with graph.as_default(): 389 # If replicated_vars are variables, get the handles. Note that this can be 390 # done inside TPUReplicateContext because replicated_vars.handle may 391 # create new ops. 392 if isinstance(replicated_vars[0], variables.Variable): 393 replicated_vars = [v.handle for v in replicated_vars] 394 # pylint: disable=protected-access 395 saved_context = graph._get_control_flow_context() 396 graph._set_control_flow_context(self.outer_context) 397 handle = tpu_ops.tpu_replicated_input(replicated_vars, 398 name=name + "/handle", 399 is_mirrored_variable=is_mirrored, 400 is_packed=is_packed) 401 graph._set_control_flow_context(saved_context) 402 # pylint: enable=protected-access 403 self._replicated_vars[handle_id] = handle 404 return handle 405 406 def report_unsupported_operations(self) -> None: 407 if self._unsupported_ops: 408 op_str = "\n".join(" %s (%s)" % (op.type, op.name) 409 for op in self._unsupported_ops[:_MAX_WARNING_LINES]) 410 logging.warning("%d unsupported operations found: \n%s", 411 len(self._unsupported_ops), op_str) 412 if len(self._unsupported_ops) > _MAX_WARNING_LINES: 413 logging.warning("... and %d more" % 414 (len(self._unsupported_ops) - _MAX_WARNING_LINES)) 415 416 def EnterGradientColocation(self, op: ops.Operation, gradient_uid: Text): 417 if op is not None: 418 if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access 419 # If we are in TF 2 functions (control flow V2 functions, or 420 # tf.function()), we need to attach _xla_outside_compilation attribute 421 # directly because we are not in TPUReplicateContext. 422 try: 423 outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii") 424 except ValueError: 425 # The attr was not present: do nothing. 426 return 427 parts = outside_attr.split(".") 428 cluster = parts[0] + "." + gradient_uid 429 self._outside_compilation_v2_context = OutsideCompilationV2Context( 430 cluster) 431 self._outside_compilation_v2_context.Enter() 432 return 433 self._gradient_colocation_stack.append(op) 434 if not self._outside_compilation_cluster: 435 try: 436 outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii") 437 if self._in_gradient_colocation: 438 raise NotImplementedError( 439 "Cannot nest gradient colocation operations outside compilation" 440 ) 441 if gradient_uid == "__unsupported__": 442 raise NotImplementedError( 443 "No gradient_uid calling gradient within outside_compilation") 444 # When we take the gradient of an op X in an outside_compilation 445 # cluster C in a forward computation we would like to put the ops 446 # corresponding to the gradient of X into a new outside_compilation 447 # cluster C'. However, if we take the gradient of X twice, the second 448 # one should get yet another new outside_compilation cluster C''. 449 # 450 # The mechanism we adopt is to use a 'root_cluster' which is the 451 # cluster that X was in before we took gradients, and a 'gradient_uid' 452 # which is different for every invocation of gradients, and put the 453 # gradient of X in cluster 'root_cluster.gradient_uid'. 454 # 455 # When taking a gradient of a gradient, some ops will be colocated 456 # with Op in the forward pass (e.g., cluster root_cluster) and some in 457 # the backward pass (e.g., cluster root_cluster.initial_gradient_uid). 458 # We need all of the grad-of-grad ops to be in the same cluster to 459 # avoid cyclic dependencies between clusters. We adopt a heuristic 460 # that puts any op clustered with root_cluster.<xxx> in 461 # root_cluster.gradient_uid, even if xxx was initial_gradient_uid. 462 self._in_gradient_colocation = op 463 parts = outside_attr.split(".") 464 cluster = parts[0] + "." + gradient_uid 465 self._EnterOutsideCompilationScope(cluster=cluster) 466 except ValueError: 467 # The attr was not present: do nothing. 468 pass 469 470 def ExitGradientColocation(self, op: ops.Operation, gradient_uid: Text): 471 if op is not None: 472 if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access 473 # Inside a TF2 tf.function or control flow graph and `op` was not 474 # marked to be outside compiled. 475 assert self._outside_compilation_v2_context is None 476 return 477 if self._outside_compilation_v2_context is not None: 478 # Inside a TF2 tf.function or control flow graph and `op` was 479 # marked to be outside compiled. 480 self._outside_compilation_v2_context.Exit() 481 self._outside_compilation_v2_context = None 482 return 483 if not self._gradient_colocation_stack: 484 raise errors.InternalError( 485 op.node_def, op, 486 f"Badly nested gradient colocation: empty stack when popping Op {op.name}" 487 ) 488 last_op = self._gradient_colocation_stack.pop() 489 if op is last_op: 490 if op is self._in_gradient_colocation: 491 self._in_gradient_colocation = None 492 self._ExitOutsideCompilationScope() 493 else: 494 raise errors.InternalError( 495 op.node_def, op, 496 f"Badly nested gradient colocation, expected {last_op}, got {op.name}" 497 ) 498 499 def _EnterOutsideCompilationScope(self, cluster: Optional[Text] = None): 500 501 class FakeOp(object): 502 """A helper class to determine the current device. 503 504 Supports only the type and device set/get methods needed to run the 505 graph's _apply_device_function method. 506 """ 507 508 def __init__(self): 509 self._device = "" 510 511 @property 512 def type(self): 513 return "FakeOp" 514 515 @property 516 def device(self): 517 return self._device 518 519 def _set_device(self, device): 520 if isinstance(device, pydev.DeviceSpec): 521 self._device = device.to_string() 522 else: 523 self._device = device 524 525 def _set_device_from_string(self, device_str): 526 self._device = device_str 527 528 if self._outside_compilation_cluster: 529 raise NotImplementedError("Cannot nest outside_compilation clusters") 530 if cluster: 531 self._outside_compilation_cluster = cluster 532 else: 533 self._outside_compilation_cluster = str(self._outside_compilation_counter) 534 self._outside_compilation_counter += 1 535 graph = ops.get_default_graph() 536 fake_op = FakeOp() 537 graph._apply_device_functions(fake_op) # pylint: disable=protected-access 538 device = pydev.DeviceSpec.from_string(fake_op.device) 539 if (device.device_type == "TPU_REPLICATED_CORE" and 540 device.device_index is not None): 541 self._host_compute_core.append(self._outside_compilation_cluster + ":" + 542 str(device.device_index)) 543 self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access 544 graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access 545 546 def _ExitOutsideCompilationScope(self): 547 if not self._outside_compilation_cluster: 548 raise ValueError( 549 "Attempted to exit outside_compilation scope when not in scope") 550 self._outside_compilation_cluster = None 551 graph = ops.get_default_graph() 552 graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access 553 554 def Enter(self) -> None: 555 if not self._outer_device_function_stack: 556 # Capture the device function stack at the time of first entry 557 # since that is the stack that will be used outside_compilation. 558 graph = ops.get_default_graph() 559 # pylint: disable=protected-access 560 self._outer_device_function_stack = graph._device_function_stack.copy() 561 # pylint: enable=protected-access 562 super(TPUReplicateContext, self).Enter() 563 564 def HostComputeCore(self) -> List[Text]: 565 return self._host_compute_core 566 567 def _RemoveExternalControlEdges( 568 self, op: ops.Operation 569 ) -> Tuple[List[ops.Operation], List[ops.Operation]]: 570 """Remove any external control dependency on this op.""" 571 internal_control_inputs = [] 572 external_control_inputs = [] 573 for x in op.control_inputs: 574 # pylint: disable=protected-access 575 is_internal_op = False 576 ctxt = x._get_control_flow_context() 577 while ctxt is not None: 578 if ctxt == self: 579 is_internal_op = True 580 break 581 ctxt = ctxt._outer_context 582 if is_internal_op: 583 internal_control_inputs.append(x) 584 else: 585 external_control_inputs.append(x) 586 # pylint: enable=protected-access 587 # pylint: disable=protected-access 588 op._remove_all_control_inputs() 589 op._add_control_inputs(internal_control_inputs) 590 # pylint: enable=protected-access 591 return internal_control_inputs, external_control_inputs 592 593 def AddOp(self, op: ops.Operation) -> None: 594 # pylint: disable=protected-access 595 if op.type in _DENYLISTED_OPS: 596 logging.error("Operation of type %s (%s) is not supported on the TPU. " 597 "Execution will fail if this op is used in the graph. ", 598 op.type, op.name) 599 600 if op.type in _UNSUPPORTED_OPS: 601 self._unsupported_ops.append(op) 602 603 if any(x.dtype._is_ref_dtype for x in op.inputs): 604 raise NotImplementedError( 605 f"Non-resource Variables are not supported inside TPU computations " 606 f"(operator name: {op.name})") 607 608 # TensorFlowOpLayer may clone nodes that are in tpu.rewrite()s. It'll add 609 # the "_cloned" attribute and we should continue in that case. 610 if (_TPU_REPLICATE_ATTR in op.node_def.attr and 611 "_cloned" not in op.node_def.attr): 612 raise ValueError(f"TPU computations cannot be nested on op ({op})") 613 op._set_attr_with_buf(_TPU_REPLICATE_ATTR, 614 self._tpu_relicate_attr_buf.buffer) 615 if self._outside_compilation_cluster: 616 op._set_attr( 617 _OUTSIDE_COMPILATION_ATTR, 618 attr_value_pb2.AttrValue( 619 s=compat.as_bytes(self._outside_compilation_cluster))) 620 if self._num_replicas > 1 or not self._outside_compilation_cluster: 621 # Prevent feeding or fetching anything that is being compiled, 622 # and any replicated outside_compilation Op. 623 op.graph.prevent_feeding(op) 624 op.graph.prevent_fetching(op) 625 626 # Remove any control edges from outer control flow contexts. These may cause 627 # mismatched frame errors. 628 (internal_control_inputs, 629 external_control_inputs) = self._RemoveExternalControlEdges(op) 630 631 if not op.inputs: 632 # Add a control edge from the control pivot to this op. 633 if not internal_control_inputs: 634 # pylint: disable=protected-access 635 op._add_control_input(self.GetControlPivot()) 636 # pylint: enable=protected-access 637 else: 638 for index in range(len(op.inputs)): 639 x = op.inputs[index] 640 real_x = self.AddValue(x) 641 if real_x is not x: 642 op._update_input(index, real_x) # pylint: disable=protected-access 643 644 if external_control_inputs: 645 # Use an identity to pull control inputs as data inputs. Note that we 646 # ignore ops which don't have outputs. TODO(phawkins): fix that. 647 with ops.control_dependencies(None): 648 self.Enter() 649 external_control_inputs = [ 650 array_ops.identity(x.outputs[0]).op 651 for x in external_control_inputs 652 if x.outputs 653 ] 654 self.Exit() 655 # pylint: disable=protected-access 656 op._add_control_inputs(external_control_inputs) 657 # pylint: enable=protected-access 658 659 # Mark op's outputs as seen by this context and any outer contexts. 660 output_names = [x.name for x in op.outputs] 661 context = self 662 while context is not None: 663 # pylint: disable=protected-access 664 context._values.update(output_names) 665 context = context._outer_context 666 # pylint: enable=protected-access 667 668 if self._outer_context: 669 self._outer_context.AddInnerOp(op) 670 671 def AddValue(self, val: core_types.Tensor) -> core_types.Tensor: 672 """Add `val` to the current context and its outer context recursively.""" 673 if not self._outer_context: 674 return val 675 676 if val.name in self._values: 677 # Use the real value if it comes from outer context. 678 result = self._external_values.get(val.name) 679 return val if result is None else result 680 681 result = val 682 self._values.add(val.name) 683 if self._outer_context: 684 result = self._outer_context.AddValue(val) 685 self._values.add(result.name) 686 687 self._external_values[val.name] = result 688 689 return result 690 691 def AddInnerOp(self, op: ops.Operation): 692 self.AddOp(op) 693 if self._outer_context: 694 self._outer_context.AddInnerOp(op) 695 696 @property 697 def grad_state(self): 698 # Define the gradient loop state associated with the TPUReplicateContext to 699 # be None as the TPUReplicateContext does not get nested nor does the 700 # grad_state outside the TPUReplicateContext affect the graph inside so the 701 # grad_state should be as if this is the top-level gradient state. 702 return None 703 704 @property 705 def back_prop(self): 706 """Forwards to the enclosing while context, if any.""" 707 if self.GetWhileContext(): 708 return self.GetWhileContext().back_prop 709 return False 710 711 def GetControlPivot(self) -> ops.Operation: 712 return self._pivot 713 714 def RequiresUniqueFunctionRetracing(self): 715 # More context: b/158152827. TPU stack uses the TPUReplicateContext to 716 # create replicated variable handles and cluster TPU computations, thus we 717 # always retrace a tf.function when the wrapped TPUReplicateContext changes. 718 return True 719 720 721class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext): 722 """The context for outside compilation in Tensorflow 2.0. 723 724 Every op added in this context will be assigned an _xla_outside_compilation 725 attribute. 726 """ 727 728 def __init__(self, name: Text): 729 control_flow_ops.ControlFlowContext.__init__(self) 730 self._name = name 731 732 def AddOp(self, op: ops.Operation) -> None: 733 if self._outer_context: 734 self._outer_context.AddOp(op) 735 # pylint: disable=protected-access 736 op._set_attr("_xla_outside_compilation", 737 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) 738 # pylint: enable=protected-access 739 740 def AddInnerOp(self, op: ops.Operation) -> None: 741 if self._outer_context: 742 self._outer_context.AddInnerOp(op) 743 # pylint: disable=protected-access 744 op._set_attr("_xla_outside_compilation", 745 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) 746 # pylint: enable=protected-access 747 748 def to_control_flow_context_def(self, context_def, export_scope=None): 749 raise NotImplementedError 750 751 752@tf_export(v1=["tpu.outside_compilation"]) 753def outside_compilation( 754 computation: Callable[..., Any], *args, **kwargs 755 ) -> Any: 756 """Builds part of a computation outside any current TPU replicate scope. 757 758 `tf.tpu.outside_compilation()` is used to run ops in `computation` on CPU 759 instead of running on TPU. For example, users can run ops that are not 760 supported on TPU's (e.g. tf.summary.write()) by explicitly placing those 761 ops on CPU's. Below usage of outside compilation will place ops in 762 `computation_with_string_ops` on CPU. 763 764 Example usage: 765 766 ```python 767 def computation_with_string_ops(x): 768 # strings types are not supported on TPU's and below ops must 769 # run on CPU instead. 770 output = tf.strings.format('1{}', x) 771 return tf.strings.to_number(output) 772 773 def tpu_computation(): 774 # Expected output is 11. 775 output = tf.tpu.outside_compilation(computation_with_string_ops, 1) 776 ``` 777 778 Outside compilation should be called inside TPUReplicateContext. That is, 779 `tf.tpu.outside_compilation()` should be called inside a function that is 780 passed to `tpu.split_compile_and_replicate()` -- this is implied when 781 outside compilation is invoked inside a function passed to TPUStrategy 782 `run()`. If invoked outside of TPUReplicateContext, 783 then this simply returns the result of `computation`, and therefore, 784 would be a no-op. Note that outside compilation is different from 785 `tf.distribute.experimental.TPUStrategy.merge_call()` as logic in 786 outside compilation is replicated and executed separately for each 787 replica. On the other hand, `merge_call()` requires a `merge_fn` 788 to aggregate the inputs from different replicas and is executed only 789 once. 790 791 For variables placed in TPU device, which includes variables created inside 792 TPUStrategy scope, outside compilation logic must not include variable 793 read/write. For variables placed on host, which is the case when variables 794 created via TPUEstimator, variable read/write is only allowed if the variable 795 is not accessed by any other ops in the TPU computation. Variable read/write 796 from outside compilation cluster is not visible from TPU computation and 797 vice versa. Therefore, if outside compilation logic contains such host 798 variables read/write ops and if the variables are accessed by TPU 799 computation as well, then this may lead to deadlock. 800 801 Internally, `tf.tpu.outside_compilation()` adds outside compilation 802 attributes to all ops in `computation`. During later graph pass, these 803 ops with outside compilation attribute is extracted out and replicated 804 into a host-side graph. Inputs to this extract host-side graph is sent 805 from TPU computation graph to host graph via a pair of XlaSendToHost and 806 XlaRecvFromHost ops. Note that using `tf.tpu.outside_compilation()` 807 may result in tensor transfer between TPU and CPU, leading to non-trivial 808 performance impact. 809 810 Args: 811 computation: A Python function that builds the computation to 812 place on the host. 813 *args: the positional arguments for the computation. 814 **kwargs: the keyword arguments for the computation. 815 816 Returns: 817 The Tensors returned by computation. 818 """ 819 args = [] if args is None else args 820 graph = ops.get_default_graph() 821 822 # If we are in TF 2 functions (control flow V2 functions, or tf.function()), 823 # we need to attach _xla_outside_compilation attribute directly because we are 824 # not in TPUReplicateContext. 825 if isinstance(graph, func_graph.FuncGraph): 826 try: 827 tpu_context, _ = _enclosing_tpu_context_and_graph() 828 except ValueError: 829 logging.warning( 830 "Outside compilation attempted outside TPUReplicateContext " 831 "scope. As no enclosing TPUReplicateContext can be found, " 832 "returning the result of `computation` as is.") 833 return computation(*args, **kwargs) 834 835 # pylint: disable=protected-access 836 outside_compilation_name = str(tpu_context._outside_compilation_counter) 837 tpu_context._outside_compilation_counter = ( 838 tpu_context._outside_compilation_counter + 1) 839 # pylint: enable=protected-access 840 841 outside_compilation_context = OutsideCompilationV2Context( 842 outside_compilation_name) 843 outside_compilation_context.Enter() 844 args = [] if args is None else args 845 retval = computation(*args, **kwargs) 846 outside_compilation_context.Exit() 847 return retval 848 849 # If we are in a TPUReplicateContext, signal that we are now 850 # outside_compilation 851 initial_context = graph._get_control_flow_context() # pylint: disable=protected-access 852 context = initial_context 853 while context: 854 if isinstance(context, TPUReplicateContext): 855 context._EnterOutsideCompilationScope() # pylint: disable=protected-access 856 context = context.outer_context 857 858 retval = computation(*args, **kwargs) 859 860 # If we are in a TPUReplicateContext, signal that we are no longer 861 # outside_compilation 862 final_context = graph._get_control_flow_context() # pylint: disable=protected-access 863 if initial_context is not final_context: 864 raise NotImplementedError( 865 "Control-flow context cannot be different at start and end of an " 866 "outside_compilation scope") 867 context = initial_context 868 while context: 869 if isinstance(context, TPUReplicateContext): 870 context._ExitOutsideCompilationScope() # pylint: disable=protected-access 871 context = context.outer_context 872 873 return retval 874 875 876@tf_export(v1=["tpu.PaddingSpec"]) 877class PaddingSpec(enum.IntEnum): 878 """Represents the type of padding policies for tpu.replicate.""" 879 # By default the policy is set to AUTO, the dynamic input shape dimension will 880 # be pad to maximum of all the replicas. 881 AUTO = 0 882 # Bucketize the dynamic input shape dimension into a power of 2. 883 POWER_OF_TWO = 1 884 885 886@tf_export("tpu.XLAOptions") 887class XLAOptions( 888 collections.namedtuple("XLAOptions", [ 889 "use_spmd_for_xla_partitioning", 890 "enable_xla_dynamic_padder", 891 ])): 892 """XLA compilation options. 893 894 Attributes: 895 use_spmd_for_xla_partitioning: Boolean. Whether to use XLA's SPMD 896 partitioner instead of MPMD partitioner when compiler partitioning is 897 requested. 898 enable_xla_dynamic_padder: Boolean. Whether to enable XLA dynamic padder 899 infrastructure to handle dynamic shapes inputs inside XLA. True by 900 default. Disabling this may cause correctness issues with dynamic shapes 901 inputs, as XLA will just assume the inputs are with padded shapes. However 902 users can optionally set it to False to improve device time if masking is 903 already handled in the user side. 904 """ 905 906 def __new__(cls, 907 use_spmd_for_xla_partitioning=True, 908 enable_xla_dynamic_padder=True): 909 return super(XLAOptions, cls).__new__(cls, use_spmd_for_xla_partitioning, 910 enable_xla_dynamic_padder) 911 912 913@tf_export(v1=["tpu.replicate"]) 914@traceback_utils.filter_traceback 915def replicate( 916 computation: Callable[..., Any], 917 inputs: Optional[List[List[core_types.Tensor]]] = None, 918 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 919 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 920 name: Optional[Text] = None, 921 maximum_shapes: Optional[Any] = None, 922 padding_spec: Optional[PaddingSpec] = None, 923 xla_options: Optional[XLAOptions] = None) -> List[Any]: 924 """Builds a graph operator that runs a replicated TPU computation. 925 926 Example for the basic usage that `inputs` has static shape: 927 928 ```python 929 930 def computation(x): 931 x = x + 1 932 return tf.math.reduce_mean(x) 933 934 x = tf.convert_to_tensor([1., 2., 3.]) 935 y = tf.convert_to_tensor([4., 5., 6.]) 936 tf.compat.v1.tpu.replicate(computation, inputs=[[x], [y]]) 937 ``` 938 939 If the `inputs` has dynamic shapes and you would like to automatically 940 bucketize the inputs to avoid XLA recompilation. See the advanced example 941 below: 942 943 ```python 944 945 def computation(x): 946 x = x + 1 947 return tf.math.reduce_mean(x) 948 949 # Assume input tensors in two replicas `x` and `y` both have dynamic shape 950 # ([None, 2]). 951 tf.compat.v1.tpu.replicate( 952 computation, 953 inputs=[x, y], 954 maximum_shapes=[tf.TensorShape([None, None])], 955 padding_spec=tf.compat.v1.tpu.PaddingSpec.POWER_OF_TWO) 956 ``` 957 958 Args: 959 computation: A Python function that builds the computation to replicate. 960 inputs: A list of lists of input tensors or `None` (equivalent to 961 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must 962 have the same number of inputs. Each input can be a nested structure 963 containing values that are convertible to tensors. Note that passing an 964 N-dimension list of compatible values will result in a N-dimension list of 965 scalar tensors rather than a single Rank-N tensors. If you need different 966 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. 967 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 968 of arguments as inputs to computation. 969 device_assignment: If not `None`, a `DeviceAssignment` describing the 970 mapping between logical cores in the computation with physical cores in 971 the TPU topology. Uses a default device assignment if `None`. The 972 `DeviceAssignment` may be omitted if each replica of the computation uses 973 only one core, and there is either only one replica, or the number of 974 replicas is equal to the number of cores in the TPU system. 975 name: (Deprecated) Does nothing. 976 maximum_shapes: A nested structure of tf.TensorShape representing the shape 977 to which the respective component of each input element in each replica 978 should be padded. Any unknown dimensions (e.g. 979 tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like 980 object) will be padded to the maximum size of that dimension over all 981 replicas. The structure of `maximum_shapes` needs to be the same as 982 `inputs[0]`. 983 padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the 984 padding policy when the `inputs` to `tpu.replicate` is dynamic. 985 One usage is to enable automatic bucketizing on the inputs by setting the 986 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the 987 recompilation in the XLA side. 988 xla_options: An instance of `tpu.XLAOptions` which indicates the options 989 passed to XLA compiler. Use `None` for default options. 990 Returns: 991 A list of outputs, indexed by `[replica_num]` each output can be a nested 992 structure same as what computation() returns with a few exceptions. 993 994 Exceptions include: 995 1) None output: a NoOp would be returned which control-depends on 996 computation. 997 2) Single value output: A tuple containing the value would be returned. 998 3) Operation-only outputs: a NoOp would be returned which 999 control-depends on computation. 1000 TODO(b/121383831): Investigate into removing these special cases. 1001 1002 Raises: 1003 ValueError: If all replicas do not have equal numbers of input tensors. 1004 ValueError: If the number of inputs per replica does not match 1005 the number of formal parameters to `computation`. 1006 ValueError: If the static `inputs` dimensions don't match with the values 1007 given in `maximum_shapes`. 1008 ValueError: If the structure of inputs per replica does not match 1009 the structure of `maximum_shapes`. 1010 """ 1011 return split_compile_and_replicate( 1012 computation, 1013 inputs, 1014 infeed_queue, 1015 device_assignment, 1016 name, 1017 maximum_shapes=maximum_shapes, 1018 padding_spec=padding_spec, 1019 xla_options=xla_options)[1] 1020 1021 1022def _ceil_to_pow_of_n(x, n): 1023 """Ceil input `x` to power of `n`.""" 1024 x = math_ops.cast(x, dtypes.float32) 1025 lognx = math_ops.log(x) / math_ops.log(n * 1.0) 1026 lognx = math_ops.ceil(lognx) 1027 result = math_ops.pow(n * 1.0, lognx) 1028 result = math_ops.cast(result, dtypes.int32) 1029 return result 1030 1031 1032def _pad_all_input( 1033 inputs: Iterable[core_types.Tensor], 1034 padded_shapes: List[Optional[tensor_shape.TensorShape]], 1035 padding_spec: PaddingSpec 1036) -> Tuple[List[List[Any]], List[dynamic_padding.PaddingMap]]: 1037 """Pad all input tensors given padded_shapes. 1038 1039 The real shape tensors will be concatenated with the padded original inputs. 1040 1041 Args: 1042 inputs: The original inputs. 1043 padded_shapes: A list of padded shapes for each input. If an entry is None, 1044 no padding is performed. 1045 padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the 1046 padding policy when the `inputs` to `tf.tpu.replicate` is dynamic. 1047 One usage is to enable automatic bucketizing on the inputs by setting the 1048 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the 1049 recompilation in the XLA side. 1050 1051 Returns: 1052 The padded inputs and a PaddingMap list which maps the padded input 1053 dimension to the real shape argument index. 1054 """ 1055 # maximum_static_shapes[idx][i] indicates the maximum static size of ith 1056 # dimension of the idx input among all the replicas. 1057 maximum_static_shapes = [] 1058 # need_padding[idx][i] indicates whether the ith dimension of the idx input 1059 # needs padding. 1060 need_padding = [] 1061 input_shape_tensors = [] 1062 for core_idx, inputs_per_core in enumerate(inputs): 1063 for idx, input_tensor in enumerate(inputs_per_core): 1064 input_shape = input_tensor.get_shape().as_list() 1065 if core_idx == 0: 1066 input_shape_tensors.append([]) 1067 maximum_static_shapes.append(input_shape) 1068 need_padding.append(np.full_like(input_shape, False, dtype=bool)) 1069 else: 1070 for i, s in enumerate(input_shape): 1071 if s is None or s != maximum_static_shapes[idx][i]: 1072 need_padding[idx][i] = True 1073 maximum_static_shapes[idx] = max(input_shape, 1074 maximum_static_shapes[idx]) 1075 1076 # Append _POST_DEVICE_REWRITE_ATTR attributes to the real shape ops. 1077 real_input_shape = array_ops.shape(input_tensor) 1078 real_input_shape.op._set_attr( # pylint: disable=protected-access 1079 _POST_DEVICE_REWRITE_ATTR, 1080 attr_value_pb2.AttrValue(b=True)) 1081 input_shape_tensors[idx].append(real_input_shape) 1082 1083 maximum_shapes = [] 1084 for shapes_per_input in input_shape_tensors: 1085 maximum_shapes.append( 1086 math_ops.reduce_max(array_ops.stack(shapes_per_input), axis=0)) 1087 1088 padded_inputs = [] 1089 real_shapes = [] 1090 padding_maps = [] 1091 for core_idx, inputs_per_core in enumerate(inputs): 1092 padded_inputs.append([]) 1093 real_shapes.append([]) 1094 real_shape_idx = len(inputs_per_core) - 1 1095 for idx, input_tensor in enumerate(inputs_per_core): 1096 input_shape_tensor = input_shape_tensors[idx][core_idx] 1097 input_shape = input_tensor.get_shape().as_list() 1098 padded_shape = padded_shapes[idx] 1099 1100 # If we have no padded_shape, then skip padding. 1101 if any(need_padding[idx]) and padded_shape is not None: 1102 for i, s in enumerate(input_shape): 1103 if need_padding[idx][i]: 1104 if core_idx == 0: 1105 real_shape_idx += 1 1106 padding_map = dynamic_padding.PaddingMap() 1107 padding_map.arg_index = idx 1108 padding_map.shape_index = i 1109 padding_map.padding_arg_index = real_shape_idx 1110 padding_maps.append(padding_map) 1111 real_shapes[core_idx].append( 1112 math_ops.cast(input_shape_tensor[i], dtypes.int32)) 1113 1114 paddings = [] 1115 for i, s in enumerate(padded_shape.dims): 1116 if need_padding[idx][i]: 1117 # The minimum padded dimension size is 2 as XLA doesn't support size 1118 # 1 dynamic size. 1119 minimum_dynamic_dim_size = 2 1120 if s.value is not None: 1121 # Pad to the given maximum value. 1122 max_dim_size = max(s.value, minimum_dynamic_dim_size) 1123 else: 1124 # If maximum value is not given, then pad to the maximum dimension 1125 # among all the cores. 1126 max_dim_size = math_ops.maximum(maximum_shapes[idx][i], 1127 minimum_dynamic_dim_size) 1128 if padding_spec == PaddingSpec.POWER_OF_TWO: 1129 max_dim_size = _ceil_to_pow_of_n(max_dim_size, 2) 1130 # Pad to the given maximum value. 1131 padding = [0, max_dim_size - input_shape_tensor[i]] 1132 else: 1133 padding = [0, 0] 1134 paddings.append(padding) 1135 1136 if input_tensor.get_shape().is_fully_defined(): 1137 # TODO(rxsang): This is a hack to make sure padded_input has dynamic 1138 # shapes, so any tf.size/tf.shape op performed on it won't be constant 1139 # folded. Do we have better ways to do it? 1140 padded_input = control_flow_ops.cond( 1141 array_ops.constant(True), 1142 lambda: array_ops.pad(input_tensor, paddings), # pylint: disable=cell-var-from-loop 1143 lambda: input_tensor) 1144 else: 1145 padded_input = array_ops.pad(input_tensor, paddings) 1146 1147 # Append _POST_DEVICE_REWRITE_ATTR attributes to all padded inputs. 1148 padded_input.op._set_attr( # pylint: disable=protected-access 1149 _POST_DEVICE_REWRITE_ATTR, 1150 attr_value_pb2.AttrValue(b=True)) 1151 1152 padded_inputs[core_idx].append(padded_input) 1153 else: 1154 padded_inputs[core_idx].append(input_tensor) 1155 1156 num_replicas = len(padded_inputs) 1157 for i in range(num_replicas): 1158 padded_inputs[i].extend(real_shapes[i]) 1159 1160 return padded_inputs, padding_maps 1161 1162 1163def _flatten_and_filter_composite(maybe_composite, non_composite_output, 1164 composite_output=None): 1165 """For an input, replaced the input by a tuple if the input is composite. 1166 1167 If `maybe_composite` is not composite, return the parameter 1168 `non_composite_output` otherwise return a tuple which consists of the value of 1169 the parameter `composite_output` the same number of times as there are 1170 components of the composite tensor. 1171 1172 This is useful for computing a mask when flattening nested data with 1173 `expand_composites=True`. For example 1174 1175 ```python 1176 nest.flatten(data, expand_composites=True) 1177 ``` 1178 1179 and 1180 1181 ```python 1182 nest.flatten(nest.map( 1183 data, lambda x: _flatten_and_filter_composite(x, False, True))) 1184 ``` 1185 1186 will have the same length and second will be True if the tensor in the first 1187 is derived from a expanding a composite tensor. 1188 1189 Args: 1190 maybe_composite: A value to test for being a composite tensor. 1191 non_composite_output: The value to return when `maybe_composite` is not a 1192 composite. 1193 composite_output: the value to fill the output tuple with if 1194 `maybe_composite` is a composite. 1195 1196 Returns: 1197 `non_composite_output` or a tuple with multiple copies of 1198 `composite_output`. 1199 """ 1200 1201 if isinstance(maybe_composite, composite_tensor.CompositeTensor): 1202 num_components = len(nest.flatten(maybe_composite, expand_composites=True)) 1203 return (composite_output,) * num_components 1204 return non_composite_output 1205 1206 1207def split_compile_and_replicate( 1208 computation: Callable[..., Any], 1209 inputs: Optional[List[List[core_types.Tensor]]] = None, 1210 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 1211 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 1212 name: Optional[Text] = None, 1213 use_tpu: bool = True, 1214 maximum_shapes: Optional[Any] = None, 1215 padding_spec: Optional[PaddingSpec] = None, 1216 xla_options: Optional[XLAOptions] = None, 1217) -> List[List[core_types.Tensor]]: 1218 """Builds graph operators that runs compilation and replicated computation. 1219 1220 This is a lower level interface than replicate that returns a separate compile 1221 and execute output tensor. In the generated graph the compile op feeds into 1222 the execute op and no additional compilation is incurred when running the 1223 compile op before the execute op. The compile op returns additional 1224 information about the compilation but does not return the compiled program. 1225 1226 Args: 1227 computation: A Python function that builds the computation to replicate. 1228 inputs: A list of lists of input tensors or `None` (equivalent to 1229 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must 1230 have the same number of inputs. Each input can be a nested structure 1231 containing values that are convertible to tensors. Note that passing an 1232 N-dimension list of compatible values will result in a N-dimension list of 1233 scalar tensors rather than a single Rank-N tensors. If you need different 1234 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. 1235 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 1236 of arguments as inputs to computation. 1237 device_assignment: If not `None`, a `DeviceAssignment` describing the 1238 mapping between logical cores in the computation with physical cores in 1239 the TPU topology. Uses a default device assignment if `None`. The 1240 `DeviceAssignment` may be omitted if each replica of the computation uses 1241 only one core, and there is either only one replica, or the number of 1242 replicas is equal to the number of cores in the TPU system. 1243 name: (Deprecated) Does nothing. 1244 use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU 1245 backends. Currently, only supports a default placement (computation is 1246 placed on GPU if one is available, and on CPU if not). 1247 maximum_shapes: A nested structure of tf.TensorShape representing the shape 1248 to which the respective component of each input element in each replica 1249 should be padded. Any unknown dimensions (e.g. 1250 tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like 1251 object) will be padded to the maximum size of that dimension over all 1252 replicas. The structure of `maximum_shapes` needs to be the same as 1253 `inputs[0]`. 1254 padding_spec: An enum specified by `tf.tpu.PaddingSpec`. This describes the 1255 padding policy when the `inputs` to `tf.tpu.replicate` is dynamic. 1256 One usage is to enable automatic bucketizing on the inputs by setting the 1257 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the 1258 recompilation in the XLA side. 1259 xla_options: An instance of `tpu.XLAOptions` which indicates the options 1260 passed to XLA compiler. Use `None` for default options. 1261 1262 Returns: 1263 A list of lists with the first list corresponding to the compile op and the 1264 second a list of output tensors, indexed by `[replica_num][output_num]`. 1265 Raises: 1266 ValueError: If all replicas do not have equal numbers of input tensors. 1267 ValueError: If the number of inputs per replica does not match 1268 the number of formal parameters to `computation`. 1269 ValueError: If the static `inputs` dimensions don't match with the values 1270 given in `maximum_shapes`. 1271 ValueError: If the structure of inputs per replica does not match 1272 the structure of `maximum_shapes`. 1273 """ 1274 del name 1275 inputs = [[]] if inputs is None else inputs 1276 xla_options = xla_options or XLAOptions() 1277 1278 metadata_kwargs = {} 1279 if device_assignment is not None: 1280 # Turn the Numpy array into a flattened list so we can pass it as an 1281 # operator attribute. 1282 metadata_kwargs = { 1283 "topology": 1284 device_assignment.topology.serialized(), 1285 "device_assignment": 1286 device_assignment.core_assignment.flatten().tolist() 1287 } 1288 metadata_kwargs["num_cores_per_replica"] = ( 1289 device_assignment.num_cores_per_replica) 1290 1291 # This entry is used for enabling automatic outside compilation. 1292 metadata_kwargs["allow_soft_placement"] = config.get_soft_device_placement() 1293 if config.get_soft_device_placement(): 1294 logging.info("Automatic outside compilation is enabled. " 1295 "Ops without XLA kernels will be automatically " 1296 "placed on CPU.") 1297 1298 if not isinstance(inputs, list): 1299 raise TypeError("tpu.replicate() inputs must be a list of lists/tuples, " 1300 f"received {type(inputs)}") 1301 if any(not isinstance(inp, (list, tuple)) for inp in inputs): 1302 raise TypeError( 1303 "tpu.replicate() inputs must be a list of lists/tuples, " 1304 f"received types: {[type(inp) for inp in inputs]}") 1305 1306 num_replicas = len(inputs) 1307 1308 # No replicas? Nothing to do. 1309 if num_replicas == 0: 1310 return [] 1311 1312 # Checks all replicas have the same structure. 1313 for i in range(1, num_replicas): 1314 nest.assert_same_structure(inputs[0], inputs[i]) 1315 1316 # Explicitly read variables. 1317 inputs = variable_utils.convert_variables_to_tensors(inputs) 1318 # Flatten inputs. This structure may contain None values, which will be 1319 # handled later. 1320 flat_inputs_with_nones = [ 1321 nest.flatten(per_replica_input, expand_composites=True) 1322 for per_replica_input in inputs 1323 ] 1324 # Mask parallel to one replica's inputs with True for tensors coming from 1325 # composites. 1326 is_composite = nest.flatten(nest.map_structure( 1327 lambda x: _flatten_and_filter_composite(x, False, True), inputs[0])) 1328 1329 # Converts inputs to Tensors, replacing Nones with a placeholder 0 since 1330 # tpu_ops.tpu_replicated_input() can't handle non-Tensor values. 1331 flat_inputs = [] 1332 for inp in flat_inputs_with_nones: 1333 flat_inputs.append([ 1334 constant_op.constant(0) if x is None else ops.convert_to_tensor(x) 1335 for x in inp 1336 ]) 1337 1338 # Verifies that all replicas have matching numbers and types of inputs 1339 flat_input_types = [x.dtype for x in flat_inputs[0]] 1340 input_arity = len(inputs[0]) 1341 flat_input_arity = len(flat_input_types) 1342 for i in range(num_replicas): 1343 if len(inputs[i]) != input_arity: 1344 raise ValueError("Replicas must have the same number of inputs. " 1345 "Replica 0 had {} inputs, replica {} had {} " 1346 "inputs.".format(input_arity, i, len(inputs[i]))) 1347 1348 types = [x.dtype for x in flat_inputs[i]] 1349 if types != flat_input_types: 1350 raise ValueError("Replicas must have matching input types. Replica 0 had " 1351 "input types {}, replica {} had input types {}".format( 1352 flat_input_types, i, types)) 1353 1354 arg_error = xla.check_function_argument_count( 1355 computation, input_arity, infeed_queue) 1356 if arg_error is not None: 1357 if infeed_queue is None: 1358 raise TypeError( 1359 "Supplied computation cannot be called with the specified inputs. " 1360 f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]}, " 1361 f"but the computation needs {arg_error}") 1362 else: 1363 raise TypeError( 1364 "Supplied computation cannot be called with the specified inputs. " 1365 f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]} ", 1366 f"and {infeed_queue.number_of_tuple_elements} additional inputs " 1367 f"from infeed, but the computation needs {arg_error}") 1368 1369 dynamic_shape_inputs = False 1370 if maximum_shapes: 1371 if infeed_queue: 1372 raise ValueError( 1373 "Dynamic input shapes are not supported with infeed queues") 1374 1375 # Make sure maximum_shapes has the same structure as inputs. 1376 nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False) 1377 1378 # Flatten padded shapes: 1379 # For composite tensor components, we don't want to pad them. For each 1380 # entry of maximum_shapes that corresponds to a composite tensor, replace it 1381 # by a tuple of Nones of the same length as the number of components of the 1382 # composite tensor. When we flatten a second time, this makes 1383 # flat_maximum_shapes have the same length as flat_inputs[i]. We can then 1384 # avoid padding these tensors. The assumption is that they will be used by 1385 # outside compilation or that the components are statically shaped and will 1386 # be used by tpu compatible ops. 1387 flat_maximum_shapes = nest.flatten( 1388 [_flatten_and_filter_composite(x, y) 1389 for x, y in zip(nest.flatten(inputs[0]), 1390 nest.flatten(maximum_shapes))]) 1391 flat_maximum_shapes = [ 1392 tensor_shape.TensorShape(s) if s is not None else None 1393 for s in flat_maximum_shapes 1394 ] 1395 nest.assert_same_structure(flat_inputs[0], flat_maximum_shapes, 1396 check_types=False) 1397 1398 unpadded_inputs = flat_inputs 1399 flat_inputs, padding_maps = _pad_all_input(unpadded_inputs, 1400 flat_maximum_shapes, 1401 padding_spec) 1402 if padding_maps: 1403 dynamic_shape_inputs = True 1404 logging.info("TPU has inputs with dynamic shapes: %s", unpadded_inputs[0]) 1405 1406 metadata_kwargs["step_marker_location"] = getattr( 1407 computation, "step_marker_location", "STEP_MARK_AT_ENTRY") 1408 metadata_kwargs["use_spmd_for_xla_partitioning"] = \ 1409 xla_options.use_spmd_for_xla_partitioning 1410 1411 graph = ops.get_default_graph() 1412 1413 # Fan-in: Builds a TPUReplicatedInput node for each input. 1414 flat_replicated_inputs = [] 1415 for i in range(0, len(flat_inputs[0])): 1416 replicas = [flat_inputs[replica][i] for replica in range(num_replicas)] 1417 flat_replicated_inputs.append( 1418 tpu_ops.tpu_replicated_input( 1419 replicas, name="input{}".format(i))) 1420 if isinstance(graph, func_graph.FuncGraph): 1421 # When we are in Tensorflow 2.0 function, 'graph' will be a FuncGraph 1422 # object. If both outside graph and this function have a TPU cluster, 1423 # they will have the same cluster name and it will cause problems (because 1424 # we lower functional ops in Tensorflow 2.0). Append function name to 1425 # 'cluster_name' to avoid cluster name collision. 1426 cluster_name = graph.unique_name("cluster_" + graph.name) 1427 else: 1428 cluster_name = graph.unique_name("cluster") 1429 pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") 1430 pivot._set_attr(_PIVOT_FOR_CLUSTER, # pylint: disable=protected-access 1431 attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))) 1432 context = TPUReplicateContext( 1433 name=cluster_name, num_replicas=num_replicas, pivot=pivot) 1434 try: 1435 context.Enter() 1436 1437 metadata = tpu_ops.tpu_replicate_metadata( 1438 num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) 1439 1440 with tpu_function.tpu_shard_context( 1441 num_replicas), ops.control_dependencies([metadata]): 1442 1443 if dynamic_shape_inputs and xla_options.enable_xla_dynamic_padder: 1444 for padding_map in padding_maps: 1445 input_shape = flat_replicated_inputs[padding_map.arg_index].shape 1446 flat_replicated_inputs[ 1447 padding_map.arg_index] = tf2xla.set_dynamic_dimension_size( 1448 flat_replicated_inputs[padding_map.arg_index], 1449 padding_map.shape_index, 1450 flat_replicated_inputs[padding_map.padding_arg_index]) 1451 flat_replicated_inputs[padding_map.arg_index].set_shape(input_shape) 1452 1453 # Add identity ops so even unused inputs are "consumed" by the 1454 # computation. This is to avoid orphaned TPUReplicatedInput nodes. 1455 # TODO(phawkins): consider instead pruning unused TPUReplicatedInput 1456 # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. 1457 flat_replicated_inputs = [ 1458 array_ops.identity(x, name="replicated_input_{}".format(i)) 1459 for i, x in enumerate(flat_replicated_inputs) 1460 ] 1461 for i, composite in zip(flat_replicated_inputs, is_composite): 1462 # pylint: disable=protected-access 1463 # Add an attribute to the identity node so that they could be removed in 1464 # encapsulate TPU computation pass if unused. However we don't remove 1465 # inputs when dynamic padding is enabled. 1466 # TODO(rxsang): Use other ways except argument index in padding_map so 1467 # outside compilation can work with dynamic padding correctly. 1468 if not dynamic_shape_inputs or composite: 1469 i.op._set_attr("_tpu_input_identity", 1470 attr_value_pb2.AttrValue(b=True)) 1471 # pylint: enable=protected-access 1472 1473 # Clobber replicated placeholders with Nones. 1474 computation_inputs = [ 1475 None if inp is None else replicated for replicated, inp in zip( 1476 flat_replicated_inputs, flat_inputs_with_nones[0]) 1477 ] 1478 1479 # Unflatten the computation inputs to match original input structure. 1480 computation_inputs = nest.pack_sequence_as( 1481 structure=inputs[0], 1482 flat_sequence=computation_inputs[:flat_input_arity], 1483 expand_composites=True) 1484 1485 # If there is an infeed queue, adds the dequeued values to the 1486 # computation's inputs. 1487 if infeed_queue is not None: 1488 infeed_queue.set_number_of_shards(num_replicas) 1489 for t in infeed_queue.generate_dequeue_op(): 1490 computation_inputs.append(t) 1491 1492 # Only resource variables work inside a TPU computation, so turn on 1493 # resource variables for the computation. 1494 # TODO(phawkins): consider removing this code. It will 1495 # be less confusing to clients if they knowingly choose to use resource 1496 # variables. 1497 # Partitioned variables is not supported (b/112311320). 1498 vscope = variable_scope.get_variable_scope() 1499 saved_use_resource = vscope.use_resource 1500 saved_custom_getter = vscope.custom_getter 1501 1502 def custom_getter(getter, name, *args, **kwargs): 1503 """Variables on TPU have a few restrictions.""" 1504 partitioner = kwargs.get("partitioner", None) 1505 if partitioner is not None: 1506 kwargs["partitioner"] = None 1507 logging.warning( 1508 "Partitioned variables are not supported on TPU. Got " 1509 "`partitioner` that is %s for variable %s. " 1510 "Setting `partitioner` to `None`.", partitioner, name) 1511 if saved_custom_getter is None: 1512 return getter(name, *args, **kwargs) 1513 else: 1514 return saved_custom_getter(getter, name, *args, **kwargs) 1515 1516 vscope.set_use_resource(True) 1517 vscope.set_custom_getter(custom_getter) 1518 1519 outputs = computation(*computation_inputs) 1520 1521 vscope.set_use_resource(saved_use_resource) 1522 vscope.set_custom_getter(saved_custom_getter) 1523 1524 outputs = variable_utils.convert_variables_to_tensors(outputs) 1525 1526 need_spmd_partitioning = ( 1527 xla_options.use_spmd_for_xla_partitioning and 1528 device_assignment is not None and 1529 device_assignment.num_cores_per_replica > 1) 1530 outputs_is_flat = xla.is_flat(outputs) 1531 if outputs_is_flat: 1532 output_tensors, control_deps, pack_template = _postprocess_flat_outputs( 1533 outputs, need_spmd_partitioning) 1534 else: 1535 output_tensors, control_deps, pack_template = ( 1536 _postprocess_non_flat_outputs(outputs, need_spmd_partitioning)) 1537 1538 # tensor_tracer imports tpu.py. Local import to tensor_tracer to avoid 1539 # import-cycle 1540 if typing.TYPE_CHECKING: 1541 tensor_tracer = Any 1542 else: 1543 # pylint: disable=g-import-not-at-top 1544 from tensorflow.python.tpu import tensor_tracer 1545 # pylint: enable=g-import-not-at-top 1546 if tensor_tracer.TensorTracer.is_enabled(): 1547 if tf2.enabled(): 1548 logging.warn("TF API ver >= 2.0 detected. " 1549 "Tensor Tracer v1 is not enabled.") 1550 else: 1551 tt = tensor_tracer.TensorTracer() 1552 output_tensors = tt.trace_tpu(ops.get_default_graph(), 1553 output_tensors, control_deps, 1554 num_replicas) 1555 1556 context.ExitResult(output_tensors) 1557 finally: 1558 context.report_unsupported_operations() 1559 context.Exit() 1560 host_compute_core = context.HostComputeCore() 1561 1562 if host_compute_core: 1563 attr_value = attr_value_pb2.AttrValue() 1564 attr_value.list.s.extend(compat.as_bytes(x) for x in host_compute_core) 1565 metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access 1566 1567 with ops.control_dependencies([metadata]): 1568 if use_tpu: 1569 compile_status = tpu_ops.tpu_compilation_result() 1570 op = compile_status.op 1571 attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) 1572 op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access 1573 else: 1574 compile_status = control_flow_ops.no_op(name="compilation_status") 1575 1576 if not output_tensors: 1577 # Returns a list of NoOps dependent on the replication Op, indexed by 1578 # [replica_num]. 1579 return [ 1580 compile_status, 1581 [ 1582 control_flow_ops.group(control_deps, name="shard_%d" % i) 1583 for i in range(num_replicas) 1584 ] 1585 ] 1586 1587 # Fan-out: Builds a TPUReplicatedOutput node for each output. 1588 replicated_outputs = [[] for i in range(num_replicas)] 1589 for i, t in enumerate(output_tensors): 1590 1591 # None values returned by the computation can't be sent to 1592 # tpu_ops.tpu_replicated_output(), we handle them specially here. We can 1593 # avoid the placeholder 0 routine required on the inputs since outputs are 1594 # replicated per-tensor, not per-replica, so we can skip replication. 1595 if t is None: 1596 for replica in range(num_replicas): 1597 replicated_outputs[replica].append(None) 1598 continue 1599 1600 # Fan-out: Builds a TPUReplicatedOutput node for each output. 1601 ys = tpu_ops.tpu_replicated_output( 1602 t, num_replicas, name="output{}".format(i)) 1603 1604 # Wraps the outputs in identity operators so the names of any possible 1605 # `fetch` nodes are preserved by the replication rewrite. 1606 with ops.control_dependencies(control_deps): 1607 for replica in range(num_replicas): 1608 replicated_outputs[replica].append( 1609 array_ops.identity( 1610 ys[replica], name="output_%d_shard_%d" % (i, replica))) 1611 1612 replicated_outputs = [ 1613 nest.pack_sequence_as(pack_template, replica_outs, expand_composites=True) 1614 for replica_outs in replicated_outputs 1615 ] 1616 1617 return [compile_status, replicated_outputs] 1618 1619 1620def _postprocess_flat_outputs( 1621 outputs: Any, 1622 need_spmd_partitioning: bool 1623) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]: 1624 """Validates non-flat outputs, add backs device assignments and other attrs. 1625 1626 Args: 1627 outputs: Output from `computation` inside `tpu.rewrite`. 1628 need_spmd_partitioning: Whether XLA SPMD partitioning is needed. 1629 1630 Returns: 1631 - Tensors extracted from outputs. 1632 - Operations extracted from outputs. 1633 - A pack template for use with nest.pack_sequence_as to pack the tensors. 1634 """ 1635 # Following code segment is to preserve legacy behavior. Previously we only 1636 # supported flat outputs and thus for consistency it was nice to convert even 1637 # single element into a tuple. But now that we support arbitrary output 1638 # structure, this is no longer necessary. 1639 # TODO(b/121383831): Migrate all legacy use cases and delete this special 1640 # case. 1641 # If the computation returns `None`, make it an empty tuple. 1642 if outputs is None: 1643 outputs = tuple() 1644 1645 # For legacy / backwards compatibility reasons we return a list for "flat" 1646 # output values (even if the user's flat return value was a different type or 1647 # even just a scalar value) so use nest.flatten to compute a flat list pack 1648 # template. 1649 pack_template = nest.flatten(outputs, expand_composites=False) 1650 1651 # Even though outputs is already "flat", we flatten any composites so their 1652 # component tensors can be tagged and replicated. The pack_template will be 1653 # used by the caller to repack the composite tensors. 1654 outputs = nest.flatten(outputs, expand_composites=True) 1655 1656 # Append `no_op` here so that fetching any return value of this function 1657 # will trigger TPUExecute node. 1658 outputs += (control_flow_ops.no_op(),) 1659 1660 maybe_convert = lambda x: None if x is None else ops.convert_to_tensor(x) 1661 try: 1662 if need_spmd_partitioning: 1663 outputs = [ 1664 o if isinstance(o, ops.Operation) else maybe_convert(o) 1665 for o in outputs 1666 ] 1667 else: 1668 with ops.device(core(0)): 1669 outputs = [ 1670 o if isinstance(o, ops.Operation) else maybe_convert(o) 1671 for o in outputs 1672 ] 1673 except Exception as e: 1674 raise ValueError( 1675 "TPU function return values must all either be Operations or " 1676 f"convertible to Tensors. Got error: {e}") 1677 1678 # Separates the returned Operations and Tensors. 1679 output_operations = [o for o in outputs if isinstance(o, ops.Operation)] 1680 output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] 1681 1682 if outputs != output_tensors + output_operations: 1683 raise ValueError( 1684 "TPU functions must return zero-or more Tensor values followed by " 1685 "zero or more Operations.") 1686 1687 # Trim operations off the end of the pack template. output_operations has 1 1688 # extra element due to the no-op that is added. 1689 if len(output_operations) > 1: 1690 pack_template = pack_template[:1 - len(output_operations)] 1691 1692 # Wraps outputs in Identity ops. Otherwise a replicated input copied 1693 # straight to an output would bypass the replicate(). This would be bad 1694 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not 1695 # be rewritten away, leading to a runtime error. 1696 # TODO(phawkins): extend the rewrite to elide these nodes instead. 1697 new_output_tensors = [] 1698 for t in output_tensors: 1699 if t is None: 1700 new_output_tensors.append(None) 1701 elif need_spmd_partitioning: 1702 o = array_ops.identity(t) 1703 # pylint: disable=protected-access 1704 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 1705 # pylint: enable=protected-access 1706 new_output_tensors.append(o) 1707 else: 1708 with ops.device(t.device if t.device else core(0)): 1709 o = array_ops.identity(t) 1710 # pylint: disable=protected-access 1711 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 1712 # pylint: enable=protected-access 1713 new_output_tensors.append(o) 1714 return new_output_tensors, output_operations, pack_template 1715 1716 1717def _postprocess_non_flat_outputs( 1718 outputs: Any, 1719 need_spmd_partitioning: bool 1720) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]: 1721 """Validates non-flat outputs, add backs device assignments and other attrs. 1722 1723 Args: 1724 outputs: Output from `computation` inside `tpu.rewrite`. 1725 need_spmd_partitioning: Whether XLA SPMD partitioning is needed. 1726 1727 Returns: 1728 - Tensors extracted from outputs. 1729 - An empty Operations list because Operations are not allowed in non-flat 1730 outputs. 1731 - A pack template for use with nest.pack_sequence_as to pack the tensors. 1732 """ 1733 1734 # Flatten output items. 1735 flat_outputs = nest.flatten(outputs, expand_composites=True) 1736 1737 # Convert all non-None non-Operation outputs to Tensors. 1738 for i, o in enumerate(flat_outputs): 1739 if o is None: 1740 flat_outputs[i] = None 1741 continue 1742 1743 if isinstance(o, ops.Operation): 1744 raise ValueError( 1745 "tpu.rewrite does not support Operation as return value in non-flat " 1746 "output structure. You can set returned Operations as control " 1747 "dependencies of returned Tensors so Operations are triggered when " 1748 f'Tensors are evaluated. Operation found: "{o.name}"') 1749 1750 try: 1751 o = ops.convert_to_tensor(o) 1752 except Exception as e: 1753 raise ValueError( 1754 "TPU function return values must all either be Operations or " 1755 f'convertible to Tensors. Got error: "{e}"') 1756 1757 # Wraps outputs in Identity ops. Otherwise a replicated input copied 1758 # straight to an output would bypass the replicate(). This would be bad 1759 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not 1760 # be rewritten away, leading to a runtime error. 1761 # TODO(phawkins): extend the rewrite to elide these nodes instead. 1762 if need_spmd_partitioning: 1763 o = array_ops.identity(o) 1764 # pylint: disable=protected-access 1765 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 1766 # pylint: enable=protected-access 1767 flat_outputs[i] = array_ops.identity(o) 1768 else: 1769 with ops.device(o.device if o.device else core(0)): 1770 o = array_ops.identity(o) 1771 # pylint: disable=protected-access 1772 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 1773 # pylint: enable=protected-access 1774 flat_outputs[i] = array_ops.identity(o) 1775 1776 # All flat_outputs are Tensors, and no Operations. 1777 return flat_outputs, [], outputs 1778 1779 1780def split_compile_and_shard( 1781 computation: Callable[..., Any], 1782 inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None, 1783 num_shards: int = 1, 1784 input_shard_axes: Optional[List[int]] = None, 1785 outputs_from_all_shards: Union[bool, List[bool]] = True, 1786 output_shard_axes: Optional[List[int]] = None, 1787 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 1788 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 1789 name: Optional[Text] = None, 1790 xla_options: Optional[XLAOptions] = None, 1791 ) -> Tuple[ops.Operation, List[core_types.Tensor]]: 1792 """Shards `computation` for parallel execution. 1793 1794 `inputs` must be a list of Tensors or None (equivalent to an empty list), each 1795 of which has a corresponding split axis (from `input_shard_axes`). Each input 1796 is split into `num_shards` pieces along the corresponding axis, and 1797 computation is applied to each shard in parallel. 1798 1799 Tensors are broadcast to all shards if they are lexically captured by 1800 `computation`. e.g., 1801 1802 x = tf.constant(7) 1803 def computation(): 1804 return x + 3 1805 ... = shard(computation, ...) 1806 1807 If `outputs_from_all_shards` is true, the outputs from all shards of 1808 `computation` are concatenated back together along their `output_shard_axes`. 1809 Otherwise, each output is taken from an arbitrary shard. 1810 1811 Inputs and outputs of the computation must be at least rank-1 Tensors. 1812 1813 Args: 1814 computation: A Python function that builds a computation to apply to each 1815 shard of the input. 1816 inputs: A list of input tensors or None (equivalent to an empty list). Each 1817 input tensor has a corresponding shard axes, given by `input_shard_axes`, 1818 which must have size divisible by `num_shards`. 1819 num_shards: The number of shards. 1820 input_shard_axes: A list of dimensions along which to shard `inputs`, or 1821 `None`. `None` means "shard all inputs along dimension 0". If not `None`, 1822 there must be one dimension per input. 1823 outputs_from_all_shards: Boolean or list of boolean. For each output, if 1824 `True`, outputs from all shards are concatenated along the corresponding 1825 `output_shard_axes` entry. Otherwise, each output is taken 1826 from an arbitrary shard. If the argument is a boolean, the argument's 1827 value is used for each output. 1828 output_shard_axes: A list of dimensions along which to concatenate the 1829 outputs of `computation`, or `None`. `None` means "concatenate all outputs 1830 along dimension 0". If not `None`, there must be one dimension per output. 1831 Ignored if `outputs_from_all_shards` is False. 1832 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs 1833 of `computation`. 1834 device_assignment: If not `None`, a `DeviceAssignment` describing the 1835 mapping between logical cores in the computation with physical cores in 1836 the TPU topology. Uses a default device assignment if `None`. The 1837 `DeviceAssignment` may be omitted if each shard of the computation uses 1838 only one core, and there is either only one shard, or the number of shards 1839 is equal to the number of cores in the TPU system. 1840 name: (Deprecated) Does nothing. 1841 xla_options: An instance of `tpu.XLAOptions` which indicates the options 1842 passed to XLA compiler. Use `None` for default options. 1843 Returns: 1844 A tuple of (compile op, [output tensors]). 1845 Raises: 1846 ValueError: If num_shards <= 0 1847 ValueError: If len(input_shard_axes) != len(inputs) 1848 ValueError: If len(output_shard_axes) != len(outputs from `computation`) 1849 """ 1850 # TODO(phawkins): consider adding support for broadcasting Tensors passed as 1851 # inputs. 1852 1853 if num_shards <= 0: 1854 raise ValueError( 1855 f"num_shards must be a positive integer. Received {num_shards}") 1856 1857 inputs = [] if inputs is None else inputs 1858 if not isinstance(inputs, list): 1859 raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None. " 1860 f"Received {type(inputs)}") 1861 1862 # Converts inputs to Tensors. 1863 inputs = [ops.convert_to_tensor(x) for x in inputs] 1864 1865 if input_shard_axes is None: 1866 input_shard_axes = [0] * len(inputs) 1867 if len(inputs) != len(input_shard_axes): 1868 raise ValueError("Length of input_shard_axes must be equal to the number " 1869 f"of inputs. Received {len(inputs)} inputs and " 1870 f"{len(input_shard_axes)} input_shard_axes.") 1871 1872 if inputs: 1873 # Splits the `inputs` along the corresponding `input_shard_axes`, giving 1874 # lists with layout [input][shard] 1875 split_inputs = [ 1876 array_ops.split(x, num_shards, axis=axis) 1877 for (axis, x) in zip(input_shard_axes, inputs)] 1878 1879 # Transposes the input lists to have layout [shard][input] 1880 transposed_inputs = [list(i) for i in zip(*split_inputs)] 1881 else: 1882 transposed_inputs = [[]] * num_shards 1883 1884 compile_op, outputs = split_compile_and_replicate( 1885 computation, 1886 transposed_inputs, 1887 infeed_queue=infeed_queue, 1888 device_assignment=device_assignment, 1889 name=name, 1890 xla_options=xla_options) 1891 1892 # There must be at least one shard since num_shards > 0. 1893 # TODO(b/36647078) remove disable when pylint bug is fixed. 1894 # pylint: disable=indexing-exception 1895 if isinstance(outputs[0], ops.Operation): 1896 # pylint: enable=indexing-exception 1897 # There were no outputs from the computation and replicate returned a list 1898 # of NoOps with control dependencies on the computation. Return the first 1899 # one so it can be used as a control dependency or fetch node. 1900 # TODO(b/36647078) remove disable when pylint bug is fixed. 1901 # pylint: disable=indexing-exception 1902 return compile_op, [outputs[0]] 1903 # pylint: enable=indexing-exception 1904 1905 # TODO(b/36647078) remove disable when pylint bug is fixed. 1906 # pylint: disable=indexing-exception 1907 num_outputs = len(outputs[0]) 1908 # pylint: enable=indexing-exception 1909 1910 if output_shard_axes is None: 1911 output_shard_axes = [0] * num_outputs 1912 if num_outputs != len(output_shard_axes): 1913 raise ValueError("Length of output_shard_axes must be equal to the number " 1914 f"of outputs. Received {num_outputs} outputs " 1915 f"and {len(output_shard_axes)} output_shard_axes.") 1916 1917 if isinstance(outputs_from_all_shards, bool): 1918 outputs_from_all_shards = [outputs_from_all_shards] * num_outputs 1919 1920 if num_outputs != len(outputs_from_all_shards): 1921 raise ValueError( 1922 "Length of outputs_from_all_shards must be equal to the number of " 1923 f"outputs. Received {num_outputs} outputs and " 1924 f"{len(outputs_from_all_shards)} outputs_from_all_shards.") 1925 1926 results = [] 1927 for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards, 1928 zip(*outputs)): 1929 if all_shards: 1930 # Concatenate all of the outputs together (use stack for scalars). 1931 shape = x[0].shape 1932 is_scalar = shape is not None and (shape.ndims == 0) 1933 results.append((array_ops.stack(list(x)) if is_scalar 1934 else array_ops.concat(list(x), axis=axis))) 1935 else: 1936 # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. 1937 results.append(x[0]) 1938 1939 return compile_op, results 1940 1941 1942@tf_export(v1=["tpu.shard"]) 1943@traceback_utils.filter_traceback 1944def shard( 1945 computation: Callable[..., Any], 1946 inputs: Optional[List[core_types.Tensor]] = None, 1947 num_shards: int = 1, 1948 input_shard_axes: Optional[List[int]] = None, 1949 outputs_from_all_shards: Union[bool, List[bool]] = True, 1950 output_shard_axes: Optional[List[int]] = None, 1951 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 1952 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 1953 name: Optional[Text] = None, 1954 xla_options: Optional[XLAOptions] = None) -> List[core_types.Tensor]: 1955 """Shards `computation` for parallel execution. 1956 1957 `inputs` must be a list of Tensors or None (equivalent to an empty list), each 1958 of which has a corresponding split axis (from `input_shard_axes`). Each input 1959 is split into `num_shards` pieces along the corresponding axis, and 1960 computation is applied to each shard in parallel. 1961 1962 Tensors are broadcast to all shards if they are lexically captured by 1963 `computation`. e.g., 1964 1965 x = tf.constant(7) 1966 def computation(): 1967 return x + 3 1968 ... = shard(computation, ...) 1969 1970 TODO(phawkins): consider adding support for broadcasting Tensors passed 1971 as inputs. 1972 1973 If `outputs_from_all_shards` is true, the outputs from all shards of 1974 `computation` are concatenated back together along their `output_shard_axes`. 1975 Otherwise, each output is taken from an arbitrary shard. 1976 1977 Inputs and outputs of the computation must be at least rank-1 Tensors. 1978 1979 Args: 1980 computation: A Python function that builds a computation to apply to each 1981 shard of the input. 1982 inputs: A list of input tensors or None (equivalent to an empty list). Each 1983 input tensor has a corresponding shard axes, given by `input_shard_axes`, 1984 which must have size divisible by `num_shards`. 1985 num_shards: The number of shards. 1986 input_shard_axes: A list of dimensions along which to shard `inputs`, or 1987 `None`. `None` means "shard all inputs along dimension 0". If not `None`, 1988 there must be one dimension per input. 1989 outputs_from_all_shards: Boolean or list of boolean. For each output, if 1990 `True`, outputs from all shards are concatenated along the corresponding 1991 `output_shard_axes` entry. Otherwise, each output is taken 1992 from an arbitrary shard. If the argument is a boolean, the argument's 1993 value is used for each output. 1994 output_shard_axes: A list of dimensions along which to concatenate the 1995 outputs of `computation`, or `None`. `None` means "concatenate all outputs 1996 along dimension 0". If not `None`, there must be one dimension per output. 1997 Ignored if `outputs_from_all_shards` is False. 1998 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs 1999 of `computation`. 2000 device_assignment: If not `None`, a `DeviceAssignment` describing the 2001 mapping between logical cores in the computation with physical cores in 2002 the TPU topology. Uses a default device assignment if `None`. The 2003 `DeviceAssignment` may be omitted if each shard of the computation uses 2004 only one core, and there is either only one shard, or the number of shards 2005 is equal to the number of cores in the TPU system. 2006 name: (Deprecated) Does nothing. 2007 xla_options: An instance of `tpu.XLAOptions` which indicates the options 2008 passed to XLA compiler. Use `None` for default options. 2009 Returns: 2010 A list of output tensors. 2011 Raises: 2012 ValueError: If num_shards <= 0 2013 ValueError: If len(input_shard_axes) != len(inputs) 2014 ValueError: If len(output_shard_axes) != len(outputs from `computation`) 2015 """ 2016 return split_compile_and_shard( 2017 computation, 2018 inputs=inputs, 2019 num_shards=num_shards, 2020 input_shard_axes=input_shard_axes, 2021 outputs_from_all_shards=outputs_from_all_shards, 2022 output_shard_axes=output_shard_axes, 2023 infeed_queue=infeed_queue, 2024 device_assignment=device_assignment, 2025 name=name, 2026 xla_options=xla_options)[1] 2027 2028 2029@tf_export(v1=["tpu.batch_parallel"]) 2030@traceback_utils.filter_traceback 2031def batch_parallel( 2032 computation: Callable[..., Any], 2033 inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None, 2034 num_shards: int = 1, 2035 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 2036 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 2037 name: Optional[Text] = None, 2038 xla_options: Optional[XLAOptions] = None): 2039 """Shards `computation` along the batch dimension for parallel execution. 2040 2041 Convenience wrapper around shard(). 2042 2043 `inputs` must be a list of Tensors or None (equivalent to an empty list). 2044 Each input is split into `num_shards` pieces along the 0-th dimension, and 2045 computation is applied to each shard in parallel. 2046 2047 Tensors are broadcast to all shards if they are lexically captured by 2048 `computation`. e.g., 2049 2050 x = tf.constant(7) 2051 def computation(): 2052 return x + 3 2053 ... = shard(computation, ...) 2054 2055 The outputs from all shards are concatenated back together along their 0-th 2056 dimension. 2057 2058 Inputs and outputs of the computation must be at least rank-1 Tensors. 2059 2060 Args: 2061 computation: A Python function that builds a computation to apply to each 2062 shard of the input. 2063 inputs: A list of input tensors or None (equivalent to an empty list). The 2064 0-th dimension of each Tensor must have size divisible by `num_shards`. 2065 num_shards: The number of shards. 2066 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 2067 of arguments as inputs to `computation`. 2068 device_assignment: If not `None`, a `DeviceAssignment` describing the 2069 mapping between logical cores in the computation with physical cores in 2070 the TPU topology. Uses a default device assignment if `None`. The 2071 `DeviceAssignment` may be omitted if each shard of the computation uses 2072 only one core, and there is either only one shard, or the number of shards 2073 is equal to the number of cores in the TPU system. 2074 name: (Deprecated) Does nothing. 2075 xla_options: An instance of `tpu.XLAOptions` which indicates the options 2076 passed to XLA compiler. Use `None` for default options. 2077 Returns: 2078 A list of output tensors. 2079 Raises: 2080 ValueError: If `num_shards <= 0` 2081 """ 2082 return shard( 2083 computation, 2084 inputs, 2085 num_shards=num_shards, 2086 infeed_queue=infeed_queue, 2087 device_assignment=device_assignment, 2088 name=name, 2089 xla_options=xla_options) 2090 2091 2092@tf_export(v1=["tpu.rewrite"]) 2093@traceback_utils.filter_traceback 2094def rewrite( 2095 computation: Callable[..., Any], 2096 inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None, 2097 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 2098 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 2099 name: Optional[Text] = None, 2100 xla_options: Optional[XLAOptions] = None) -> Any: 2101 """Rewrites `computation` for execution on a TPU system. 2102 2103 Args: 2104 computation: A Python function that builds a computation to apply to the 2105 input. If the function takes n inputs, 'inputs' should be a list of n 2106 tensors. 2107 2108 `computation` may return a list of operations and tensors. Tensors must 2109 come before operations in the returned list. The return value of 2110 `rewrite` is a list of tensors corresponding to the tensors from the 2111 output of `computation`. 2112 2113 All `Operation`s constructed during `computation` will be executed when 2114 evaluating any of the returned output tensors, not just the ones returned. 2115 inputs: A list of input tensors or `None` (equivalent to an empty list). 2116 Each input can be a nested structure containing values that are 2117 convertible to tensors. Note that passing an N-dimension list of 2118 compatible values will result in a N-dimension list of scalar tensors 2119 rather than a single Rank-N tensors. If you need different behavior, 2120 convert part of inputs to tensors with `tf.convert_to_tensor`. 2121 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 2122 of arguments as inputs to `computation`. 2123 device_assignment: if not `None`, a `DeviceAssignment` describing the 2124 mapping between logical cores in the computation with physical cores in 2125 the TPU topology. May be omitted for a single-core computation, in which 2126 case the core attached to task 0, TPU device 0 is used. 2127 name: (Deprecated) Does nothing. 2128 xla_options: An instance of `tpu.XLAOptions` which indicates the options 2129 passed to XLA compiler. Use `None` for default options. 2130 Returns: 2131 Same data structure as if computation(*inputs) is called directly with some 2132 exceptions for correctness. Exceptions include: 2133 1) None output: a NoOp would be returned which control-depends on 2134 computation. 2135 2) Single value output: A tuple containing the value would be returned. 2136 3) Operation-only outputs: a NoOp would be returned which 2137 control-depends on computation. 2138 TODO(b/121383831): Investigate into removing these special cases. 2139 """ 2140 # TODO(b/36647078) remove disable when pylint bug is fixed. 2141 # pylint: disable=indexing-exception 2142 return replicate( 2143 computation, 2144 None if inputs is None else [inputs], 2145 infeed_queue=infeed_queue, 2146 device_assignment=device_assignment, 2147 name=name, 2148 xla_options=xla_options)[0] 2149 # pylint: enable=indexing-exception 2150 2151 # Operations that indicate some error in the user's inference graph. 2152 2153 2154_DENYLISTED_INFERENCE_OPS = set([ 2155 "ReadVariableOp", 2156 "AssignVariableOp", 2157 "AssignAddVariableOp", 2158 "AssignSubVariableOp", 2159 "VarHandleOp", 2160 "Variable", 2161 "VariableV2", 2162]) 2163 2164 2165def under_tpu_inference_context() -> bool: 2166 """Check if it is currently under `_TPUInferenceContext`.""" 2167 graph = ops.get_default_graph() 2168 while graph: 2169 context = graph._get_control_flow_context() # pylint: disable=protected-access 2170 while context: 2171 if isinstance(context, _TPUInferenceContext): 2172 return True 2173 context = context.outer_context 2174 if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access 2175 graph = graph._outer_graph # pylint: disable=protected-access 2176 elif isinstance(graph, func_graph.FuncGraph): 2177 graph = graph.outer_graph 2178 else: 2179 return False 2180 2181 2182class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): 2183 """A `ControlFlowContext` for nodes inside a TPU inference computation. 2184 2185 The primary role of `_TPUInferenceContext` is to indicate the mode of 2186 operation and possibly sanity check operators inside a 2187 tpu.rewrite_for_inference() computation. 2188 """ 2189 2190 def __init__(self, name: Text, check_ops: bool = True): 2191 super(_TPUInferenceContext, self).__init__() 2192 self._name = name 2193 self._check_ops = check_ops 2194 2195 def AddOp(self, op): 2196 self._AddOpInternal(op) 2197 2198 def _AddOpInternal(self, op): 2199 # pylint: disable=protected-access 2200 if self._check_ops and op.type in _DENYLISTED_INFERENCE_OPS: 2201 raise NotImplementedError( 2202 f"Operation of type {op.type} ({op.name}) is not supported on the " 2203 "TPU for inference. Execution will fail if this op is used in the " 2204 "graph. Make sure your variables are using variable_scope.") 2205 if self._outer_context: 2206 self._outer_context.AddInnerOp(op) 2207 2208 def AddValue(self, val): 2209 result = val 2210 if self._outer_context: 2211 result = self._outer_context.AddValue(val) 2212 return result 2213 2214 def AddInnerOp(self, op): 2215 self._AddOpInternal(op) 2216 2217 @property 2218 def grad_state(self): 2219 return None 2220 2221 2222def validate_inference_rewrite_for_variables(graph: ops.Graph): 2223 """Validates whether rewrite_for_inference() 'worked' for variables. 2224 2225 The rewrite_for_inference() method is supposed to append GuaranteeConstOps 2226 after ReadVariableOps, but this mechanism works only if you are using 2227 tf.compat.v1.get_variable() to create and access variables in your tpu 2228 computation. This validation method can be called immediately after calling 2229 tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added 2230 to the graph. 2231 2232 Typical usages: 2233 tpu.validate_inference_rewrite_for_variables( 2234 tf.compat.v1.get_default_graph()) 2235 2236 tpu.validate_inference_rewrite_for_variables(sess.graph) 2237 2238 Args: 2239 graph: The graph which needs to be validated. 2240 Raises: 2241 RuntimeError: if validation failed. 2242 """ 2243 if not any(x.type == "GuaranteeConst" for x in graph.get_operations()): 2244 raise RuntimeError( 2245 "No GuaranteeConst ops found in the graph after running " 2246 "tpu.rewrite_for_inference(...). Please check that you are using " 2247 "tf.get_variable() to create and access variables in your tpu " 2248 "computation.") 2249 2250 2251def rewrite_for_inference( 2252 computation: Callable[..., Any], 2253 inputs: Optional[List[core_types.Tensor]] = None, 2254 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 2255 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 2256 name: Optional[Text] = None) -> List[core_types.Tensor]: 2257 """Rewrites `computation` for inference on a TPU system. 2258 2259 Other than 'rewriting' the computation to run on a TPU, if using variables 2260 in your computation, it moves the ReadVariableOps outside the TPU 2261 computation, and adds GuaranteeConst ops just after the ReadVariableOps. 2262 This mechanism works only if you are using tf.compat.v1.get_variable() to 2263 create and access variables in your tpu computation. You can validate 2264 whether this worked, by calling validate_inference_rewrite_for_variables() 2265 method immediately after this method to check whether GuaranteeConstOps 2266 where added to the graph. 2267 2268 Args: 2269 computation: A Python function that builds a computation to apply to the 2270 input. If the function takes n inputs, 'inputs' should be a list of n 2271 tensors. If the function returns m outputs, rewrite will return a list of 2272 m tensors. 2273 inputs: A list of input tensors or `None` (equivalent to an empty list). 2274 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 2275 of arguments as inputs to `computation`. 2276 device_assignment: if not `None`, a `DeviceAssignment` describing the 2277 mapping between logical cores in the computation with physical cores in 2278 the TPU topology. May be omitted for a single-core computation, in which 2279 case the core attached to task 0, TPU device 0 is used. 2280 name: The name of the operator. 2281 Returns: 2282 A list of output tensors. 2283 """ 2284 2285 def guarantee_const_getter(getter, name, *args, **kwargs): 2286 with ops.control_dependencies(None): 2287 return array_ops.guarantee_const( 2288 getter(name, *args, **kwargs), name=name + "/GuaranteeConst") 2289 2290 def wrapped_computation(*args, **kwargs): 2291 """Execute computation under `_TPUInferenceContext`.""" 2292 context = _TPUInferenceContext( 2293 name=ops.get_default_graph().unique_name("rewrite_for_inference")) 2294 try: 2295 context.Enter() 2296 2297 vscope = variable_scope.get_variable_scope() 2298 prev_custom_getter = vscope.custom_getter 2299 prev_caching_device = vscope.caching_device 2300 vscope.set_custom_getter(guarantee_const_getter) 2301 vscope.set_caching_device(lambda op: op.device) 2302 2303 result = computation(*args, **kwargs) 2304 2305 vscope.set_custom_getter(prev_custom_getter) 2306 vscope.set_caching_device(prev_caching_device) 2307 finally: 2308 context.Exit() 2309 return result 2310 2311 # pylint: disable=undefined-variable 2312 return rewrite( 2313 wrapped_computation, 2314 inputs=inputs, 2315 infeed_queue=infeed_queue, 2316 device_assignment=device_assignment, 2317 name=name) 2318 # pylint: enable=undefined-variable 2319 2320 2321def prune_unconnected_ops_from_xla(prune_graph: ops.Graph): 2322 """Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE. 2323 2324 Args: 2325 prune_graph: A tensorflow graph from which we wish to prune unconnected ops 2326 as listed in _UNCONNECTED_OPS_TO_PRUNE. In general, these ops should have 2327 no inputs and no consumers. These can often be left behind due to graph 2328 construction rewiring (for instance TF-Hub). While they never execute, 2329 they will cause XLA compile to fail so we strip them from XLA compile by 2330 removing the tpu_replicate attribute. 2331 """ 2332 # Scan over the top level graph and all function graphs. 2333 for graph in [prune_graph] + [ 2334 f for f in prune_graph._functions.values() # pylint: disable=protected-access 2335 ]: 2336 if not isinstance(graph, ops.Graph): 2337 continue 2338 for op in graph.get_operations(): 2339 if op.type not in _UNCONNECTED_OPS_TO_PRUNE: 2340 continue 2341 outputs_consumed = False 2342 for output in op.outputs: 2343 if output.consumers(): 2344 outputs_consumed = True 2345 break 2346 if not outputs_consumed: 2347 logging.info( 2348 "Pruning OP %s of type %s from XLA Compile due to " 2349 "it being disconnected.", op.name, op.type) 2350 op._clear_attr(_TPU_REPLICATE_ATTR) # pylint: disable=protected-access 2351