xref: /aosp_15_r20/external/google-cloud-java/owl-bot-postprocessor/synthtool/transforms.py (revision 55e87721aa1bc457b326496a7ca40f3ea1a63287)
1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from pathlib import Path
16import shutil
17from typing import Callable, Iterable, Union, List, Optional
18import os
19import re
20import sys
21
22from synthtool import _tracked_paths
23from synthtool.log import logger
24from synthtool import metadata
25
26PathOrStr = Union[str, Path]
27ListOfPathsOrStrs = Iterable[Union[str, Path]]
28
29
30class MissingSourceError(Exception):
31    pass
32
33
34def _expand_paths(paths: ListOfPathsOrStrs, root: PathOrStr = None) -> Iterable[Path]:
35    """Given a list of globs/paths, expands them into a flat sequence,
36    expanding globs as necessary."""
37    if paths is None:
38        return []
39
40    if isinstance(paths, (str, Path)):
41        paths = [paths]
42
43    if root is None:
44        root = Path(".")
45
46    # ensure root is a path
47    root = Path(root)
48
49    # record name of synth script so we don't try to do transforms on it
50    synth_script_name = sys.argv[0]
51
52    for path in paths:
53        if isinstance(path, Path):
54            if path.is_absolute():
55                anchor = Path(path.anchor)
56                remainder = str(path.relative_to(path.anchor))
57                yield from anchor.glob(remainder)
58            else:
59                yield from root.glob(str(path))
60        else:
61            yield from (
62                p
63                for p in root.glob(path)
64                if p.absolute() != Path(synth_script_name).absolute()
65            )
66
67
68def _filter_files(paths: Iterable[Path]) -> Iterable[Path]:
69    """Returns only the paths that are files (no directories)."""
70
71    return (path for path in paths if path.is_file() and os.access(path, os.W_OK))
72
73
74def _merge_file(
75    source_path: Path, dest_path: Path, merge: Callable[[str, str, Path], str]
76):
77    """
78    Writes to the destination the result of merging the source with the
79    existing destination contents, using the given merge function.
80
81    The merge function must take three arguments: the source contents, the
82    old destination contents, and a Path to the file to be written.
83    """
84
85    with source_path.open("r") as source_file:
86        source_text = source_file.read()
87
88    with dest_path.open("r+") as dest_file:
89        dest_text = dest_file.read()
90
91        final_text = merge(source_text, dest_text, dest_path)
92
93        # use the source file's file permission mode
94        os.chmod(dest_path, os.stat(source_path).st_mode)
95        if final_text != dest_text:
96            dest_file.seek(0)
97            dest_file.write(final_text)
98            dest_file.truncate()
99        else:
100            dest_path.touch()
101
102
103def _copy_dir_to_existing_dir(
104    source: Path,
105    destination: Path,
106    excludes: ListOfPathsOrStrs = None,
107    merge: Callable[[str, str, Path], str] = None,
108) -> bool:
109    """
110    copies files over existing files to an existing directory
111    this function does not copy empty directories.
112
113    Returns: True if any files were copied, False otherwise.
114    """
115    copied = False
116
117    if not excludes:
118        excludes = []
119    for root, _, files in os.walk(source):
120        for name in files:
121            rel_path = str(Path(root).relative_to(source))
122            dest_dir = destination / rel_path
123            dest_path = dest_dir / name
124            exclude = [
125                e
126                for e in excludes
127                if (
128                    Path(e) == _tracked_paths.relativize(root)
129                    or Path(e) == _tracked_paths.relativize(Path(root) / name)
130                )
131            ]
132            if not exclude:
133                os.makedirs(str(dest_dir), exist_ok=True)
134                source_path = Path(os.path.join(root, name))
135                if merge is not None and dest_path.is_file():
136                    try:
137                        _merge_file(source_path, dest_path, merge)
138                    except Exception:
139                        logger.exception(
140                            "_merge_file failed for %s, fall back to copy",
141                            source_path,
142                        )
143                        shutil.copy2(str(source_path), str(dest_path))
144                else:
145                    shutil.copy2(str(source_path), str(dest_path))
146                copied = True
147
148    return copied
149
150
151def dont_overwrite(
152    patterns: ListOfPathsOrStrs,
153) -> Callable[[str, str, Path], str]:
154    """Returns a merge function that doesn't overwrite the specified files.
155
156    Pass the return value to move() or copy() to avoid overwriting existing
157    files.
158    """
159
160    def merge(source_text: str, destinaton_text: str, file_path: Path) -> str:
161        for pattern in patterns:
162            if file_path.match(str(pattern)):
163                logger.debug(f"Preserving existing contents of {file_path}.")
164                return destinaton_text
165        return source_text
166
167    return merge
168
169
170def move(
171    sources: ListOfPathsOrStrs,
172    destination: PathOrStr = None,
173    excludes: ListOfPathsOrStrs = None,
174    merge: Callable[[str, str, Path], str] = None,
175    required: bool = False,
176) -> bool:
177    """
178    copy file(s) at source to current directory, preserving file mode.
179
180    Args:
181        sources (ListOfPathsOrStrs): Glob pattern(s) to copy
182        destination (PathOrStr): Destination folder for copied files
183        excludes (ListOfPathsOrStrs): Glob pattern(s) of files to skip
184        merge (Callable[[str, str, Path], str]): Callback function for merging files
185            if there is an existing file.
186        required (bool): If required and no source files are copied, throws a MissingSourceError
187
188    Returns:
189        True if any files were copied, False otherwise.
190    """
191    copied = False
192
193    for excluded_pattern in excludes or []:
194        metadata.add_pattern_excluded_during_copy(str(excluded_pattern))
195
196    for source in _expand_paths(sources):
197        if destination is None:
198            canonical_destination = _tracked_paths.relativize(source)
199        else:
200            canonical_destination = Path(destination)
201
202        if excludes:
203            excludes = [
204                _tracked_paths.relativize(e) for e in _expand_paths(excludes, source)
205            ]
206        else:
207            excludes = []
208        if source.is_dir():
209            copied = copied or _copy_dir_to_existing_dir(
210                source, canonical_destination, excludes=excludes, merge=merge
211            )
212        elif source not in excludes:
213            # copy individual file
214            if merge is not None and canonical_destination.is_file():
215                try:
216                    _merge_file(source, canonical_destination, merge)
217                except Exception:
218                    logger.exception(
219                        "_merge_file failed for %s, fall back to copy", source
220                    )
221                    shutil.copy2(source, canonical_destination)
222            else:
223                shutil.copy2(source, canonical_destination)
224            copied = True
225
226    if not copied:
227        if required:
228            raise MissingSourceError(
229                f"No files in sources {sources} were copied. Does the source "
230                f"contain files?"
231            )
232        else:
233            logger.warning(
234                f"No files in sources {sources} were copied. Does the source "
235                f"contain files?"
236            )
237
238    return copied
239
240
241def _replace_in_file(path, expr, replacement):
242    try:
243        with path.open("r+") as fh:
244            return _replace_in_file_handle(fh, expr, replacement)
245    except UnicodeDecodeError:
246        pass  # It's a binary file.  Try again with a binary regular expression.
247    flags = expr.flags & ~re.UNICODE
248    expr = re.compile(expr.pattern.encode(), flags)
249    with path.open("rb+") as fh:
250        return _replace_in_file_handle(fh, expr, replacement.encode())
251
252
253def _replace_in_file_handle(fh, expr, replacement):
254    content = fh.read()
255    content, count = expr.subn(replacement, content)
256
257    # Don't bother writing the file if we didn't change
258    # anything.
259    if count:
260        fh.seek(0)
261        fh.write(content)
262        fh.truncate()
263    return count
264
265
266def replace(
267    sources: ListOfPathsOrStrs, before: str, after: str, flags: int = re.MULTILINE
268) -> int:
269    """Replaces occurrences of before with after in all the given sources.
270
271    Returns:
272      The number of times the text was found and replaced across all files.
273    """
274    expr = re.compile(before, flags=flags or 0)
275    paths = _filter_files(_expand_paths(sources, "."))
276
277    if not paths:
278        logger.warning(f"No files were found in sources {sources} for replace()")
279
280    count_replaced = 0
281    for path in paths:
282        replaced = _replace_in_file(path, expr, after)
283        count_replaced += replaced
284        if replaced:
285            logger.info(f"Replaced {before!r} in {path}.")
286
287    if not count_replaced:
288        logger.warning(
289            f"No replacements made in {sources} for pattern {before}, maybe "
290            "replacement is no longer needed?"
291        )
292    return count_replaced
293
294
295def get_staging_dirs(
296    default_version: Optional[str] = None, staging_path: Optional[str] = None
297) -> List[Path]:
298    """Returns the list of directories, one per version, copied from
299    https://github.com/googleapis/googleapis-gen. Will return in lexical sorting
300    order with the exception of the default_version which will be last (if specified).
301
302    Args:
303      default_version: the default version of the API. The directory for this version
304        will be the last item in the returned list if specified.
305      staging_path: the path to the staging directory.
306
307    Returns: the empty list if no file were copied.
308    """
309
310    if staging_path:
311        staging = Path(staging_path)
312    else:
313        staging = Path("owl-bot-staging")
314    if staging.is_dir():
315        # Collect the subdirectories of the staging directory.
316        versions = [v.name for v in staging.iterdir() if v.is_dir()]
317        # Reorder the versions so the default version always comes last.
318        versions = [v for v in versions if v != default_version]
319        versions.sort()
320        if default_version is not None:
321            versions += [default_version]
322        dirs = [staging / v for v in versions]
323        for dir in dirs:
324            _tracked_paths.add(dir)
325        return dirs
326    else:
327        return []
328
329
330def remove_staging_dirs():
331    """Removes all the staging directories."""
332    staging = Path("owl-bot-staging")
333    if staging.is_dir():
334        shutil.rmtree(staging)
335