xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/docs/tf_doctest.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Run doctests for tensorflow."""
16
17import importlib
18import os
19import pkgutil
20import sys
21
22from absl import flags
23from absl.testing import absltest
24import numpy as np
25import tensorflow.compat.v2 as tf
26
27# Prevent Python exception from circular dependencies (b/117329403) looking very
28# similar to https://bugs.python.org/issue43546.
29from tensorflow.python.distribute import distribution_strategy_context  # pylint: disable=unused-import
30from tensorflow.python.ops import logging_ops
31
32from tensorflow.tools.docs import tf_doctest_lib
33
34# We put doctest after absltest so that it picks up the unittest monkeypatch.
35# Otherwise doctest tests aren't runnable at all.
36import doctest  # pylint: disable=g-bad-import-order
37
38tf.compat.v1.enable_v2_behavior()
39
40# `enable_interactive_logging` must come after `enable_v2_behavior`.
41logging_ops.enable_interactive_logging()
42
43FLAGS = flags.FLAGS
44
45flags.DEFINE_list('module', [], 'A list of specific module to run doctest on.')
46flags.DEFINE_list('module_prefix_skip', [],
47                  'A list of modules to ignore when resolving modules.')
48flags.DEFINE_boolean('list', None,
49                     'List all the modules in the core package imported.')
50flags.DEFINE_integer('required_gpus', 0,
51                     'The number of GPUs required for the tests.')
52
53# Both --module and --module_prefix_skip are relative to PACKAGE.
54PACKAGES = [
55    'tensorflow.python.',
56    'tensorflow.lite.python.',
57]
58
59
60def recursive_import(root):
61  """Recursively imports all the sub-modules under a root package.
62
63  Args:
64    root: A python package.
65  """
66  for _, name, _ in pkgutil.walk_packages(
67      root.__path__, prefix=root.__name__ + '.'):
68    try:
69      importlib.import_module(name)
70    except (AttributeError, ImportError):
71      pass
72
73
74def find_modules():
75  """Finds all the modules in the core package imported.
76
77  Returns:
78    A list containing all the modules in tensorflow.python.
79  """
80
81  tf_modules = []
82  for name, module in sys.modules.items():
83    # The below for loop is a constant time loop.
84    for package in PACKAGES:
85      if name.startswith(package):
86        tf_modules.append(module)
87
88  return tf_modules
89
90
91def filter_on_submodules(all_modules, submodules):
92  """Filters all the modules based on the modules flag.
93
94  The module flag has to be relative to the core package imported.
95  For example, if `module=keras.layers` then, this function will return
96  all the modules in the submodule.
97
98  Args:
99    all_modules: All the modules in the core package.
100    submodules: Submodules to filter from all the modules.
101
102  Returns:
103    All the modules in the submodule.
104  """
105
106  filtered_modules = []
107
108  for mod in all_modules:
109    for submodule in submodules:
110      # The below for loop is a constant time loop.
111      for package in PACKAGES:
112        if package + submodule in mod.__name__:
113          filtered_modules.append(mod)
114
115  return filtered_modules
116
117
118def setup_gpu(required_gpus):
119  """Sets up the GPU devices.
120
121  If there're more available GPUs than needed, it hides the additional ones. If
122  there're less, it creates logical devices. This is to make sure the tests see
123  a fixed number of GPUs regardless of the environment.
124
125  Args:
126    required_gpus: an integer. The number of GPUs required.
127
128  Raises:
129    ValueError: if num_gpus is larger than zero but no GPU is available.
130  """
131  if required_gpus == 0:
132    return
133  available_gpus = tf.config.experimental.list_physical_devices('GPU')
134  if not available_gpus:
135    raise ValueError('requires at least one physical GPU')
136  if len(available_gpus) >= required_gpus:
137    tf.config.set_visible_devices(available_gpus[:required_gpus])
138  else:
139    # Create logical GPUs out of one physical GPU for simplicity. Note that the
140    # other physical GPUs are still available and corresponds to one logical GPU
141    # each.
142    num_logical_gpus = required_gpus - len(available_gpus) + 1
143    logical_gpus = [
144        tf.config.LogicalDeviceConfiguration(memory_limit=256)
145        for _ in range(num_logical_gpus)
146    ]
147    tf.config.set_logical_device_configuration(available_gpus[0], logical_gpus)
148
149
150class TfTestCase(tf.test.TestCase):
151
152  def set_up(self, test):
153    # Enable soft device placement to run distributed doctests.
154    tf.config.set_soft_device_placement(True)
155    self.setUp()
156
157  def tear_down(self, test):
158    self.tearDown()
159
160
161def load_tests(unused_loader, tests, unused_ignore):
162  """Loads all the tests in the docstrings and runs them."""
163
164  tf_modules = find_modules()
165
166  if FLAGS.module:
167    tf_modules = filter_on_submodules(tf_modules, FLAGS.module)
168
169  if FLAGS.list:
170    print('**************************************************')
171    for mod in tf_modules:
172      print(mod.__name__)
173    print('**************************************************')
174    return tests
175
176  test_shard_index = int(os.environ.get('TEST_SHARD_INDEX', '0'))
177  total_test_shards = int(os.environ.get('TEST_TOTAL_SHARDS', '1'))
178
179  tf_modules = sorted(tf_modules, key=lambda mod: mod.__name__)
180  for n, module in enumerate(tf_modules):
181    if (n % total_test_shards) != test_shard_index:
182      continue
183
184    # If I break the loop comprehension, then the test times out in `small`
185    # size.
186    if any(
187        module.__name__.startswith(package + prefix)  # pylint: disable=g-complex-comprehension
188        for prefix in FLAGS.module_prefix_skip for package in PACKAGES):
189      continue
190    testcase = TfTestCase()
191    tests.addTests(
192        doctest.DocTestSuite(
193            module,
194            test_finder=doctest.DocTestFinder(exclude_empty=False),
195            extraglobs={
196                'tf': tf,
197                'np': np,
198                'os': os
199            },
200            setUp=testcase.set_up,
201            tearDown=testcase.tear_down,
202            checker=tf_doctest_lib.TfDoctestOutputChecker(),
203            optionflags=(doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE
204                         | doctest.IGNORE_EXCEPTION_DETAIL
205                         | doctest.DONT_ACCEPT_BLANKLINE),
206        ))
207  return tests
208
209
210# We can only create logical devices before initializing Tensorflow. This is
211# called by unittest framework before running any test.
212# https://docs.python.org/3/library/unittest.html#setupmodule-and-teardownmodule
213def setUpModule():
214  setup_gpu(FLAGS.required_gpus)
215
216
217if __name__ == '__main__':
218  # Use importlib to import python submodule of tensorflow.
219  # We delete python submodule in root __init__.py file. This means
220  # normal import won't work for some Python versions.
221  for pkg in PACKAGES:
222    recursive_import(importlib.import_module(pkg[:-1]))
223  absltest.main()
224