1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from pathlib import Path 16import shutil 17from typing import Callable, Iterable, Union, List, Optional 18import os 19import re 20import sys 21 22from synthtool import _tracked_paths 23from synthtool.log import logger 24from synthtool import metadata 25 26PathOrStr = Union[str, Path] 27ListOfPathsOrStrs = Iterable[Union[str, Path]] 28 29 30class MissingSourceError(Exception): 31 pass 32 33 34def _expand_paths(paths: ListOfPathsOrStrs, root: PathOrStr = None) -> Iterable[Path]: 35 """Given a list of globs/paths, expands them into a flat sequence, 36 expanding globs as necessary.""" 37 if paths is None: 38 return [] 39 40 if isinstance(paths, (str, Path)): 41 paths = [paths] 42 43 if root is None: 44 root = Path(".") 45 46 # ensure root is a path 47 root = Path(root) 48 49 # record name of synth script so we don't try to do transforms on it 50 synth_script_name = sys.argv[0] 51 52 for path in paths: 53 if isinstance(path, Path): 54 if path.is_absolute(): 55 anchor = Path(path.anchor) 56 remainder = str(path.relative_to(path.anchor)) 57 yield from anchor.glob(remainder) 58 else: 59 yield from root.glob(str(path)) 60 else: 61 yield from ( 62 p 63 for p in root.glob(path) 64 if p.absolute() != Path(synth_script_name).absolute() 65 ) 66 67 68def _filter_files(paths: Iterable[Path]) -> Iterable[Path]: 69 """Returns only the paths that are files (no directories).""" 70 71 return (path for path in paths if path.is_file() and os.access(path, os.W_OK)) 72 73 74def _merge_file( 75 source_path: Path, dest_path: Path, merge: Callable[[str, str, Path], str] 76): 77 """ 78 Writes to the destination the result of merging the source with the 79 existing destination contents, using the given merge function. 80 81 The merge function must take three arguments: the source contents, the 82 old destination contents, and a Path to the file to be written. 83 """ 84 85 with source_path.open("r") as source_file: 86 source_text = source_file.read() 87 88 with dest_path.open("r+") as dest_file: 89 dest_text = dest_file.read() 90 91 final_text = merge(source_text, dest_text, dest_path) 92 93 # use the source file's file permission mode 94 os.chmod(dest_path, os.stat(source_path).st_mode) 95 if final_text != dest_text: 96 dest_file.seek(0) 97 dest_file.write(final_text) 98 dest_file.truncate() 99 else: 100 dest_path.touch() 101 102 103def _copy_dir_to_existing_dir( 104 source: Path, 105 destination: Path, 106 excludes: ListOfPathsOrStrs = None, 107 merge: Callable[[str, str, Path], str] = None, 108) -> bool: 109 """ 110 copies files over existing files to an existing directory 111 this function does not copy empty directories. 112 113 Returns: True if any files were copied, False otherwise. 114 """ 115 copied = False 116 117 if not excludes: 118 excludes = [] 119 for root, _, files in os.walk(source): 120 for name in files: 121 rel_path = str(Path(root).relative_to(source)) 122 dest_dir = destination / rel_path 123 dest_path = dest_dir / name 124 exclude = [ 125 e 126 for e in excludes 127 if ( 128 Path(e) == _tracked_paths.relativize(root) 129 or Path(e) == _tracked_paths.relativize(Path(root) / name) 130 ) 131 ] 132 if not exclude: 133 os.makedirs(str(dest_dir), exist_ok=True) 134 source_path = Path(os.path.join(root, name)) 135 if merge is not None and dest_path.is_file(): 136 try: 137 _merge_file(source_path, dest_path, merge) 138 except Exception: 139 logger.exception( 140 "_merge_file failed for %s, fall back to copy", 141 source_path, 142 ) 143 shutil.copy2(str(source_path), str(dest_path)) 144 else: 145 shutil.copy2(str(source_path), str(dest_path)) 146 copied = True 147 148 return copied 149 150 151def dont_overwrite( 152 patterns: ListOfPathsOrStrs, 153) -> Callable[[str, str, Path], str]: 154 """Returns a merge function that doesn't overwrite the specified files. 155 156 Pass the return value to move() or copy() to avoid overwriting existing 157 files. 158 """ 159 160 def merge(source_text: str, destinaton_text: str, file_path: Path) -> str: 161 for pattern in patterns: 162 if file_path.match(str(pattern)): 163 logger.debug(f"Preserving existing contents of {file_path}.") 164 return destinaton_text 165 return source_text 166 167 return merge 168 169 170def move( 171 sources: ListOfPathsOrStrs, 172 destination: PathOrStr = None, 173 excludes: ListOfPathsOrStrs = None, 174 merge: Callable[[str, str, Path], str] = None, 175 required: bool = False, 176) -> bool: 177 """ 178 copy file(s) at source to current directory, preserving file mode. 179 180 Args: 181 sources (ListOfPathsOrStrs): Glob pattern(s) to copy 182 destination (PathOrStr): Destination folder for copied files 183 excludes (ListOfPathsOrStrs): Glob pattern(s) of files to skip 184 merge (Callable[[str, str, Path], str]): Callback function for merging files 185 if there is an existing file. 186 required (bool): If required and no source files are copied, throws a MissingSourceError 187 188 Returns: 189 True if any files were copied, False otherwise. 190 """ 191 copied = False 192 193 for excluded_pattern in excludes or []: 194 metadata.add_pattern_excluded_during_copy(str(excluded_pattern)) 195 196 for source in _expand_paths(sources): 197 if destination is None: 198 canonical_destination = _tracked_paths.relativize(source) 199 else: 200 canonical_destination = Path(destination) 201 202 if excludes: 203 excludes = [ 204 _tracked_paths.relativize(e) for e in _expand_paths(excludes, source) 205 ] 206 else: 207 excludes = [] 208 if source.is_dir(): 209 copied = copied or _copy_dir_to_existing_dir( 210 source, canonical_destination, excludes=excludes, merge=merge 211 ) 212 elif source not in excludes: 213 # copy individual file 214 if merge is not None and canonical_destination.is_file(): 215 try: 216 _merge_file(source, canonical_destination, merge) 217 except Exception: 218 logger.exception( 219 "_merge_file failed for %s, fall back to copy", source 220 ) 221 shutil.copy2(source, canonical_destination) 222 else: 223 shutil.copy2(source, canonical_destination) 224 copied = True 225 226 if not copied: 227 if required: 228 raise MissingSourceError( 229 f"No files in sources {sources} were copied. Does the source " 230 f"contain files?" 231 ) 232 else: 233 logger.warning( 234 f"No files in sources {sources} were copied. Does the source " 235 f"contain files?" 236 ) 237 238 return copied 239 240 241def _replace_in_file(path, expr, replacement): 242 try: 243 with path.open("r+") as fh: 244 return _replace_in_file_handle(fh, expr, replacement) 245 except UnicodeDecodeError: 246 pass # It's a binary file. Try again with a binary regular expression. 247 flags = expr.flags & ~re.UNICODE 248 expr = re.compile(expr.pattern.encode(), flags) 249 with path.open("rb+") as fh: 250 return _replace_in_file_handle(fh, expr, replacement.encode()) 251 252 253def _replace_in_file_handle(fh, expr, replacement): 254 content = fh.read() 255 content, count = expr.subn(replacement, content) 256 257 # Don't bother writing the file if we didn't change 258 # anything. 259 if count: 260 fh.seek(0) 261 fh.write(content) 262 fh.truncate() 263 return count 264 265 266def replace( 267 sources: ListOfPathsOrStrs, before: str, after: str, flags: int = re.MULTILINE 268) -> int: 269 """Replaces occurrences of before with after in all the given sources. 270 271 Returns: 272 The number of times the text was found and replaced across all files. 273 """ 274 expr = re.compile(before, flags=flags or 0) 275 paths = _filter_files(_expand_paths(sources, ".")) 276 277 if not paths: 278 logger.warning(f"No files were found in sources {sources} for replace()") 279 280 count_replaced = 0 281 for path in paths: 282 replaced = _replace_in_file(path, expr, after) 283 count_replaced += replaced 284 if replaced: 285 logger.info(f"Replaced {before!r} in {path}.") 286 287 if not count_replaced: 288 logger.warning( 289 f"No replacements made in {sources} for pattern {before}, maybe " 290 "replacement is no longer needed?" 291 ) 292 return count_replaced 293 294 295def get_staging_dirs( 296 default_version: Optional[str] = None, staging_path: Optional[str] = None 297) -> List[Path]: 298 """Returns the list of directories, one per version, copied from 299 https://github.com/googleapis/googleapis-gen. Will return in lexical sorting 300 order with the exception of the default_version which will be last (if specified). 301 302 Args: 303 default_version: the default version of the API. The directory for this version 304 will be the last item in the returned list if specified. 305 staging_path: the path to the staging directory. 306 307 Returns: the empty list if no file were copied. 308 """ 309 310 if staging_path: 311 staging = Path(staging_path) 312 else: 313 staging = Path("owl-bot-staging") 314 if staging.is_dir(): 315 # Collect the subdirectories of the staging directory. 316 versions = [v.name for v in staging.iterdir() if v.is_dir()] 317 # Reorder the versions so the default version always comes last. 318 versions = [v for v in versions if v != default_version] 319 versions.sort() 320 if default_version is not None: 321 versions += [default_version] 322 dirs = [staging / v for v in versions] 323 for dir in dirs: 324 _tracked_paths.add(dir) 325 return dirs 326 else: 327 return [] 328 329 330def remove_staging_dirs(): 331 """Removes all the staging directories.""" 332 staging = Path("owl-bot-staging") 333 if staging.is_dir(): 334 shutil.rmtree(staging) 335