1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Opensource base_dir configuration for tensorflow doc-generator.""" 16import distutils 17from os import path 18 19import keras_preprocessing 20import tensorboard 21import tensorflow as tf 22from tensorflow_docs.api_generator import public_api 23import tensorflow_estimator 24 25 26try: 27 import keras # pylint: disable=g-import-not-at-top 28except ImportError: 29 pass 30 31 32def get_base_dirs_and_prefixes(code_url_prefix): 33 """Returns the base_dirs and code_prefixes for OSS TensorFlow api gen.""" 34 base_dir = path.dirname(tf.__file__) 35 36 if distutils.version.LooseVersion(tf.__version__) >= "2.9": 37 base_dirs = [ 38 base_dir, 39 path.dirname(keras.__file__), 40 path.dirname(tensorboard.__file__), 41 path.dirname(tensorflow_estimator.__file__), 42 ] 43 44 elif distutils.version.LooseVersion(tf.__version__) >= "2.6": 45 base_dirs = [ 46 base_dir, 47 path.dirname(keras.__file__), 48 path.dirname(keras_preprocessing.__file__), 49 path.dirname(tensorboard.__file__), 50 path.dirname(tensorflow_estimator.__file__), 51 ] 52 elif distutils.version.LooseVersion(tf.__version__) >= "2.2": 53 base_dirs = [ 54 base_dir, 55 path.dirname(keras_preprocessing.__file__), 56 path.dirname(tensorboard.__file__), 57 path.dirname(tensorflow_estimator.__file__), 58 ] 59 else: 60 base_dirs = [ 61 path.normpath(path.join(base_dir, "../tensorflow_core")), 62 path.dirname(keras_preprocessing.__file__), 63 path.dirname(tensorboard.__file__), 64 path.dirname(tensorflow_estimator.__file__), 65 ] 66 67 if "dev" in tf.__version__: 68 keras_url_prefix = "https://github.com/keras-team/keras/tree/master/keras" 69 else: 70 keras_url_prefix = f"https://github.com/keras-team/keras/tree/v{keras.__version__}/keras" 71 72 if distutils.version.LooseVersion(tf.__version__) >= "2.9": 73 code_url_prefixes = ( 74 code_url_prefix, 75 keras_url_prefix, 76 f"https://github.com/tensorflow/tensorboard/tree/{tensorboard.__version__}/tensorboard", 77 "https://github.com/tensorflow/estimator/tree/master/tensorflow_estimator", 78 ) 79 elif distutils.version.LooseVersion(tf.__version__) >= "2.6": 80 code_url_prefixes = ( 81 code_url_prefix, 82 keras_url_prefix, 83 f"https://github.com/keras-team/keras-preprocessing/tree/{keras_preprocessing.__version__}/keras_preprocessing", 84 f"https://github.com/tensorflow/tensorboard/tree/{tensorboard.__version__}/tensorboard", 85 "https://github.com/tensorflow/estimator/tree/master/tensorflow_estimator", 86 ) 87 else: 88 code_url_prefixes = ( 89 code_url_prefix, 90 f"https://github.com/keras-team/keras-preprocessing/tree/{keras_preprocessing.__version__}/keras_preprocessing", 91 f"https://github.com/tensorflow/tensorboard/tree/{tensorboard.__version__}/tensorboard", 92 "https://github.com/tensorflow/estimator/tree/master/tensorflow_estimator", 93 ) 94 95 return base_dirs, code_url_prefixes 96 97 98def explicit_filter_keep_keras(parent_path, parent, children): 99 """Like explicit_package_contents_filter, but keeps keras.""" 100 new_children = public_api.explicit_package_contents_filter( 101 parent_path, parent, children) 102 103 if parent_path[-1] not in ["tf", "v1", "v2"]: 104 return new_children 105 106 had_keras = any(name == "keras" for name, child in children) 107 has_keras = any(name == "keras" for name, child in new_children) 108 109 if had_keras and not has_keras: 110 new_children.append(("keras", parent.keras)) 111 112 return sorted(new_children, key=lambda x: x[0]) 113 114 115def get_callbacks(): 116 if distutils.version.LooseVersion(tf.__version__) >= "2.9": 117 return [explicit_filter_keep_keras] 118 else: 119 return [] 120