xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/docs/base_dir.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Opensource base_dir configuration for tensorflow doc-generator."""
16import distutils
17from os import path
18
19import keras_preprocessing
20import tensorboard
21import tensorflow as tf
22from tensorflow_docs.api_generator import public_api
23import tensorflow_estimator
24
25
26try:
27  import keras  # pylint: disable=g-import-not-at-top
28except ImportError:
29  pass
30
31
32def get_base_dirs_and_prefixes(code_url_prefix):
33  """Returns the base_dirs and code_prefixes for OSS TensorFlow api gen."""
34  base_dir = path.dirname(tf.__file__)
35
36  if distutils.version.LooseVersion(tf.__version__) >= "2.9":
37    base_dirs = [
38        base_dir,
39        path.dirname(keras.__file__),
40        path.dirname(tensorboard.__file__),
41        path.dirname(tensorflow_estimator.__file__),
42    ]
43
44  elif distutils.version.LooseVersion(tf.__version__) >= "2.6":
45    base_dirs = [
46        base_dir,
47        path.dirname(keras.__file__),
48        path.dirname(keras_preprocessing.__file__),
49        path.dirname(tensorboard.__file__),
50        path.dirname(tensorflow_estimator.__file__),
51    ]
52  elif distutils.version.LooseVersion(tf.__version__) >= "2.2":
53    base_dirs = [
54        base_dir,
55        path.dirname(keras_preprocessing.__file__),
56        path.dirname(tensorboard.__file__),
57        path.dirname(tensorflow_estimator.__file__),
58    ]
59  else:
60    base_dirs = [
61        path.normpath(path.join(base_dir, "../tensorflow_core")),
62        path.dirname(keras_preprocessing.__file__),
63        path.dirname(tensorboard.__file__),
64        path.dirname(tensorflow_estimator.__file__),
65    ]
66
67  if "dev" in tf.__version__:
68    keras_url_prefix = "https://github.com/keras-team/keras/tree/master/keras"
69  else:
70    keras_url_prefix = f"https://github.com/keras-team/keras/tree/v{keras.__version__}/keras"
71
72  if distutils.version.LooseVersion(tf.__version__) >= "2.9":
73    code_url_prefixes = (
74        code_url_prefix,
75        keras_url_prefix,
76        f"https://github.com/tensorflow/tensorboard/tree/{tensorboard.__version__}/tensorboard",
77        "https://github.com/tensorflow/estimator/tree/master/tensorflow_estimator",
78    )
79  elif distutils.version.LooseVersion(tf.__version__) >= "2.6":
80    code_url_prefixes = (
81        code_url_prefix,
82        keras_url_prefix,
83        f"https://github.com/keras-team/keras-preprocessing/tree/{keras_preprocessing.__version__}/keras_preprocessing",
84        f"https://github.com/tensorflow/tensorboard/tree/{tensorboard.__version__}/tensorboard",
85        "https://github.com/tensorflow/estimator/tree/master/tensorflow_estimator",
86    )
87  else:
88    code_url_prefixes = (
89        code_url_prefix,
90        f"https://github.com/keras-team/keras-preprocessing/tree/{keras_preprocessing.__version__}/keras_preprocessing",
91        f"https://github.com/tensorflow/tensorboard/tree/{tensorboard.__version__}/tensorboard",
92        "https://github.com/tensorflow/estimator/tree/master/tensorflow_estimator",
93    )
94
95  return base_dirs, code_url_prefixes
96
97
98def explicit_filter_keep_keras(parent_path, parent, children):
99  """Like explicit_package_contents_filter, but keeps keras."""
100  new_children = public_api.explicit_package_contents_filter(
101      parent_path, parent, children)
102
103  if parent_path[-1] not in ["tf", "v1", "v2"]:
104    return new_children
105
106  had_keras = any(name == "keras" for name, child in children)
107  has_keras = any(name == "keras" for name, child in new_children)
108
109  if had_keras and not has_keras:
110    new_children.append(("keras", parent.keras))
111
112  return sorted(new_children, key=lambda x: x[0])
113
114
115def get_callbacks():
116  if distutils.version.LooseVersion(tf.__version__) >= "2.9":
117    return [explicit_filter_keep_keras]
118  else:
119    return []
120