xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tests/xla_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Definition of XLA test case."""
16
17import contextlib
18import os
19import random
20import re
21
22import numpy as np
23
24from tensorflow.core.framework import types_pb2
25from tensorflow.core.protobuf import config_pb2
26from tensorflow.core.protobuf import rewriter_config_pb2
27from tensorflow.python.client import session
28from tensorflow.python.compiler.xla import jit
29from tensorflow.python.eager import context
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import random_seed
33from tensorflow.python.framework import test_util
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.platform import flags
37from tensorflow.python.platform import test
38from tensorflow.python.platform import tf_logging as logging
39
40FLAGS = flags.FLAGS
41
42flags.DEFINE_string('test_device', None,
43                    'Tensorflow device on which to place operators under test')
44flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.')
45flags.DEFINE_string('disabled_manifest', None,
46                    'Path to a file with a list of tests that should not run.')
47flags.DEFINE_string('tf_xla_flags', None,
48                    'Value to set the TF_XLA_FLAGS environment variable to')
49
50
51def parse_disabled_manifest(manifest_content):
52  comments_re = re.compile('#.*$')
53  disabled_tests = []
54  disabled_method_types = []
55  for l in manifest_content.splitlines():
56    stripped = comments_re.sub('', l).strip()
57    if not stripped:
58      continue
59    entry = stripped.split(' ')
60    if len(entry) == 1:
61      disabled_tests.append(entry[0])
62    elif len(entry) == 2:
63      disabled_method_types.append((entry[0], entry[1].strip().split(',')))
64    else:
65      raise ValueError('Bad entry in manifest file.')
66
67  disabled_regex = '|'.join(disabled_tests)
68  method_types_filter = {}
69  for method, types in disabled_method_types:
70    method_types_filter[method] = set([
71        dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype
72        for name in types
73    ])
74  return disabled_regex, method_types_filter
75
76
77class TPURewriteSession(session.Session):
78  """Tensorflow session that runs tpu.rewrite() on ops on run()."""
79
80  def __init__(self, *args, **kwargs):
81    super().__init__(*args, **kwargs)
82    self.topology = None
83
84  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
85    from tensorflow.python.tpu import tpu  # pylint: disable=g-import-not-at-top
86    if self.topology is None:
87      self.topology = super().run(tpu.initialize_system())
88      assert self.topology is not None
89    fetch_mapper = session._FetchMapper.for_fetch(fetches)
90    new_fetches = []
91    for fetch in fetch_mapper.unique_fetches():
92      if isinstance(fetch, ops.Operation):
93        fetch = tpu.rewrite(lambda fetch=fetch: fetch)
94      new_fetches.append(fetch)
95    rewritten_fetches = fetch_mapper.build_results(new_fetches)
96    return super().run(rewritten_fetches, feed_dict, options, run_metadata)
97
98
99class XLATestCase(test.TestCase):
100  """XLA test cases are parameterized test cases."""
101
102  def __init__(self, method_name='runTest'):
103    super(XLATestCase, self).__init__(method_name)
104    if 'XLA' in FLAGS.test_device:
105      context.context().enable_xla_devices()
106
107    # Check if the mlir bridge has been explicitly enabled or disabled. If
108    # is_mlir_bridge_enabled() returns None, the user did not explictly enable
109    # or disable the bridge so do not update enable_mlir_bridge.
110    if test_util.is_mlir_bridge_enabled():
111      context.context().enable_mlir_bridge = True
112    elif test_util.is_mlir_bridge_enabled() is not None:
113      context.context().enable_mlir_bridge = False
114
115    self.device = FLAGS.test_device
116    self.has_custom_call = (self.device == 'XLA_CPU')
117
118    # Some tests (e.g. ftrl_ops) only work if the program goes through the
119    # _TPUCompileMLIR op. They will set this flag to True.
120    # TODO(kramm): Flip to true (and enable MLIR bridge) for more tests.
121    self.rewrite_ops_for_tpu = False
122
123    self._all_tf_types = set([
124        dtypes.as_dtype(types_pb2.DataType.Value(name))
125        for name in FLAGS.types.split(',')
126    ])
127    self.int_tf_types = set([
128        dtype for dtype in self._all_tf_types if dtype.is_integer
129    ])
130    self._float_tf_types = set([
131        dtype for dtype in self._all_tf_types if dtype.is_floating
132    ])
133    self.complex_tf_types = set([
134        dtype for dtype in self._all_tf_types if dtype.is_complex
135    ])
136    self._numeric_tf_types = set(
137        self.int_tf_types | self._float_tf_types | self.complex_tf_types)
138    self.quantized_tf_types = set(
139        dtype for dtype in self._all_tf_types if dtype.is_quantized)
140
141    # Quantized types don't have a numpy equivalent, include them in
142    # all_tf_types but not in all_types.
143    # TODO(b/115960798): Parametrize tests on TF types instead of numpy types
144    # and remove all_types.
145    self._all_types = set(dtype.as_numpy_dtype
146                          for dtype in self._all_tf_types
147                          if not dtype.is_quantized)
148    self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types])
149    self.signed_int_types = set(dtype.as_numpy_dtype
150                                for dtype in self.int_tf_types
151                                if not dtype.is_unsigned)
152    self.unsigned_int_types = set(dtype.as_numpy_dtype
153                                  for dtype in self.int_tf_types
154                                  if dtype.is_unsigned)
155    self._float_types = set(
156        [dtype.as_numpy_dtype for dtype in self._float_tf_types])
157    self.complex_types = set([
158        dtype.as_numpy_dtype for dtype in self.complex_tf_types
159    ])
160    self._numeric_types = set(self._int_types | self._float_types
161                              | self.complex_types)
162
163    # Parse the manifest file, if any, into a regex identifying tests to
164    # disable
165    # TODO(xpan): Make it text proto if it doesn't scale.
166    # Each line of the manifest file specifies an entry. The entry can be
167    # 1) TestNameRegex  // E.g. CumprodTest.* Or
168    # 2) TestName TypeName  // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16
169    # The 1) disables the entire test. While 2) only filter some numeric types
170    # so that they are not used in those tests.
171    self.disabled_regex = None
172    self._method_types_filter = {}
173
174    if FLAGS.disabled_manifest is not None:
175      with open(FLAGS.disabled_manifest, 'r') as manifest_file:
176        disabled_regex, self._method_types_filter = (
177            parse_disabled_manifest(manifest_file.read()))
178        if disabled_regex:
179          self.disabled_regex = re.compile(disabled_regex)
180
181    if FLAGS.tf_xla_flags is not None:
182      os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags
183
184  @property
185  def all_tf_types(self):
186    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
187    tf_types = set([dtypes.as_dtype(t)
188                    for t in self._method_types_filter.get(name, set())])
189    return self._all_tf_types - tf_types
190
191  @property
192  def float_types(self):
193    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
194    return self._float_types - self._method_types_filter.get(name, set())
195
196  @property
197  def float_tf_types(self):
198    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
199    return self._float_tf_types - self._method_types_filter.get(name, set())
200
201  @property
202  def int_types(self):
203    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
204    return self._int_types - self._method_types_filter.get(name, set())
205
206  @property
207  def numeric_tf_types(self):
208    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
209    tf_types = set([dtypes.as_dtype(t)
210                    for t in self._method_types_filter.get(name, set())])
211    return self._numeric_tf_types - tf_types
212
213  @property
214  def numeric_types(self):
215    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
216    return self._numeric_types - self._method_types_filter.get(name, set())
217
218  @property
219  def all_types(self):
220    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
221    return self._all_types - self._method_types_filter.get(name, set())
222
223  def setUp(self):
224    super(XLATestCase, self).setUp()
225    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
226    if self.disabled_regex is not None and self.disabled_regex.match(name):
227      logging.info('Disabled test case: %s', name)
228      self.skipTest('{} is disabled by manifest.'.format(name))
229      return
230    logging.info('Start test case: %s', name)
231
232    random.seed(random_seed.DEFAULT_GRAPH_SEED)
233    np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
234
235  def tearDown(self):
236    super(XLATestCase, self).tearDown()
237    logging.info('End test case: %s', self._testMethodName)
238
239  @contextlib.contextmanager
240  def session(self):
241    """Custom implementation of session() for XLA tests.
242
243    We override the standard Tensorflow session() since it is too
244    specific to CPU and GPU tests. In particular, we want to disable soft
245    placement and explicitly assign ops to devices under test.
246
247    Yields:
248      A session to use when running a test case.
249    """
250    graph = ops.Graph()
251    config = context.context().config
252
253    # Grappler can constant fold TensorListFromTensor ops into DT_VARIANT
254    # constants which XLA does not understand.  So disable constant folding in
255    # these tests.
256    config.graph_options.rewrite_options.constant_folding = (
257        rewriter_config_pb2.RewriterConfig.OFF)
258
259    if self.rewrite_ops_for_tpu:
260      session_type = TPURewriteSession
261    else:
262      session_type = session.Session
263
264    with session_type(graph=graph, config=config) as sess, graph.as_default():
265      yield sess
266
267  def cached_session(self):
268    raise NotImplementedError(
269        'cached_session not supported on XLATestCase, please use session')
270
271  def test_session(self):
272    raise NotImplementedError(
273        'test_session not supported on XLATestCase, please use session')
274
275  @contextlib.contextmanager
276  def device_scope(self):
277    """Scope that runs tests on `self.device`.
278
279    Yields:
280      A scope to apply to the operators under test.
281    """
282    with ops.device('device:{}:0'.format(self.device)):
283      yield
284
285  def test_scope(self):
286    """Deprecated alias of `device_scope`.
287
288    This should be avoided as the name starts with `test`, so test runners
289    treat it as a test. This interferes with class decorators that operate on
290    each test method.
291    """
292    return self.device_scope()
293
294
295def Benchmark(tf_bench,
296              builder_fn,
297              use_xla_jit,
298              device,
299              separate_compiled_gradients=False):
300  """Build a graph and run benchmarks against it, with or without XLA.
301
302  Args:
303    tf_bench: An instance of tf.test.Benchmark, used to run the benchmark.
304    builder_fn: A function that builds a graph when invoked, and returns
305        (name, fetches), where name is the name of the test, and fetches
306        is a list of tensors to fetch as output.
307    use_xla_jit: If true compile with the XLA JIT, otherwise use regular TF.
308    device: The tensorflow device to run on, e.g. "cpu", "gpu".
309    separate_compiled_gradients: If true put each gradient subgraph into a
310      separate compilation scope. This gives fine-grained control over which
311      portions of the graph will be compiled as a single unit. Compiling
312      gradients separately may yield better performance for some graphs.
313      The scope is named based on the scope of the forward computation as well
314      as the name of the gradients. As a result, the gradients will be compiled
315      in a scope that is separate from both the forward computation, and from
316      other gradients.
317  """
318
319  with ops.Graph().as_default():
320    name = None
321    targets = []
322    with ops.device(device):
323      fetches = []
324      jit_scope = jit.experimental_jit_scope
325      with jit_scope(
326          compile_ops=use_xla_jit,
327          separate_compiled_gradients=separate_compiled_gradients):
328        name, fetches = builder_fn()
329
330      # We only want to benchmark the operations themselves, and not the data
331      # transfer of the result(s).  Non-compiled identity ops ensure XLA
332      # doesn't know we're dropping the results, otherwise it might compile
333      # away the entire computation.
334      for fetch in fetches:
335        targets.append(array_ops.identity(fetch).op)
336
337    # TODO(b/132430685):  Should we allow soft placement here?
338    config = config_pb2.ConfigProto(allow_soft_placement=True)
339    with session.Session(config=config) as sess:
340      sess.run(variables.global_variables_initializer())
341      xla = 'xla_' if use_xla_jit else ''
342      tf_bench.run_op_benchmark(
343          sess, targets, name='%s_%s%s' % (name, xla, device))
344