xref: /aosp_15_r20/external/tensorflow/tensorflow/python/compiler/xla/jit.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for controlling the Tensorflow/XLA JIT compiler."""
16
17import contextlib
18
19from tensorflow.core.framework import attr_value_pb2
20from tensorflow.python.eager import context
21from tensorflow.python.framework import ops
22from tensorflow.python.util.tf_export import tf_export
23
24
25_XLA_SCOPE_KEY = ("__xla_scope",)
26
27
28class _XlaScope(object):
29  """Keeps track of previous XLA scope calls, and depth of current call."""
30
31  def __init__(self, count, depth):
32    self.count = count
33    self.depth = depth
34
35
36@contextlib.contextmanager
37@tf_export("xla.experimental.jit_scope")
38def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False):
39  """Enable or disable JIT compilation of operators within the scope.
40
41  NOTE: This is an experimental feature.
42
43  The compilation is a hint and only supported on a best-effort basis.
44
45  Example usage:
46
47    ```python
48    with tf.xla.experimental.jit_scope():
49      c = tf.matmul(a, b)  # compiled
50    with tf.xla.experimental.jit_scope(compile_ops=False):
51      d = tf.matmul(a, c)  # not compiled
52    with tf.xla.experimental.jit_scope(
53        compile_ops=lambda node_def: 'matmul' in node_def.op.lower()):
54      e = tf.matmul(a, b) + d  # matmul is compiled, the addition is not.
55    ```
56
57  Example of `separate_compiled_gradients`:
58
59    ```python
60    # In the example below, the computations for f, g and h will all be compiled
61    # in separate scopes.
62    with tf.xla.experimental.jit_scope(
63        separate_compiled_gradients=True):
64      f = tf.matmul(a, b)
65    g = tf.gradients([f], [a, b], name='mygrads1')
66    h = tf.gradients([f], [a, b], name='mygrads2')
67    ```
68
69  Ops that are not in the scope may be clustered and compiled with ops in
70  the scope with `compile_ops=True`, while the ops in the scope with
71  `compile_ops=False` will never be compiled.
72
73  For example:
74
75    ```python
76    # In the example below, x and loss may be clustered and compiled together,
77    # while y will not be compiled.
78    with tf.xla.experimental.jit_scope():
79      x = tf.matmul(a, b)
80    with tf.xla.experimental.jit_scope(compile_ops=False):
81      y = tf.matmul(c, d)
82    loss = x + y
83    ```
84
85  If you want to only compile the ops in the scope with `compile_ops=True`,
86  consider adding an outer `jit_scope(compile_ops=False)`:
87
88    ```python
89    # In the example below, only x will be compiled.
90    with tf.xla.experimental.jit_scope(compile_ops=False):
91      with tf.xla.experimental.jit_scope():
92        x = tf.matmul(a, b)
93      y = tf.matmul(c, d)
94      loss = x + y
95    ```
96
97  Args:
98    compile_ops: Whether to enable or disable compilation in the scope.
99      Either a Python bool, or a callable that accepts the parameter
100      `node_def` and returns a python bool.
101    separate_compiled_gradients: If true put each gradient subgraph into a
102      separate compilation scope. This gives fine-grained control over which
103      portions of the graph will be compiled as a single unit. Compiling
104      gradients separately may yield better performance for some graphs.
105      The scope is named based on the scope of the forward computation as well
106      as the name of the gradients. As a result, the gradients will be compiled
107      in a scope that is separate from both the forward computation, and from
108      other gradients.
109  Raises:
110    RuntimeError: if called when eager execution is enabled.
111  Yields:
112    The current scope, enabling or disabling compilation.
113  """
114  if context.executing_eagerly():
115    raise RuntimeError("xla.experimental.jit_scope is not supported when eager "
116                       "execution is enabled. Try use it inside tf.function.")
117
118  if callable(compile_ops):
119    def xla_compile(node_def):
120      return attr_value_pb2.AttrValue(b=compile_ops(node_def))
121  else:
122    xla_compile = attr_value_pb2.AttrValue(b=compile_ops)
123
124  attrs = {
125      "_XlaCompile":
126          xla_compile,
127      "_XlaSeparateCompiledGradients":
128          attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients))
129  }
130
131  # Find the singleton counter for the current scoped graph.  If it
132  # doesn't exist, create one.
133  xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY)
134  if not xla_scope_counter:
135    xla_scope_counter = _XlaScope(0, 0)
136    ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter)
137  else:
138    xla_scope_counter = xla_scope_counter[0]
139
140  if xla_scope_counter.depth == 0:
141    # If we're at the root xla scope, we can increase the counter so
142    # future calls to jit_scope use a different scope value.
143    # If we're already within a scope, we'll be fusing using the scope
144    # controlled by the parent.
145    attrs["_XlaScope"] = attr_value_pb2.AttrValue(
146        s=("jit_scope_%d" % xla_scope_counter.count).encode())
147    xla_scope_counter.count += 1
148
149  xla_scope_counter.depth += 1
150
151  # pylint: disable=protected-access
152  with ops.get_default_graph()._attr_scope(attrs):
153    yield
154  # pylint: enable=protected-access
155
156  xla_scope_counter.depth -= 1
157