xref: /aosp_15_r20/external/tensorflow/tensorflow/python/compiler/xla/xla.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""xla is an experimental library that provides XLA support APIs."""
16
17import contextlib
18
19
20from tensorflow.compiler.jit.ops import xla_ops
21from tensorflow.compiler.jit.ops import xla_ops_grad  # pylint: disable=unused-import
22from tensorflow.core.framework import attr_value_pb2
23from tensorflow.python.distribute import summary_op_util
24from tensorflow.python.eager import context
25from tensorflow.python.eager import def_function
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import variable_scope
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util import compat
32from tensorflow.python.util import nest
33from tensorflow.python.util import tf_inspect
34from tensorflow.python.util.compat import collections_abc
35from tensorflow.python.util.deprecation import deprecated
36from tensorflow.python.util.tf_export import tf_export
37
38_XLA_COMPILE_ATTR = '_xla_compile_id'
39_MAX_WARNING_LINES = 5
40
41# Operations that indicate some error in the users graph. For example, XLA
42# computation should not have any Placeholder op.
43_DENYLISTED_OPS = set([
44    'Placeholder',
45])
46
47# XLA doesn't currently support reading of intermediate tensors, thus some ops
48# are not supported.
49_UNSUPPORTED_OPS = set([
50    'AudioSummary',
51    'AudioSummaryV2',
52    'HistogramSummary',
53    'ImageSummary',
54    'MergeSummary',
55    'Print',
56    'ScalarSummary',
57    'TensorSummary',
58    'TensorSummaryV2',
59])
60
61
62@tf_export('xla.experimental.compile')
63@deprecated(
64    None, 'xla.experimental.compile is deprecated. Consider using '
65    'tf.function(jit_compile=True)',
66    warn_once=True)
67def compile(computation, inputs=None):  # pylint: disable=redefined-builtin
68  """Builds an operator that compiles and runs `computation` with XLA.
69
70  NOTE: In eager mode, `computation` will have `@tf.function` semantics.
71
72  Args:
73    computation: A Python function that builds a computation to apply to the
74      input. If the function takes n inputs, 'inputs' should be a list of n
75      tensors.
76
77      `computation` may return a list of operations and tensors.  Tensors must
78      come before operations in the returned list.  The return value of
79      `compile` is a list of tensors corresponding to the tensors from the
80      output of `computation`.
81
82      All `Operation`s returned from `computation` will be executed when
83      evaluating any of the returned output tensors.
84    inputs: A list of inputs or `None` (equivalent to an empty list). Each input
85      can be a nested structure containing values that are convertible to
86      tensors. Note that passing an N-dimension list of compatible values will
87      result in a N-dimension list of scalar tensors rather than a single Rank-N
88      tensors. If you need different behavior, convert part of inputs to tensors
89      with `tf.convert_to_tensor`.
90
91  Returns:
92    Same data structure as if computation(*inputs) is called directly with some
93    exceptions for correctness. Exceptions include:
94      1) None output: a NoOp would be returned which control-depends on
95         computation.
96      2) Single value output: A tuple containing the value would be returned.
97      3) Operation-only outputs: a NoOp would be returned which
98         control-depends on computation.
99      TODO(b/121383831): Investigate into removing these special cases.
100
101  Raises:
102    RuntimeError: if called when eager execution is enabled.
103
104  Known issues:
105    When a tf.random operation is built with XLA, the implementation doesn't
106      pass the user provided seed to the XLA compiler. As such, the XLA compiler
107      generates a random number and uses it as a seed when compiling the
108      operation. This implementation causes a violation of the Tensorflow
109      defined semantics in two aspects. First, changing the value of the user
110      defined seed doesn't change the numbers generated by the operation.
111      Second, when a seed is not specified, running the program multiple times
112      will generate the same numbers.
113
114  """
115  if context.executing_eagerly():
116    @def_function.function
117    def xla_compile_wrapper():
118      return _compile_internal(computation, inputs)
119
120    return xla_compile_wrapper()
121
122  return _compile_internal(computation, inputs)
123
124
125class XLACompileContext(control_flow_ops.XLAControlFlowContext):
126  """A `ControlFlowContext` for nodes inside an XLA computation cluster.
127
128  THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY.
129
130  The primary role of `XLACompileContext` is to mark operators inside a
131  xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
132  a unique name.
133
134  `ControlFlowContext` is used to perform the annotation since it integrates
135  with Tensorflow constructs like ResourceVariables. For example, if a
136  `ResourceVariable` is constructed inside a xla.compile() block, the
137  `ResourceVariable` implementation can use
138  `with ops.control_dependencies(None)` to build the variable's definition
139  outside the compiled computation.
140  """
141
142  def __init__(self, name, pivot):
143    """Builds a new XLACompileContext.
144
145    Args:
146      name: a unique name for the context, used to populate the
147        `_xla_compile_id` attribute.
148      pivot: a pivot node. Nodes in the XLACompileContext that do not have any
149        inputs will have a control dependency on the pivot node. This ensures
150        that nodes are correctly included in any enclosing control flow
151        contexts.
152    """
153    super(XLACompileContext, self).__init__()
154    self._name = name
155    self._name_as_bytes = compat.as_bytes(name)
156    self._unsupported_ops = []
157    self._pivot = pivot
158
159  def report_unsupported_operations(self):
160    if self._unsupported_ops:
161      op_str = '\n'.join([
162          '  %s (%s)' % (op.type, op.name)
163          for op in self._unsupported_ops[:_MAX_WARNING_LINES]
164      ])
165      logging.warning('%d unsupported operations found: \n%s',
166                      len(self._unsupported_ops), op_str)
167      if len(self._unsupported_ops) > _MAX_WARNING_LINES:
168        logging.warning('... and %d more',
169                        len(self._unsupported_ops) - _MAX_WARNING_LINES)
170
171  def _RemoveExternalControlEdges(self, op):
172    """Remove any external control dependency on this op."""
173    internal_control_inputs = []
174    external_control_inputs = []
175    for x in op.control_inputs:
176      # pylint: disable=protected-access
177      is_internal_op = False
178      ctxt = x._get_control_flow_context()
179      while ctxt is not None:
180        if ctxt == self:
181          is_internal_op = True
182          break
183        ctxt = ctxt._outer_context
184      if is_internal_op:
185        internal_control_inputs.append(x)
186      else:
187        external_control_inputs.append(x)
188      # pylint: enable=protected-access
189    # pylint: disable=protected-access
190    op._remove_all_control_inputs()
191    op._add_control_inputs(internal_control_inputs)
192    # pylint: enable=protected-access
193    return internal_control_inputs, external_control_inputs
194
195  def AddOp(self, op):
196    """Create op in XLACompileContext and notifies outer context recursively."""
197    # pylint: disable=protected-access
198    if op.type in _DENYLISTED_OPS:
199      logging.error(
200          'Operation of type %s (%s) is not supported in XLA. Execution will '
201          'fail if this op is used in the graph. ', op.type, op.name)
202
203    # TODO(ycao): Automatically disable summaries instead of reporting them.
204    if op.type in _UNSUPPORTED_OPS:
205      self._unsupported_ops.append(op)
206
207    if any(x.dtype._is_ref_dtype for x in op.inputs):
208      raise NotImplementedError(
209          'Non-resource Variables are not supported inside XLA computations '
210          '(operator name: %s)' % op.name)
211
212    if _XLA_COMPILE_ATTR in op.node_def.attr:
213      raise ValueError('XLA compiled computations cannot be nested, (operator '
214                       'name: %s)' % op.name)
215
216    op._set_attr(
217        _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))
218
219    op.graph.prevent_feeding(op)
220    op.graph.prevent_fetching(op)
221
222    # Remove any control edges from outer control flow contexts. These may cause
223    # mismatched frame errors. An example is when one of op's inputs is
224    # generated in a different While control flow context.
225    (internal_control_inputs,
226     external_control_inputs) = self._RemoveExternalControlEdges(op)
227
228    if not op.inputs:
229      # Add a control edge from the control pivot to this op.
230      if not internal_control_inputs:
231        # pylint: disable=protected-access
232        op._add_control_input(self._pivot)
233        # pylint: enable=protected-access
234    else:
235      for index in range(len(op.inputs)):
236        x = op.inputs[index]
237        real_x = self.AddValue(x)
238        if real_x is not x:
239          op._update_input(index, real_x)  # pylint: disable=protected-access
240
241    if external_control_inputs:
242      # Use an identity to pull control inputs as data inputs. Note that we
243      # ignore ops which don't have outputs. TODO(phawkins): fix that.
244      with ops.control_dependencies(None):
245        self.Enter()
246        external_control_inputs = [
247            array_ops.identity(x.outputs[0]).op
248            for x in external_control_inputs
249            if x.outputs
250        ]
251        self.Exit()
252      # pylint: disable=protected-access
253      op._add_control_inputs(external_control_inputs)
254      # pylint: enable=protected-access
255
256    # Mark op's outputs as seen by this context and any outer contexts.
257    output_names = [x.name for x in op.outputs]
258    context = self
259    while context is not None:
260      # pylint: disable=protected-access
261      context._values.update(output_names)
262      context = context._outer_context
263      # pylint: enable=protected-access
264
265    if self._outer_context:
266      self._outer_context.AddInnerOp(op)
267
268  def AddValue(self, val):
269    """Add `val` to the current context and its outer context recursively."""
270    if val.name in self._values:
271      # Use the real value if it comes from outer context.
272      result = self._external_values.get(val.name)
273      return val if result is None else result
274
275    result = val
276    self._values.add(val.name)
277    if self._outer_context:
278      result = self._outer_context.AddValue(val)
279      self._values.add(result.name)
280
281    self._external_values[val.name] = result
282
283    return result
284
285  def AddInnerOp(self, op):
286    self.AddOp(op)
287    if self._outer_context:
288      self._outer_context.AddInnerOp(op)
289
290  @property
291  def grad_state(self):
292    # Define the gradient loop state associated with the XLACompileContext to
293    # be None as the XLACompileContext does not get nested nor does the
294    # grad_state outside the XLACompileContext affect the graph inside so the
295    # grad_state should be as if this is the top-level gradient state.
296    return None
297
298  @property
299  def back_prop(self):
300    """Forwards to the enclosing while context, if any."""
301    if self.GetWhileContext():
302      return self.GetWhileContext().back_prop
303    return False
304
305
306def _compile_internal(computation, inputs=None):
307  """Builds graph operators that compiles and symbolically executes computation.
308
309  Args:
310    computation: A Python function that builds the computation to compile and
311      execute.
312    inputs: A list of inputs or `None` (equivalent to an empty list). Each input
313      can be a nested structure containing values that are convertible to
314      tensors. Note that passing an N-dimension list of compatible values will
315      result in a N-dimension list of scalar tensors rather than a single Rank-N
316      tensors. If you need different behavior, convert part of inputs to tensors
317      with `tf.convert_to_tensor`.
318
319  Returns:
320    Same data structure as if computation(*inputs) is called directly with some
321    exceptions for correctness. Exceptions include: 1) None output 2) Single
322    value output 3) Operation-only outputs
323  Raises:
324    ValueError: If any element in computation outputs is neither an operations
325      or a value that can be converted to tensor.
326    ValueError: If computation outputs is non-flat and contains any Operations.
327    TypeError: If `inputs` is not a list or tuple.
328  """
329  if inputs is None:
330    inputs = []
331
332  if not isinstance(inputs, collections_abc.Sequence):
333    raise TypeError('inputs must be a list')
334
335  # Flatten inputs.
336  flat_inputs = nest.flatten(inputs)
337  # Converts inputs to Tensors.
338  flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]
339
340  cluster_name = ops.get_default_graph().unique_name('cluster')
341  pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
342  context = XLACompileContext(name=cluster_name, pivot=pivot)
343  try:
344    context.Enter()
345
346    # Add identity ops so even unused inputs are 'consumed' by the
347    # computation.
348    flat_inputs = [
349        array_ops.identity(x, name='input_{}'.format(i))
350        for i, x in enumerate(flat_inputs)
351    ]
352
353    # Re-pack flat_inputs in same structure as 'inputs'.
354    computation_inputs = nest.pack_sequence_as(
355        structure=inputs, flat_sequence=flat_inputs)
356
357    # Only resource variables work inside an XLA computation, so turn on
358    # resource variables for the computation.
359    vscope = variable_scope.get_variable_scope()
360    saved_use_resource = vscope.use_resource
361    vscope.set_use_resource(True)
362
363    with _disable_summary_context():
364      outputs = computation(*computation_inputs)
365
366    # Restore variable scope after computation.
367    vscope.set_use_resource(saved_use_resource)
368
369    outputs_is_flat = is_flat(outputs)
370    if outputs_is_flat:
371      output_tensors, control_deps = _postprocess_flat_outputs(outputs)
372    else:
373      output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)
374
375    context.ExitResult(output_tensors)
376  finally:
377    context.report_unsupported_operations()
378    context.Exit()
379
380  # When XLA computation returns only operations and no tensors, a NoOp
381  # dependent on the operations in outputs is returned. Otherwise final
382  # outputs would be empty and there is no way to trigger returned
383  # operations.
384  if not output_tensors:
385    return control_flow_ops.group(control_deps, name='output_0')
386
387  output_tensors = [
388      xla_ops.xla_cluster_output(o, name='output{}'.format(i))
389      for i, o in enumerate(output_tensors)
390  ]
391
392  with ops.control_dependencies(control_deps):
393    # Wraps the outputs in identity operators that carries control
394    # dependencies.
395    output_tensors = [
396        array_ops.identity(o, name='output_%d' % i)
397        for i, o in enumerate(output_tensors)
398    ]
399
400  # If `computation` returned non-flat output structure, pack output tensors
401  # back into same structure.
402  if not outputs_is_flat:
403    output_tensors = nest.pack_sequence_as(
404        structure=outputs, flat_sequence=output_tensors)
405
406  return output_tensors
407
408
409def is_flat(outputs):
410  """Checks if outputs is a flat structure.
411
412    Following structures and values are considered flat:
413    1) None
414    2) A single object
415    3) A list or tuple of Tensors/Operations
416
417    The only structures that this function understands are sequences,
418    dictionaries and types defined using the attrs library.  E.g. this means
419    that if outputs contains a single user-defined Object, it is considered to
420    be flat. Errors are raised later on if that Object cannot be converted to a
421    Tensor.
422
423  Args:
424    outputs: Output from `computation` inside `xla.compile`.
425
426  Returns:
427    A boolean indicates whether outputs is flat.
428  """
429  # If outputs is a list or tuple, check if it has any nested structure. If
430  # there is, then outputs is non-flat.
431  if isinstance(outputs, collections_abc.Sequence):
432    for o in outputs:
433      if (isinstance(o, collections_abc.Sequence) or
434          isinstance(o, collections_abc.Mapping) or
435          hasattr(o.__class__, '__attrs_attrs__')):
436        return False
437
438  # If outputs is a dict, it is non-flat.
439  if isinstance(outputs, collections_abc.Mapping):
440    return False
441
442  # If outputs is from the attrs library, it is non-flat.
443  if hasattr(outputs.__class__, '__attrs_attrs__'):
444    return False
445
446  # Getting here means either outputs itself is a single non-structured value
447  # or it is a flat list of single non-structured values.
448  return True
449
450
451def _postprocess_flat_outputs(outputs):
452  """Validates flat outputs and adds back device assignments.
453
454  Args:
455    outputs: Output from `computation` inside `xla.compile`.
456
457  Returns:
458    Tensors and Operations extracted from outputs.
459  """
460  # Following code segment is to preserve legacy behavior. Previously we only
461  # supported flat outputs and thus for consistency it was nice to convert even
462  # single element into a tuple. But now that we support arbitrary output
463  # structure, this is no longer necessary.
464  # TODO(b/121383831): Migrate all legacy use cases and delete this special
465  # case.
466  # If the computation returns `None`, make it an empty tuple.
467  if outputs is None:
468    outputs = tuple()
469  # If the computation only returned one value, make it a tuple.
470  if not isinstance(outputs, collections_abc.Sequence):
471    outputs = (outputs,)
472
473  # Append `no_op` here so that return value of this function always contains
474  # at least one op that can trigger XlaLaunch node.
475  outputs += (control_flow_ops.no_op(),)
476  try:
477    outputs = [
478        o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
479        for o in outputs
480    ]
481  except Exception as e:
482    raise ValueError(
483        'XLA computation function return values must all either be Operations'
484        ' or convertible to Tensors. Got error: "%s"' % str(e))
485
486  # Separates the returned Operations and Tensors.
487  output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
488  output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
489
490  if outputs != output_tensors + output_operations:
491    raise ValueError(
492        'XLA computation function must return zero or more Tensor values '
493        'followed by zero or more Operations.')
494
495  new_output_tensors = []
496  for t in output_tensors:
497    with ops.device(t.device if t.device else ''):
498      new_output_tensors.append(array_ops.identity(t))
499
500  return new_output_tensors, output_operations
501
502
503def _postprocess_non_flat_outputs(outputs):
504  """Validates non-flat outputs and adds back device assignments.
505
506  Args:
507    outputs: Output from `computation` inside `xla.compile`.
508
509  Returns:
510    Tensors extracted from outputs and an empty list because Operations are not
511    allowed in non-flat outputs..
512  """
513  # Convert all non-Operation outputs to Tensors.
514  new_output_tensors = []
515  for o in nest.flatten(outputs):
516    if isinstance(o, ops.Operation):
517      raise ValueError(
518          'xla.compile does not support Operation as return value in non-flat '
519          'output structure. You can set returned Operations as control '
520          'dependencies of returned Tensors so Operations are triggered when '
521          'Tensors are evaluated. Operation found: "%s"' % o.name)
522
523    try:
524      o = ops.convert_to_tensor(o)
525    except Exception as e:
526      raise ValueError(
527          'XLA computation function return values must all either be '
528          'Operations or convertible to Tensors. Got error: "%s"' % str(e))
529
530    # Makes sure even pass-through inputs/outputs are touched in compile
531    # context by creating an Identity node inside compile context.
532    with ops.device(o.device if o.device else ''):
533      new_output_tensors.append(array_ops.identity(o))
534
535  return new_output_tensors, []
536
537
538@contextlib.contextmanager
539def _disable_summary_context():
540  """Enters a context where all summary ops are skipped.
541
542  Summaries are not yet supported in xla.compile(). So we provide this context
543  manager that can skip creating summary ops. This is a temporary workaround due
544  to XLA not supporting summary ops.
545
546  Yields:
547    None.
548  """
549  original_skip_summary_func = summary_op_util.skip_summary
550  summary_op_util.skip_summary = lambda: True
551
552  try:
553    yield
554  finally:
555    summary_op_util.skip_summary = original_skip_summary_func
556
557
558class _CapturedObject(object):
559  """A placeholder to capture an object."""
560
561  def __init__(self):
562    self._object = None
563
564  def capture(self, o):
565    if self._object:
566      raise RuntimeError(
567          'InternalError: _CapturedObject can capture only once. Please file '
568          'bug.')
569
570    self._object = o
571
572  def get(self):
573    return self._object
574
575
576def _get_scaffold(captured_scaffold_fn):
577  """Retrieves the Scaffold from `captured_scaffold_fn`."""
578  scaffold_fn = captured_scaffold_fn.get()
579
580  if not scaffold_fn:
581    return None
582
583  scaffold = scaffold_fn()
584  if scaffold is None:
585    raise ValueError(
586        'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
587
588  return scaffold
589
590
591def check_function_argument_count(func, input_arity, infeed_queue):
592  """Validate the number of input arguments to an XLA function.
593
594  Args:
595    func: the Python function that will be called to generate the body of an XLA
596      computation graph.
597    input_arity: the number of explicit arguments supplied by the caller.
598    infeed_queue: if not None, the infeed queue that will supply
599      additional arguments to the function.
600
601  Returns:
602    None if function can be called with the supplied number of
603      arguments, or an error string if it cannot.
604  """
605  def format_error(complaint, quantity):
606    return '%s %d argument%s' % (complaint, quantity, ''
607                                 if quantity == 1 else 's')
608
609  num_args_supplied = input_arity
610  if infeed_queue is not None:
611    num_args_supplied += infeed_queue.number_of_tuple_elements
612  arg_spec = tf_inspect.getargspec(func)
613  num_func_args = len(arg_spec.args)
614  if arg_spec.defaults is None:
615    num_func_defaults = 0
616  else:
617    num_func_defaults = len(arg_spec.defaults)
618  min_func_args = num_func_args - num_func_defaults
619  if num_args_supplied < min_func_args:
620    # The required number of arguments is not enough to call the function.
621    if num_func_defaults == 0 and arg_spec.varargs is None:
622      return format_error('exactly', num_func_args)
623    else:
624      return format_error('at least', min_func_args)
625  if arg_spec.varargs is None and num_args_supplied > num_func_args:
626    # The required number of arguments is too many to call the function.
627    if num_func_defaults == 0:
628      return format_error('exactly', num_func_args)
629    else:
630      return format_error('at most', num_func_args)
631  # Reaching here means either
632  # 1) There are varargs, func can accept any number of arguments greater than
633  # the minimum.
634  # 2) Number of supplied arguments falls in range of acceptable argument count
635  # of func.
636  return None
637