1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Definition of XLA test case.""" 16 17import contextlib 18import os 19import random 20import re 21 22import numpy as np 23 24from tensorflow.core.framework import types_pb2 25from tensorflow.core.protobuf import config_pb2 26from tensorflow.core.protobuf import rewriter_config_pb2 27from tensorflow.python.client import session 28from tensorflow.python.compiler.xla import jit 29from tensorflow.python.eager import context 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import random_seed 33from tensorflow.python.framework import test_util 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.platform import flags 37from tensorflow.python.platform import test 38from tensorflow.python.platform import tf_logging as logging 39 40FLAGS = flags.FLAGS 41 42flags.DEFINE_string('test_device', None, 43 'Tensorflow device on which to place operators under test') 44flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.') 45flags.DEFINE_string('disabled_manifest', None, 46 'Path to a file with a list of tests that should not run.') 47flags.DEFINE_string('tf_xla_flags', None, 48 'Value to set the TF_XLA_FLAGS environment variable to') 49 50 51def parse_disabled_manifest(manifest_content): 52 comments_re = re.compile('#.*$') 53 disabled_tests = [] 54 disabled_method_types = [] 55 for l in manifest_content.splitlines(): 56 stripped = comments_re.sub('', l).strip() 57 if not stripped: 58 continue 59 entry = stripped.split(' ') 60 if len(entry) == 1: 61 disabled_tests.append(entry[0]) 62 elif len(entry) == 2: 63 disabled_method_types.append((entry[0], entry[1].strip().split(','))) 64 else: 65 raise ValueError('Bad entry in manifest file.') 66 67 disabled_regex = '|'.join(disabled_tests) 68 method_types_filter = {} 69 for method, types in disabled_method_types: 70 method_types_filter[method] = set([ 71 dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype 72 for name in types 73 ]) 74 return disabled_regex, method_types_filter 75 76 77class TPURewriteSession(session.Session): 78 """Tensorflow session that runs tpu.rewrite() on ops on run().""" 79 80 def __init__(self, *args, **kwargs): 81 super().__init__(*args, **kwargs) 82 self.topology = None 83 84 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 85 from tensorflow.python.tpu import tpu # pylint: disable=g-import-not-at-top 86 if self.topology is None: 87 self.topology = super().run(tpu.initialize_system()) 88 assert self.topology is not None 89 fetch_mapper = session._FetchMapper.for_fetch(fetches) 90 new_fetches = [] 91 for fetch in fetch_mapper.unique_fetches(): 92 if isinstance(fetch, ops.Operation): 93 fetch = tpu.rewrite(lambda fetch=fetch: fetch) 94 new_fetches.append(fetch) 95 rewritten_fetches = fetch_mapper.build_results(new_fetches) 96 return super().run(rewritten_fetches, feed_dict, options, run_metadata) 97 98 99class XLATestCase(test.TestCase): 100 """XLA test cases are parameterized test cases.""" 101 102 def __init__(self, method_name='runTest'): 103 super(XLATestCase, self).__init__(method_name) 104 if 'XLA' in FLAGS.test_device: 105 context.context().enable_xla_devices() 106 107 # Check if the mlir bridge has been explicitly enabled or disabled. If 108 # is_mlir_bridge_enabled() returns None, the user did not explictly enable 109 # or disable the bridge so do not update enable_mlir_bridge. 110 if test_util.is_mlir_bridge_enabled(): 111 context.context().enable_mlir_bridge = True 112 elif test_util.is_mlir_bridge_enabled() is not None: 113 context.context().enable_mlir_bridge = False 114 115 self.device = FLAGS.test_device 116 self.has_custom_call = (self.device == 'XLA_CPU') 117 118 # Some tests (e.g. ftrl_ops) only work if the program goes through the 119 # _TPUCompileMLIR op. They will set this flag to True. 120 # TODO(kramm): Flip to true (and enable MLIR bridge) for more tests. 121 self.rewrite_ops_for_tpu = False 122 123 self._all_tf_types = set([ 124 dtypes.as_dtype(types_pb2.DataType.Value(name)) 125 for name in FLAGS.types.split(',') 126 ]) 127 self.int_tf_types = set([ 128 dtype for dtype in self._all_tf_types if dtype.is_integer 129 ]) 130 self._float_tf_types = set([ 131 dtype for dtype in self._all_tf_types if dtype.is_floating 132 ]) 133 self.complex_tf_types = set([ 134 dtype for dtype in self._all_tf_types if dtype.is_complex 135 ]) 136 self._numeric_tf_types = set( 137 self.int_tf_types | self._float_tf_types | self.complex_tf_types) 138 self.quantized_tf_types = set( 139 dtype for dtype in self._all_tf_types if dtype.is_quantized) 140 141 # Quantized types don't have a numpy equivalent, include them in 142 # all_tf_types but not in all_types. 143 # TODO(b/115960798): Parametrize tests on TF types instead of numpy types 144 # and remove all_types. 145 self._all_types = set(dtype.as_numpy_dtype 146 for dtype in self._all_tf_types 147 if not dtype.is_quantized) 148 self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) 149 self.signed_int_types = set(dtype.as_numpy_dtype 150 for dtype in self.int_tf_types 151 if not dtype.is_unsigned) 152 self.unsigned_int_types = set(dtype.as_numpy_dtype 153 for dtype in self.int_tf_types 154 if dtype.is_unsigned) 155 self._float_types = set( 156 [dtype.as_numpy_dtype for dtype in self._float_tf_types]) 157 self.complex_types = set([ 158 dtype.as_numpy_dtype for dtype in self.complex_tf_types 159 ]) 160 self._numeric_types = set(self._int_types | self._float_types 161 | self.complex_types) 162 163 # Parse the manifest file, if any, into a regex identifying tests to 164 # disable 165 # TODO(xpan): Make it text proto if it doesn't scale. 166 # Each line of the manifest file specifies an entry. The entry can be 167 # 1) TestNameRegex // E.g. CumprodTest.* Or 168 # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16 169 # The 1) disables the entire test. While 2) only filter some numeric types 170 # so that they are not used in those tests. 171 self.disabled_regex = None 172 self._method_types_filter = {} 173 174 if FLAGS.disabled_manifest is not None: 175 with open(FLAGS.disabled_manifest, 'r') as manifest_file: 176 disabled_regex, self._method_types_filter = ( 177 parse_disabled_manifest(manifest_file.read())) 178 if disabled_regex: 179 self.disabled_regex = re.compile(disabled_regex) 180 181 if FLAGS.tf_xla_flags is not None: 182 os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags 183 184 @property 185 def all_tf_types(self): 186 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 187 tf_types = set([dtypes.as_dtype(t) 188 for t in self._method_types_filter.get(name, set())]) 189 return self._all_tf_types - tf_types 190 191 @property 192 def float_types(self): 193 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 194 return self._float_types - self._method_types_filter.get(name, set()) 195 196 @property 197 def float_tf_types(self): 198 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 199 return self._float_tf_types - self._method_types_filter.get(name, set()) 200 201 @property 202 def int_types(self): 203 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 204 return self._int_types - self._method_types_filter.get(name, set()) 205 206 @property 207 def numeric_tf_types(self): 208 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 209 tf_types = set([dtypes.as_dtype(t) 210 for t in self._method_types_filter.get(name, set())]) 211 return self._numeric_tf_types - tf_types 212 213 @property 214 def numeric_types(self): 215 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 216 return self._numeric_types - self._method_types_filter.get(name, set()) 217 218 @property 219 def all_types(self): 220 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 221 return self._all_types - self._method_types_filter.get(name, set()) 222 223 def setUp(self): 224 super(XLATestCase, self).setUp() 225 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 226 if self.disabled_regex is not None and self.disabled_regex.match(name): 227 logging.info('Disabled test case: %s', name) 228 self.skipTest('{} is disabled by manifest.'.format(name)) 229 return 230 logging.info('Start test case: %s', name) 231 232 random.seed(random_seed.DEFAULT_GRAPH_SEED) 233 np.random.seed(random_seed.DEFAULT_GRAPH_SEED) 234 235 def tearDown(self): 236 super(XLATestCase, self).tearDown() 237 logging.info('End test case: %s', self._testMethodName) 238 239 @contextlib.contextmanager 240 def session(self): 241 """Custom implementation of session() for XLA tests. 242 243 We override the standard Tensorflow session() since it is too 244 specific to CPU and GPU tests. In particular, we want to disable soft 245 placement and explicitly assign ops to devices under test. 246 247 Yields: 248 A session to use when running a test case. 249 """ 250 graph = ops.Graph() 251 config = context.context().config 252 253 # Grappler can constant fold TensorListFromTensor ops into DT_VARIANT 254 # constants which XLA does not understand. So disable constant folding in 255 # these tests. 256 config.graph_options.rewrite_options.constant_folding = ( 257 rewriter_config_pb2.RewriterConfig.OFF) 258 259 if self.rewrite_ops_for_tpu: 260 session_type = TPURewriteSession 261 else: 262 session_type = session.Session 263 264 with session_type(graph=graph, config=config) as sess, graph.as_default(): 265 yield sess 266 267 def cached_session(self): 268 raise NotImplementedError( 269 'cached_session not supported on XLATestCase, please use session') 270 271 def test_session(self): 272 raise NotImplementedError( 273 'test_session not supported on XLATestCase, please use session') 274 275 @contextlib.contextmanager 276 def device_scope(self): 277 """Scope that runs tests on `self.device`. 278 279 Yields: 280 A scope to apply to the operators under test. 281 """ 282 with ops.device('device:{}:0'.format(self.device)): 283 yield 284 285 def test_scope(self): 286 """Deprecated alias of `device_scope`. 287 288 This should be avoided as the name starts with `test`, so test runners 289 treat it as a test. This interferes with class decorators that operate on 290 each test method. 291 """ 292 return self.device_scope() 293 294 295def Benchmark(tf_bench, 296 builder_fn, 297 use_xla_jit, 298 device, 299 separate_compiled_gradients=False): 300 """Build a graph and run benchmarks against it, with or without XLA. 301 302 Args: 303 tf_bench: An instance of tf.test.Benchmark, used to run the benchmark. 304 builder_fn: A function that builds a graph when invoked, and returns 305 (name, fetches), where name is the name of the test, and fetches 306 is a list of tensors to fetch as output. 307 use_xla_jit: If true compile with the XLA JIT, otherwise use regular TF. 308 device: The tensorflow device to run on, e.g. "cpu", "gpu". 309 separate_compiled_gradients: If true put each gradient subgraph into a 310 separate compilation scope. This gives fine-grained control over which 311 portions of the graph will be compiled as a single unit. Compiling 312 gradients separately may yield better performance for some graphs. 313 The scope is named based on the scope of the forward computation as well 314 as the name of the gradients. As a result, the gradients will be compiled 315 in a scope that is separate from both the forward computation, and from 316 other gradients. 317 """ 318 319 with ops.Graph().as_default(): 320 name = None 321 targets = [] 322 with ops.device(device): 323 fetches = [] 324 jit_scope = jit.experimental_jit_scope 325 with jit_scope( 326 compile_ops=use_xla_jit, 327 separate_compiled_gradients=separate_compiled_gradients): 328 name, fetches = builder_fn() 329 330 # We only want to benchmark the operations themselves, and not the data 331 # transfer of the result(s). Non-compiled identity ops ensure XLA 332 # doesn't know we're dropping the results, otherwise it might compile 333 # away the entire computation. 334 for fetch in fetches: 335 targets.append(array_ops.identity(fetch).op) 336 337 # TODO(b/132430685): Should we allow soft placement here? 338 config = config_pb2.ConfigProto(allow_soft_placement=True) 339 with session.Session(config=config) as sess: 340 sess.run(variables.global_variables_initializer()) 341 xla = 'xla_' if use_xla_jit else '' 342 tf_bench.run_op_benchmark( 343 sess, targets, name='%s_%s%s' % (name, xla, device)) 344