xref: /aosp_15_r20/external/mesa3d/bin/ci/gitlab_gql.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1#!/usr/bin/env python3
2# For the dependencies, see the requirements.txt
3
4import logging
5import re
6import traceback
7from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace
8from collections import OrderedDict
9from copy import deepcopy
10from dataclasses import dataclass, field
11from itertools import accumulate
12from pathlib import Path
13from subprocess import check_output
14from textwrap import dedent
15from typing import Any, Iterable, Optional, Pattern, TypedDict, Union
16
17import yaml
18from filecache import DAY, filecache
19from gitlab_common import get_token_from_default_dir
20from gql import Client, gql
21from gql.transport.requests import RequestsHTTPTransport
22from graphql import DocumentNode
23
24
25class DagNode(TypedDict):
26    needs: set[str]
27    stage: str
28    # `name` is redundant but is here for retro-compatibility
29    name: str
30
31
32# see create_job_needs_dag function for more details
33Dag = dict[str, DagNode]
34
35
36StageSeq = OrderedDict[str, set[str]]
37
38
39def get_project_root_dir():
40    root_path = Path(__file__).parent.parent.parent.resolve()
41    gitlab_file = root_path / ".gitlab-ci.yml"
42    assert gitlab_file.exists()
43
44    return root_path
45
46
47@dataclass
48class GitlabGQL:
49    _transport: Any = field(init=False)
50    client: Client = field(init=False)
51    url: str = "https://gitlab.freedesktop.org/api/graphql"
52    token: Optional[str] = None
53
54    def __post_init__(self) -> None:
55        self._setup_gitlab_gql_client()
56
57    def _setup_gitlab_gql_client(self) -> None:
58        # Select your transport with a defined url endpoint
59        headers = {}
60        if self.token:
61            headers["Authorization"] = f"Bearer {self.token}"
62        self._transport = RequestsHTTPTransport(url=self.url, headers=headers)
63
64        # Create a GraphQL client using the defined transport
65        self.client = Client(transport=self._transport, fetch_schema_from_transport=True)
66
67    def query(
68        self,
69        gql_file: Union[Path, str],
70        params: dict[str, Any] = {},
71        operation_name: Optional[str] = None,
72        paginated_key_loc: Iterable[str] = [],
73        disable_cache: bool = False,
74    ) -> dict[str, Any]:
75        def run_uncached() -> dict[str, Any]:
76            if paginated_key_loc:
77                return self._sweep_pages(gql_file, params, operation_name, paginated_key_loc)
78            return self._query(gql_file, params, operation_name)
79
80        if disable_cache:
81            return run_uncached()
82
83        try:
84            # Create an auxiliary variable to deliver a cached result and enable catching exceptions
85            # Decorate the query to be cached
86            if paginated_key_loc:
87                result = self._sweep_pages_cached(
88                    gql_file, params, operation_name, paginated_key_loc
89                )
90            else:
91                result = self._query_cached(gql_file, params, operation_name)
92            return result  # type: ignore
93        except Exception as ex:
94            logging.error(f"Cached query failed with {ex}")
95            # print exception traceback
96            traceback_str = "".join(traceback.format_exception(ex))
97            logging.error(traceback_str)
98            self.invalidate_query_cache()
99            logging.error("Cache invalidated, retrying without cache")
100        finally:
101            return run_uncached()
102
103    def _query(
104        self,
105        gql_file: Union[Path, str],
106        params: dict[str, Any] = {},
107        operation_name: Optional[str] = None,
108    ) -> dict[str, Any]:
109        # Provide a GraphQL query
110        source_path: Path = Path(__file__).parent
111        pipeline_query_file: Path = source_path / gql_file
112
113        query: DocumentNode
114        with open(pipeline_query_file, "r") as f:
115            pipeline_query = f.read()
116            query = gql(pipeline_query)
117
118        # Execute the query on the transport
119        return self.client.execute_sync(
120            query, variable_values=params, operation_name=operation_name
121        )
122
123    @filecache(DAY)
124    def _sweep_pages_cached(self, *args, **kwargs):
125        return self._sweep_pages(*args, **kwargs)
126
127    @filecache(DAY)
128    def _query_cached(self, *args, **kwargs):
129        return self._query(*args, **kwargs)
130
131    def _sweep_pages(
132        self, query, params, operation_name=None, paginated_key_loc: Iterable[str] = []
133    ) -> dict[str, Any]:
134        """
135        Retrieve paginated data from a GraphQL API and concatenate the results into a single
136        response.
137
138        Args:
139            query: represents a filepath with the GraphQL query to be executed.
140            params: a dictionary that contains the parameters to be passed to the query. These
141                parameters can be used to filter or modify the results of the query.
142            operation_name: The `operation_name` parameter is an optional parameter that specifies
143                the name of the GraphQL operation to be executed. It is used when making a GraphQL
144                query to specify which operation to execute if there are multiple operations defined
145                in the GraphQL schema. If not provided, the default operation will be executed.
146            paginated_key_loc (Iterable[str]): The `paginated_key_loc` parameter is an iterable of
147                strings that represents the location of the paginated field within the response. It
148                is used to extract the paginated field from the response and append it to the final
149                result. The node has to be a list of objects with a `pageInfo` field that contains
150                at least the `hasNextPage` and `endCursor` fields.
151
152        Returns:
153            a dictionary containing the response from the query with the paginated field
154            concatenated.
155        """
156
157        def fetch_page(cursor: str | None = None) -> dict[str, Any]:
158            if cursor:
159                params["cursor"] = cursor
160                logging.info(
161                    f"Found more than 100 elements, paginating. "
162                    f"Current cursor at {cursor}"
163                )
164
165            return self._query(query, params, operation_name)
166
167        # Execute the initial query
168        response: dict[str, Any] = fetch_page()
169
170        # Initialize an empty list to store the final result
171        final_partial_field: list[dict[str, Any]] = []
172
173        # Loop until all pages have been retrieved
174        while True:
175            # Get the partial field to be appended to the final result
176            partial_field = response
177            for key in paginated_key_loc:
178                partial_field = partial_field[key]
179
180            # Append the partial field to the final result
181            final_partial_field += partial_field["nodes"]
182
183            # Check if there are more pages to retrieve
184            page_info = partial_field["pageInfo"]
185            if not page_info["hasNextPage"]:
186                break
187
188            # Execute the query with the updated cursor parameter
189            response = fetch_page(page_info["endCursor"])
190
191        # Replace the "nodes" field in the original response with the final result
192        partial_field["nodes"] = final_partial_field
193        return response
194
195    def invalidate_query_cache(self) -> None:
196        logging.warning("Invalidating query cache")
197        try:
198            self._sweep_pages._db.clear()
199            self._query._db.clear()
200        except AttributeError as ex:
201            logging.warning(f"Could not invalidate cache, maybe it was not used in {ex.args}?")
202
203
204def insert_early_stage_jobs(stage_sequence: StageSeq, jobs_metadata: Dag) -> Dag:
205    pre_processed_dag: dict[str, set[str]] = {}
206    jobs_from_early_stages = list(accumulate(stage_sequence.values(), set.union))
207    for job_name, metadata in jobs_metadata.items():
208        final_needs: set[str] = deepcopy(metadata["needs"])
209        # Pre-process jobs that are not based on needs field
210        # e.g. sanity job in mesa MR pipelines
211        if not final_needs:
212            job_stage: str = jobs_metadata[job_name]["stage"]
213            stage_index: int = list(stage_sequence.keys()).index(job_stage)
214            if stage_index > 0:
215                final_needs |= jobs_from_early_stages[stage_index - 1]
216        pre_processed_dag[job_name] = final_needs
217
218    for job_name, needs in pre_processed_dag.items():
219        jobs_metadata[job_name]["needs"] = needs
220
221    return jobs_metadata
222
223
224def traverse_dag_needs(jobs_metadata: Dag) -> None:
225    created_jobs = set(jobs_metadata.keys())
226    for job, metadata in jobs_metadata.items():
227        final_needs: set = deepcopy(metadata["needs"]) & created_jobs
228        # Post process jobs that are based on needs field
229        partial = True
230
231        while partial:
232            next_depth: set[str] = {n for dn in final_needs if dn in jobs_metadata for n in jobs_metadata[dn]["needs"]}
233            partial: bool = not final_needs.issuperset(next_depth)
234            final_needs = final_needs.union(next_depth)
235
236        jobs_metadata[job]["needs"] = final_needs
237
238
239def extract_stages_and_job_needs(
240    pipeline_jobs: dict[str, Any], pipeline_stages: dict[str, Any]
241) -> tuple[StageSeq, Dag]:
242    jobs_metadata = Dag()
243    # Record the stage sequence to post process deps that are not based on needs
244    # field, for example: sanity job
245    stage_sequence: OrderedDict[str, set[str]] = OrderedDict()
246    for stage in pipeline_stages["nodes"]:
247        stage_sequence[stage["name"]] = set()
248
249    for job in pipeline_jobs["nodes"]:
250        stage_sequence[job["stage"]["name"]].add(job["name"])
251        dag_job: DagNode = {
252            "name": job["name"],
253            "stage": job["stage"]["name"],
254            "needs": set([j["node"]["name"] for j in job["needs"]["edges"]]),
255        }
256        jobs_metadata[job["name"]] = dag_job
257
258    return stage_sequence, jobs_metadata
259
260
261def create_job_needs_dag(gl_gql: GitlabGQL, params, disable_cache: bool = True) -> Dag:
262    """
263    This function creates a Directed Acyclic Graph (DAG) to represent a sequence of jobs, where each
264    job has a set of jobs that it depends on (its "needs") and belongs to a certain "stage".
265    The "name" of the job is used as the key in the dictionary.
266
267    For example, consider the following DAG:
268
269        1. build stage: job1 -> job2 -> job3
270        2. test stage: job2 -> job4
271
272    - The job needs for job3 are: job1, job2
273    - The job needs for job4 are: job2
274    - The job2 needs to wait all jobs from build stage to finish.
275
276    The resulting DAG would look like this:
277
278        dag = {
279            "job1": {"needs": set(), "stage": "build", "name": "job1"},
280            "job2": {"needs": {"job1", "job2", job3"}, "stage": "test", "name": "job2"},
281            "job3": {"needs": {"job1", "job2"}, "stage": "build", "name": "job3"},
282            "job4": {"needs": {"job2"}, "stage": "test", "name": "job4"},
283        }
284
285    To access the job needs, one can do:
286
287        dag["job3"]["needs"]
288
289    This will return the set of jobs that job3 needs: {"job1", "job2"}
290
291    Args:
292        gl_gql (GitlabGQL): The `gl_gql` parameter is an instance of the `GitlabGQL` class, which is
293            used to make GraphQL queries to the GitLab API.
294        params (dict): The `params` parameter is a dictionary that contains the necessary parameters
295            for the GraphQL query. It is used to specify the details of the pipeline for which the
296            job needs DAG is being created.
297            The specific keys and values in the `params` dictionary will depend on
298            the requirements of the GraphQL query being executed
299        disable_cache (bool): The `disable_cache` parameter is a boolean that specifies whether the
300
301    Returns:
302        The final DAG (Directed Acyclic Graph) representing the job dependencies sourced from needs
303        or stages rule.
304    """
305    stages_jobs_gql = gl_gql.query(
306        "pipeline_details.gql",
307        params=params,
308        paginated_key_loc=["project", "pipeline", "jobs"],
309        disable_cache=disable_cache,
310    )
311    pipeline_data = stages_jobs_gql["project"]["pipeline"]
312    if not pipeline_data:
313        raise RuntimeError(f"Could not find any pipelines for {params}")
314
315    stage_sequence, jobs_metadata = extract_stages_and_job_needs(
316        pipeline_data["jobs"], pipeline_data["stages"]
317    )
318    # Fill the DAG with the job needs from stages that don't have any needs but still need to wait
319    # for previous stages
320    final_dag = insert_early_stage_jobs(stage_sequence, jobs_metadata)
321    # Now that each job has its direct needs filled correctly, update the "needs" field for each job
322    # in the DAG by performing a topological traversal
323    traverse_dag_needs(final_dag)
324
325    return final_dag
326
327
328def filter_dag(
329    dag: Dag, job_name_regex: Pattern, include_stage_regex: Pattern, exclude_stage_regex: Pattern
330) -> Dag:
331    filtered_jobs: Dag = Dag({})
332    for (job, data) in dag.items():
333        if not job_name_regex.fullmatch(job):
334            continue
335        if not include_stage_regex.fullmatch(data["stage"]):
336            continue
337        if exclude_stage_regex.fullmatch(data["stage"]):
338            continue
339        filtered_jobs[job] = data
340    return filtered_jobs
341
342
343def print_dag(dag: Dag) -> None:
344    for job, data in sorted(dag.items()):
345        print(f"{job}:\n\t{' '.join(data['needs'])}\n")
346
347
348def fetch_merged_yaml(gl_gql: GitlabGQL, params) -> dict[str, Any]:
349    params["content"] = dedent("""\
350    include:
351      - local: .gitlab-ci.yml
352    """)
353    raw_response = gl_gql.query("job_details.gql", params)
354    ci_config = raw_response["ciConfig"]
355    if merged_yaml := ci_config["mergedYaml"]:
356        return yaml.safe_load(merged_yaml)
357    if "errors" in ci_config:
358        for error in ci_config["errors"]:
359            print(error)
360
361    gl_gql.invalidate_query_cache()
362    raise ValueError(
363        """
364    Could not fetch any content for merged YAML,
365    please verify if the git SHA exists in remote.
366    Maybe you forgot to `git push`?  """
367    )
368
369
370def recursive_fill(job, relationship_field, target_data, acc_data: dict, merged_yaml):
371    if relatives := job.get(relationship_field):
372        if isinstance(relatives, str):
373            relatives = [relatives]
374
375        for relative in relatives:
376            parent_job = merged_yaml[relative]
377            acc_data = recursive_fill(parent_job, acc_data, merged_yaml)  # type: ignore
378
379    acc_data |= job.get(target_data, {})
380
381    return acc_data
382
383
384def get_variables(job, merged_yaml, project_path, sha) -> dict[str, str]:
385    p = get_project_root_dir() / ".gitlab-ci" / "image-tags.yml"
386    image_tags = yaml.safe_load(p.read_text())
387
388    variables = image_tags["variables"]
389    variables |= merged_yaml["variables"]
390    variables |= job["variables"]
391    variables["CI_PROJECT_PATH"] = project_path
392    variables["CI_PROJECT_NAME"] = project_path.split("/")[1]
393    variables["CI_REGISTRY_IMAGE"] = "registry.freedesktop.org/${CI_PROJECT_PATH}"
394    variables["CI_COMMIT_SHA"] = sha
395
396    while recurse_among_variables_space(variables):
397        pass
398
399    return variables
400
401
402# Based on: https://stackoverflow.com/a/2158532/1079223
403def flatten(xs):
404    for x in xs:
405        if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
406            yield from flatten(x)
407        else:
408            yield x
409
410
411def get_full_script(job) -> list[str]:
412    script = []
413    for script_part in ("before_script", "script", "after_script"):
414        script.append(f"# {script_part}")
415        lines = flatten(job.get(script_part, []))
416        script.extend(lines)
417        script.append("")
418
419    return script
420
421
422def recurse_among_variables_space(var_graph) -> bool:
423    updated = False
424    for var, value in var_graph.items():
425        value = str(value)
426        dep_vars = []
427        if match := re.findall(r"(\$[{]?[\w\d_]*[}]?)", value):
428            all_dep_vars = [v.lstrip("${").rstrip("}") for v in match]
429            # print(value, match, all_dep_vars)
430            dep_vars = [v for v in all_dep_vars if v in var_graph]
431
432        for dep_var in dep_vars:
433            dep_value = str(var_graph[dep_var])
434            new_value = var_graph[var]
435            new_value = new_value.replace(f"${{{dep_var}}}", dep_value)
436            new_value = new_value.replace(f"${dep_var}", dep_value)
437            var_graph[var] = new_value
438            updated |= dep_value != new_value
439
440    return updated
441
442
443def print_job_final_definition(job_name, merged_yaml, project_path, sha):
444    job = merged_yaml[job_name]
445    variables = get_variables(job, merged_yaml, project_path, sha)
446
447    print("# --------- variables ---------------")
448    for var, value in sorted(variables.items()):
449        print(f"export {var}={value!r}")
450
451    # TODO: Recurse into needs to get full script
452    # TODO: maybe create a extra yaml file to avoid too much rework
453    script = get_full_script(job)
454    print()
455    print()
456    print("# --------- full script ---------------")
457    print("\n".join(script))
458
459    if image := variables.get("MESA_IMAGE"):
460        print()
461        print()
462        print("# --------- container image ---------------")
463        print(image)
464
465
466def from_sha_to_pipeline_iid(gl_gql: GitlabGQL, params) -> str:
467    result = gl_gql.query("pipeline_utils.gql", params)
468
469    return result["project"]["pipelines"]["nodes"][0]["iid"]
470
471
472def parse_args() -> Namespace:
473    parser = ArgumentParser(
474        formatter_class=ArgumentDefaultsHelpFormatter,
475        description="CLI and library with utility functions to debug jobs via Gitlab GraphQL",
476        epilog=f"""Example:
477        {Path(__file__).name} --print-dag""",
478    )
479    parser.add_argument("-pp", "--project-path", type=str, default="mesa/mesa")
480    parser.add_argument("--sha", "--rev", type=str, default='HEAD')
481    parser.add_argument(
482        "--regex",
483        type=str,
484        required=False,
485        default=".*",
486        help="Regex pattern for the job name to be considered",
487    )
488    parser.add_argument(
489        "--include-stage",
490        type=str,
491        required=False,
492        default=".*",
493        help="Regex pattern for the stage name to be considered",
494    )
495    parser.add_argument(
496        "--exclude-stage",
497        type=str,
498        required=False,
499        default="^$",
500        help="Regex pattern for the stage name to be excluded",
501    )
502    mutex_group_print = parser.add_mutually_exclusive_group()
503    mutex_group_print.add_argument(
504        "--print-dag",
505        action="store_true",
506        help="Print job needs DAG",
507    )
508    mutex_group_print.add_argument(
509        "--print-merged-yaml",
510        action="store_true",
511        help="Print the resulting YAML for the specific SHA",
512    )
513    mutex_group_print.add_argument(
514        "--print-job-manifest",
515        metavar='JOB_NAME',
516        type=str,
517        help="Print the resulting job data"
518    )
519    parser.add_argument(
520        "--gitlab-token-file",
521        type=str,
522        default=get_token_from_default_dir(),
523        help="force GitLab token, otherwise it's read from $XDG_CONFIG_HOME/gitlab-token",
524    )
525
526    args = parser.parse_args()
527    args.gitlab_token = Path(args.gitlab_token_file).read_text().strip()
528    return args
529
530
531def main():
532    args = parse_args()
533    gl_gql = GitlabGQL(token=args.gitlab_token)
534
535    sha = check_output(['git', 'rev-parse', args.sha]).decode('ascii').strip()
536
537    if args.print_dag:
538        iid = from_sha_to_pipeline_iid(gl_gql, {"projectPath": args.project_path, "sha": sha})
539        dag = create_job_needs_dag(
540            gl_gql, {"projectPath": args.project_path, "iid": iid}, disable_cache=True
541        )
542
543        dag = filter_dag(dag, re.compile(args.regex), re.compile(args.include_stage), re.compile(args.exclude_stage))
544
545        print_dag(dag)
546
547    if args.print_merged_yaml or args.print_job_manifest:
548        merged_yaml = fetch_merged_yaml(
549            gl_gql, {"projectPath": args.project_path, "sha": sha}
550        )
551
552        if args.print_merged_yaml:
553            print(yaml.dump(merged_yaml, indent=2))
554
555        if args.print_job_manifest:
556            print_job_final_definition(
557                args.print_job_manifest, merged_yaml, args.project_path, sha
558            )
559
560
561if __name__ == "__main__":
562    main()
563