1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""xla is an experimental library that provides XLA support APIs.""" 16 17import contextlib 18 19 20from tensorflow.compiler.jit.ops import xla_ops 21from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import 22from tensorflow.core.framework import attr_value_pb2 23from tensorflow.python.distribute import summary_op_util 24from tensorflow.python.eager import context 25from tensorflow.python.eager import def_function 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import variable_scope 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.util import compat 32from tensorflow.python.util import nest 33from tensorflow.python.util import tf_inspect 34from tensorflow.python.util.compat import collections_abc 35from tensorflow.python.util.deprecation import deprecated 36from tensorflow.python.util.tf_export import tf_export 37 38_XLA_COMPILE_ATTR = '_xla_compile_id' 39_MAX_WARNING_LINES = 5 40 41# Operations that indicate some error in the users graph. For example, XLA 42# computation should not have any Placeholder op. 43_DENYLISTED_OPS = set([ 44 'Placeholder', 45]) 46 47# XLA doesn't currently support reading of intermediate tensors, thus some ops 48# are not supported. 49_UNSUPPORTED_OPS = set([ 50 'AudioSummary', 51 'AudioSummaryV2', 52 'HistogramSummary', 53 'ImageSummary', 54 'MergeSummary', 55 'Print', 56 'ScalarSummary', 57 'TensorSummary', 58 'TensorSummaryV2', 59]) 60 61 62@tf_export('xla.experimental.compile') 63@deprecated( 64 None, 'xla.experimental.compile is deprecated. Consider using ' 65 'tf.function(jit_compile=True)', 66 warn_once=True) 67def compile(computation, inputs=None): # pylint: disable=redefined-builtin 68 """Builds an operator that compiles and runs `computation` with XLA. 69 70 NOTE: In eager mode, `computation` will have `@tf.function` semantics. 71 72 Args: 73 computation: A Python function that builds a computation to apply to the 74 input. If the function takes n inputs, 'inputs' should be a list of n 75 tensors. 76 77 `computation` may return a list of operations and tensors. Tensors must 78 come before operations in the returned list. The return value of 79 `compile` is a list of tensors corresponding to the tensors from the 80 output of `computation`. 81 82 All `Operation`s returned from `computation` will be executed when 83 evaluating any of the returned output tensors. 84 inputs: A list of inputs or `None` (equivalent to an empty list). Each input 85 can be a nested structure containing values that are convertible to 86 tensors. Note that passing an N-dimension list of compatible values will 87 result in a N-dimension list of scalar tensors rather than a single Rank-N 88 tensors. If you need different behavior, convert part of inputs to tensors 89 with `tf.convert_to_tensor`. 90 91 Returns: 92 Same data structure as if computation(*inputs) is called directly with some 93 exceptions for correctness. Exceptions include: 94 1) None output: a NoOp would be returned which control-depends on 95 computation. 96 2) Single value output: A tuple containing the value would be returned. 97 3) Operation-only outputs: a NoOp would be returned which 98 control-depends on computation. 99 TODO(b/121383831): Investigate into removing these special cases. 100 101 Raises: 102 RuntimeError: if called when eager execution is enabled. 103 104 Known issues: 105 When a tf.random operation is built with XLA, the implementation doesn't 106 pass the user provided seed to the XLA compiler. As such, the XLA compiler 107 generates a random number and uses it as a seed when compiling the 108 operation. This implementation causes a violation of the Tensorflow 109 defined semantics in two aspects. First, changing the value of the user 110 defined seed doesn't change the numbers generated by the operation. 111 Second, when a seed is not specified, running the program multiple times 112 will generate the same numbers. 113 114 """ 115 if context.executing_eagerly(): 116 @def_function.function 117 def xla_compile_wrapper(): 118 return _compile_internal(computation, inputs) 119 120 return xla_compile_wrapper() 121 122 return _compile_internal(computation, inputs) 123 124 125class XLACompileContext(control_flow_ops.XLAControlFlowContext): 126 """A `ControlFlowContext` for nodes inside an XLA computation cluster. 127 128 THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY. 129 130 The primary role of `XLACompileContext` is to mark operators inside a 131 xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is 132 a unique name. 133 134 `ControlFlowContext` is used to perform the annotation since it integrates 135 with Tensorflow constructs like ResourceVariables. For example, if a 136 `ResourceVariable` is constructed inside a xla.compile() block, the 137 `ResourceVariable` implementation can use 138 `with ops.control_dependencies(None)` to build the variable's definition 139 outside the compiled computation. 140 """ 141 142 def __init__(self, name, pivot): 143 """Builds a new XLACompileContext. 144 145 Args: 146 name: a unique name for the context, used to populate the 147 `_xla_compile_id` attribute. 148 pivot: a pivot node. Nodes in the XLACompileContext that do not have any 149 inputs will have a control dependency on the pivot node. This ensures 150 that nodes are correctly included in any enclosing control flow 151 contexts. 152 """ 153 super(XLACompileContext, self).__init__() 154 self._name = name 155 self._name_as_bytes = compat.as_bytes(name) 156 self._unsupported_ops = [] 157 self._pivot = pivot 158 159 def report_unsupported_operations(self): 160 if self._unsupported_ops: 161 op_str = '\n'.join([ 162 ' %s (%s)' % (op.type, op.name) 163 for op in self._unsupported_ops[:_MAX_WARNING_LINES] 164 ]) 165 logging.warning('%d unsupported operations found: \n%s', 166 len(self._unsupported_ops), op_str) 167 if len(self._unsupported_ops) > _MAX_WARNING_LINES: 168 logging.warning('... and %d more', 169 len(self._unsupported_ops) - _MAX_WARNING_LINES) 170 171 def _RemoveExternalControlEdges(self, op): 172 """Remove any external control dependency on this op.""" 173 internal_control_inputs = [] 174 external_control_inputs = [] 175 for x in op.control_inputs: 176 # pylint: disable=protected-access 177 is_internal_op = False 178 ctxt = x._get_control_flow_context() 179 while ctxt is not None: 180 if ctxt == self: 181 is_internal_op = True 182 break 183 ctxt = ctxt._outer_context 184 if is_internal_op: 185 internal_control_inputs.append(x) 186 else: 187 external_control_inputs.append(x) 188 # pylint: enable=protected-access 189 # pylint: disable=protected-access 190 op._remove_all_control_inputs() 191 op._add_control_inputs(internal_control_inputs) 192 # pylint: enable=protected-access 193 return internal_control_inputs, external_control_inputs 194 195 def AddOp(self, op): 196 """Create op in XLACompileContext and notifies outer context recursively.""" 197 # pylint: disable=protected-access 198 if op.type in _DENYLISTED_OPS: 199 logging.error( 200 'Operation of type %s (%s) is not supported in XLA. Execution will ' 201 'fail if this op is used in the graph. ', op.type, op.name) 202 203 # TODO(ycao): Automatically disable summaries instead of reporting them. 204 if op.type in _UNSUPPORTED_OPS: 205 self._unsupported_ops.append(op) 206 207 if any(x.dtype._is_ref_dtype for x in op.inputs): 208 raise NotImplementedError( 209 'Non-resource Variables are not supported inside XLA computations ' 210 '(operator name: %s)' % op.name) 211 212 if _XLA_COMPILE_ATTR in op.node_def.attr: 213 raise ValueError('XLA compiled computations cannot be nested, (operator ' 214 'name: %s)' % op.name) 215 216 op._set_attr( 217 _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes)) 218 219 op.graph.prevent_feeding(op) 220 op.graph.prevent_fetching(op) 221 222 # Remove any control edges from outer control flow contexts. These may cause 223 # mismatched frame errors. An example is when one of op's inputs is 224 # generated in a different While control flow context. 225 (internal_control_inputs, 226 external_control_inputs) = self._RemoveExternalControlEdges(op) 227 228 if not op.inputs: 229 # Add a control edge from the control pivot to this op. 230 if not internal_control_inputs: 231 # pylint: disable=protected-access 232 op._add_control_input(self._pivot) 233 # pylint: enable=protected-access 234 else: 235 for index in range(len(op.inputs)): 236 x = op.inputs[index] 237 real_x = self.AddValue(x) 238 if real_x is not x: 239 op._update_input(index, real_x) # pylint: disable=protected-access 240 241 if external_control_inputs: 242 # Use an identity to pull control inputs as data inputs. Note that we 243 # ignore ops which don't have outputs. TODO(phawkins): fix that. 244 with ops.control_dependencies(None): 245 self.Enter() 246 external_control_inputs = [ 247 array_ops.identity(x.outputs[0]).op 248 for x in external_control_inputs 249 if x.outputs 250 ] 251 self.Exit() 252 # pylint: disable=protected-access 253 op._add_control_inputs(external_control_inputs) 254 # pylint: enable=protected-access 255 256 # Mark op's outputs as seen by this context and any outer contexts. 257 output_names = [x.name for x in op.outputs] 258 context = self 259 while context is not None: 260 # pylint: disable=protected-access 261 context._values.update(output_names) 262 context = context._outer_context 263 # pylint: enable=protected-access 264 265 if self._outer_context: 266 self._outer_context.AddInnerOp(op) 267 268 def AddValue(self, val): 269 """Add `val` to the current context and its outer context recursively.""" 270 if val.name in self._values: 271 # Use the real value if it comes from outer context. 272 result = self._external_values.get(val.name) 273 return val if result is None else result 274 275 result = val 276 self._values.add(val.name) 277 if self._outer_context: 278 result = self._outer_context.AddValue(val) 279 self._values.add(result.name) 280 281 self._external_values[val.name] = result 282 283 return result 284 285 def AddInnerOp(self, op): 286 self.AddOp(op) 287 if self._outer_context: 288 self._outer_context.AddInnerOp(op) 289 290 @property 291 def grad_state(self): 292 # Define the gradient loop state associated with the XLACompileContext to 293 # be None as the XLACompileContext does not get nested nor does the 294 # grad_state outside the XLACompileContext affect the graph inside so the 295 # grad_state should be as if this is the top-level gradient state. 296 return None 297 298 @property 299 def back_prop(self): 300 """Forwards to the enclosing while context, if any.""" 301 if self.GetWhileContext(): 302 return self.GetWhileContext().back_prop 303 return False 304 305 306def _compile_internal(computation, inputs=None): 307 """Builds graph operators that compiles and symbolically executes computation. 308 309 Args: 310 computation: A Python function that builds the computation to compile and 311 execute. 312 inputs: A list of inputs or `None` (equivalent to an empty list). Each input 313 can be a nested structure containing values that are convertible to 314 tensors. Note that passing an N-dimension list of compatible values will 315 result in a N-dimension list of scalar tensors rather than a single Rank-N 316 tensors. If you need different behavior, convert part of inputs to tensors 317 with `tf.convert_to_tensor`. 318 319 Returns: 320 Same data structure as if computation(*inputs) is called directly with some 321 exceptions for correctness. Exceptions include: 1) None output 2) Single 322 value output 3) Operation-only outputs 323 Raises: 324 ValueError: If any element in computation outputs is neither an operations 325 or a value that can be converted to tensor. 326 ValueError: If computation outputs is non-flat and contains any Operations. 327 TypeError: If `inputs` is not a list or tuple. 328 """ 329 if inputs is None: 330 inputs = [] 331 332 if not isinstance(inputs, collections_abc.Sequence): 333 raise TypeError('inputs must be a list') 334 335 # Flatten inputs. 336 flat_inputs = nest.flatten(inputs) 337 # Converts inputs to Tensors. 338 flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] 339 340 cluster_name = ops.get_default_graph().unique_name('cluster') 341 pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') 342 context = XLACompileContext(name=cluster_name, pivot=pivot) 343 try: 344 context.Enter() 345 346 # Add identity ops so even unused inputs are 'consumed' by the 347 # computation. 348 flat_inputs = [ 349 array_ops.identity(x, name='input_{}'.format(i)) 350 for i, x in enumerate(flat_inputs) 351 ] 352 353 # Re-pack flat_inputs in same structure as 'inputs'. 354 computation_inputs = nest.pack_sequence_as( 355 structure=inputs, flat_sequence=flat_inputs) 356 357 # Only resource variables work inside an XLA computation, so turn on 358 # resource variables for the computation. 359 vscope = variable_scope.get_variable_scope() 360 saved_use_resource = vscope.use_resource 361 vscope.set_use_resource(True) 362 363 with _disable_summary_context(): 364 outputs = computation(*computation_inputs) 365 366 # Restore variable scope after computation. 367 vscope.set_use_resource(saved_use_resource) 368 369 outputs_is_flat = is_flat(outputs) 370 if outputs_is_flat: 371 output_tensors, control_deps = _postprocess_flat_outputs(outputs) 372 else: 373 output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) 374 375 context.ExitResult(output_tensors) 376 finally: 377 context.report_unsupported_operations() 378 context.Exit() 379 380 # When XLA computation returns only operations and no tensors, a NoOp 381 # dependent on the operations in outputs is returned. Otherwise final 382 # outputs would be empty and there is no way to trigger returned 383 # operations. 384 if not output_tensors: 385 return control_flow_ops.group(control_deps, name='output_0') 386 387 output_tensors = [ 388 xla_ops.xla_cluster_output(o, name='output{}'.format(i)) 389 for i, o in enumerate(output_tensors) 390 ] 391 392 with ops.control_dependencies(control_deps): 393 # Wraps the outputs in identity operators that carries control 394 # dependencies. 395 output_tensors = [ 396 array_ops.identity(o, name='output_%d' % i) 397 for i, o in enumerate(output_tensors) 398 ] 399 400 # If `computation` returned non-flat output structure, pack output tensors 401 # back into same structure. 402 if not outputs_is_flat: 403 output_tensors = nest.pack_sequence_as( 404 structure=outputs, flat_sequence=output_tensors) 405 406 return output_tensors 407 408 409def is_flat(outputs): 410 """Checks if outputs is a flat structure. 411 412 Following structures and values are considered flat: 413 1) None 414 2) A single object 415 3) A list or tuple of Tensors/Operations 416 417 The only structures that this function understands are sequences, 418 dictionaries and types defined using the attrs library. E.g. this means 419 that if outputs contains a single user-defined Object, it is considered to 420 be flat. Errors are raised later on if that Object cannot be converted to a 421 Tensor. 422 423 Args: 424 outputs: Output from `computation` inside `xla.compile`. 425 426 Returns: 427 A boolean indicates whether outputs is flat. 428 """ 429 # If outputs is a list or tuple, check if it has any nested structure. If 430 # there is, then outputs is non-flat. 431 if isinstance(outputs, collections_abc.Sequence): 432 for o in outputs: 433 if (isinstance(o, collections_abc.Sequence) or 434 isinstance(o, collections_abc.Mapping) or 435 hasattr(o.__class__, '__attrs_attrs__')): 436 return False 437 438 # If outputs is a dict, it is non-flat. 439 if isinstance(outputs, collections_abc.Mapping): 440 return False 441 442 # If outputs is from the attrs library, it is non-flat. 443 if hasattr(outputs.__class__, '__attrs_attrs__'): 444 return False 445 446 # Getting here means either outputs itself is a single non-structured value 447 # or it is a flat list of single non-structured values. 448 return True 449 450 451def _postprocess_flat_outputs(outputs): 452 """Validates flat outputs and adds back device assignments. 453 454 Args: 455 outputs: Output from `computation` inside `xla.compile`. 456 457 Returns: 458 Tensors and Operations extracted from outputs. 459 """ 460 # Following code segment is to preserve legacy behavior. Previously we only 461 # supported flat outputs and thus for consistency it was nice to convert even 462 # single element into a tuple. But now that we support arbitrary output 463 # structure, this is no longer necessary. 464 # TODO(b/121383831): Migrate all legacy use cases and delete this special 465 # case. 466 # If the computation returns `None`, make it an empty tuple. 467 if outputs is None: 468 outputs = tuple() 469 # If the computation only returned one value, make it a tuple. 470 if not isinstance(outputs, collections_abc.Sequence): 471 outputs = (outputs,) 472 473 # Append `no_op` here so that return value of this function always contains 474 # at least one op that can trigger XlaLaunch node. 475 outputs += (control_flow_ops.no_op(),) 476 try: 477 outputs = [ 478 o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) 479 for o in outputs 480 ] 481 except Exception as e: 482 raise ValueError( 483 'XLA computation function return values must all either be Operations' 484 ' or convertible to Tensors. Got error: "%s"' % str(e)) 485 486 # Separates the returned Operations and Tensors. 487 output_operations = [o for o in outputs if isinstance(o, ops.Operation)] 488 output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] 489 490 if outputs != output_tensors + output_operations: 491 raise ValueError( 492 'XLA computation function must return zero or more Tensor values ' 493 'followed by zero or more Operations.') 494 495 new_output_tensors = [] 496 for t in output_tensors: 497 with ops.device(t.device if t.device else ''): 498 new_output_tensors.append(array_ops.identity(t)) 499 500 return new_output_tensors, output_operations 501 502 503def _postprocess_non_flat_outputs(outputs): 504 """Validates non-flat outputs and adds back device assignments. 505 506 Args: 507 outputs: Output from `computation` inside `xla.compile`. 508 509 Returns: 510 Tensors extracted from outputs and an empty list because Operations are not 511 allowed in non-flat outputs.. 512 """ 513 # Convert all non-Operation outputs to Tensors. 514 new_output_tensors = [] 515 for o in nest.flatten(outputs): 516 if isinstance(o, ops.Operation): 517 raise ValueError( 518 'xla.compile does not support Operation as return value in non-flat ' 519 'output structure. You can set returned Operations as control ' 520 'dependencies of returned Tensors so Operations are triggered when ' 521 'Tensors are evaluated. Operation found: "%s"' % o.name) 522 523 try: 524 o = ops.convert_to_tensor(o) 525 except Exception as e: 526 raise ValueError( 527 'XLA computation function return values must all either be ' 528 'Operations or convertible to Tensors. Got error: "%s"' % str(e)) 529 530 # Makes sure even pass-through inputs/outputs are touched in compile 531 # context by creating an Identity node inside compile context. 532 with ops.device(o.device if o.device else ''): 533 new_output_tensors.append(array_ops.identity(o)) 534 535 return new_output_tensors, [] 536 537 538@contextlib.contextmanager 539def _disable_summary_context(): 540 """Enters a context where all summary ops are skipped. 541 542 Summaries are not yet supported in xla.compile(). So we provide this context 543 manager that can skip creating summary ops. This is a temporary workaround due 544 to XLA not supporting summary ops. 545 546 Yields: 547 None. 548 """ 549 original_skip_summary_func = summary_op_util.skip_summary 550 summary_op_util.skip_summary = lambda: True 551 552 try: 553 yield 554 finally: 555 summary_op_util.skip_summary = original_skip_summary_func 556 557 558class _CapturedObject(object): 559 """A placeholder to capture an object.""" 560 561 def __init__(self): 562 self._object = None 563 564 def capture(self, o): 565 if self._object: 566 raise RuntimeError( 567 'InternalError: _CapturedObject can capture only once. Please file ' 568 'bug.') 569 570 self._object = o 571 572 def get(self): 573 return self._object 574 575 576def _get_scaffold(captured_scaffold_fn): 577 """Retrieves the Scaffold from `captured_scaffold_fn`.""" 578 scaffold_fn = captured_scaffold_fn.get() 579 580 if not scaffold_fn: 581 return None 582 583 scaffold = scaffold_fn() 584 if scaffold is None: 585 raise ValueError( 586 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') 587 588 return scaffold 589 590 591def check_function_argument_count(func, input_arity, infeed_queue): 592 """Validate the number of input arguments to an XLA function. 593 594 Args: 595 func: the Python function that will be called to generate the body of an XLA 596 computation graph. 597 input_arity: the number of explicit arguments supplied by the caller. 598 infeed_queue: if not None, the infeed queue that will supply 599 additional arguments to the function. 600 601 Returns: 602 None if function can be called with the supplied number of 603 arguments, or an error string if it cannot. 604 """ 605 def format_error(complaint, quantity): 606 return '%s %d argument%s' % (complaint, quantity, '' 607 if quantity == 1 else 's') 608 609 num_args_supplied = input_arity 610 if infeed_queue is not None: 611 num_args_supplied += infeed_queue.number_of_tuple_elements 612 arg_spec = tf_inspect.getargspec(func) 613 num_func_args = len(arg_spec.args) 614 if arg_spec.defaults is None: 615 num_func_defaults = 0 616 else: 617 num_func_defaults = len(arg_spec.defaults) 618 min_func_args = num_func_args - num_func_defaults 619 if num_args_supplied < min_func_args: 620 # The required number of arguments is not enough to call the function. 621 if num_func_defaults == 0 and arg_spec.varargs is None: 622 return format_error('exactly', num_func_args) 623 else: 624 return format_error('at least', min_func_args) 625 if arg_spec.varargs is None and num_args_supplied > num_func_args: 626 # The required number of arguments is too many to call the function. 627 if num_func_defaults == 0: 628 return format_error('exactly', num_func_args) 629 else: 630 return format_error('at most', num_func_args) 631 # Reaching here means either 632 # 1) There are varargs, func can accept any number of arguments greater than 633 # the minimum. 634 # 2) Number of supplied arguments falls in range of acceptable argument count 635 # of func. 636 return None 637