1#!/usr/bin/env python3 2# For the dependencies, see the requirements.txt 3 4import logging 5import re 6import traceback 7from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace 8from collections import OrderedDict 9from copy import deepcopy 10from dataclasses import dataclass, field 11from itertools import accumulate 12from pathlib import Path 13from subprocess import check_output 14from textwrap import dedent 15from typing import Any, Iterable, Optional, Pattern, TypedDict, Union 16 17import yaml 18from filecache import DAY, filecache 19from gitlab_common import get_token_from_default_dir 20from gql import Client, gql 21from gql.transport.requests import RequestsHTTPTransport 22from graphql import DocumentNode 23 24 25class DagNode(TypedDict): 26 needs: set[str] 27 stage: str 28 # `name` is redundant but is here for retro-compatibility 29 name: str 30 31 32# see create_job_needs_dag function for more details 33Dag = dict[str, DagNode] 34 35 36StageSeq = OrderedDict[str, set[str]] 37 38 39def get_project_root_dir(): 40 root_path = Path(__file__).parent.parent.parent.resolve() 41 gitlab_file = root_path / ".gitlab-ci.yml" 42 assert gitlab_file.exists() 43 44 return root_path 45 46 47@dataclass 48class GitlabGQL: 49 _transport: Any = field(init=False) 50 client: Client = field(init=False) 51 url: str = "https://gitlab.freedesktop.org/api/graphql" 52 token: Optional[str] = None 53 54 def __post_init__(self) -> None: 55 self._setup_gitlab_gql_client() 56 57 def _setup_gitlab_gql_client(self) -> None: 58 # Select your transport with a defined url endpoint 59 headers = {} 60 if self.token: 61 headers["Authorization"] = f"Bearer {self.token}" 62 self._transport = RequestsHTTPTransport(url=self.url, headers=headers) 63 64 # Create a GraphQL client using the defined transport 65 self.client = Client(transport=self._transport, fetch_schema_from_transport=True) 66 67 def query( 68 self, 69 gql_file: Union[Path, str], 70 params: dict[str, Any] = {}, 71 operation_name: Optional[str] = None, 72 paginated_key_loc: Iterable[str] = [], 73 disable_cache: bool = False, 74 ) -> dict[str, Any]: 75 def run_uncached() -> dict[str, Any]: 76 if paginated_key_loc: 77 return self._sweep_pages(gql_file, params, operation_name, paginated_key_loc) 78 return self._query(gql_file, params, operation_name) 79 80 if disable_cache: 81 return run_uncached() 82 83 try: 84 # Create an auxiliary variable to deliver a cached result and enable catching exceptions 85 # Decorate the query to be cached 86 if paginated_key_loc: 87 result = self._sweep_pages_cached( 88 gql_file, params, operation_name, paginated_key_loc 89 ) 90 else: 91 result = self._query_cached(gql_file, params, operation_name) 92 return result # type: ignore 93 except Exception as ex: 94 logging.error(f"Cached query failed with {ex}") 95 # print exception traceback 96 traceback_str = "".join(traceback.format_exception(ex)) 97 logging.error(traceback_str) 98 self.invalidate_query_cache() 99 logging.error("Cache invalidated, retrying without cache") 100 finally: 101 return run_uncached() 102 103 def _query( 104 self, 105 gql_file: Union[Path, str], 106 params: dict[str, Any] = {}, 107 operation_name: Optional[str] = None, 108 ) -> dict[str, Any]: 109 # Provide a GraphQL query 110 source_path: Path = Path(__file__).parent 111 pipeline_query_file: Path = source_path / gql_file 112 113 query: DocumentNode 114 with open(pipeline_query_file, "r") as f: 115 pipeline_query = f.read() 116 query = gql(pipeline_query) 117 118 # Execute the query on the transport 119 return self.client.execute_sync( 120 query, variable_values=params, operation_name=operation_name 121 ) 122 123 @filecache(DAY) 124 def _sweep_pages_cached(self, *args, **kwargs): 125 return self._sweep_pages(*args, **kwargs) 126 127 @filecache(DAY) 128 def _query_cached(self, *args, **kwargs): 129 return self._query(*args, **kwargs) 130 131 def _sweep_pages( 132 self, query, params, operation_name=None, paginated_key_loc: Iterable[str] = [] 133 ) -> dict[str, Any]: 134 """ 135 Retrieve paginated data from a GraphQL API and concatenate the results into a single 136 response. 137 138 Args: 139 query: represents a filepath with the GraphQL query to be executed. 140 params: a dictionary that contains the parameters to be passed to the query. These 141 parameters can be used to filter or modify the results of the query. 142 operation_name: The `operation_name` parameter is an optional parameter that specifies 143 the name of the GraphQL operation to be executed. It is used when making a GraphQL 144 query to specify which operation to execute if there are multiple operations defined 145 in the GraphQL schema. If not provided, the default operation will be executed. 146 paginated_key_loc (Iterable[str]): The `paginated_key_loc` parameter is an iterable of 147 strings that represents the location of the paginated field within the response. It 148 is used to extract the paginated field from the response and append it to the final 149 result. The node has to be a list of objects with a `pageInfo` field that contains 150 at least the `hasNextPage` and `endCursor` fields. 151 152 Returns: 153 a dictionary containing the response from the query with the paginated field 154 concatenated. 155 """ 156 157 def fetch_page(cursor: str | None = None) -> dict[str, Any]: 158 if cursor: 159 params["cursor"] = cursor 160 logging.info( 161 f"Found more than 100 elements, paginating. " 162 f"Current cursor at {cursor}" 163 ) 164 165 return self._query(query, params, operation_name) 166 167 # Execute the initial query 168 response: dict[str, Any] = fetch_page() 169 170 # Initialize an empty list to store the final result 171 final_partial_field: list[dict[str, Any]] = [] 172 173 # Loop until all pages have been retrieved 174 while True: 175 # Get the partial field to be appended to the final result 176 partial_field = response 177 for key in paginated_key_loc: 178 partial_field = partial_field[key] 179 180 # Append the partial field to the final result 181 final_partial_field += partial_field["nodes"] 182 183 # Check if there are more pages to retrieve 184 page_info = partial_field["pageInfo"] 185 if not page_info["hasNextPage"]: 186 break 187 188 # Execute the query with the updated cursor parameter 189 response = fetch_page(page_info["endCursor"]) 190 191 # Replace the "nodes" field in the original response with the final result 192 partial_field["nodes"] = final_partial_field 193 return response 194 195 def invalidate_query_cache(self) -> None: 196 logging.warning("Invalidating query cache") 197 try: 198 self._sweep_pages._db.clear() 199 self._query._db.clear() 200 except AttributeError as ex: 201 logging.warning(f"Could not invalidate cache, maybe it was not used in {ex.args}?") 202 203 204def insert_early_stage_jobs(stage_sequence: StageSeq, jobs_metadata: Dag) -> Dag: 205 pre_processed_dag: dict[str, set[str]] = {} 206 jobs_from_early_stages = list(accumulate(stage_sequence.values(), set.union)) 207 for job_name, metadata in jobs_metadata.items(): 208 final_needs: set[str] = deepcopy(metadata["needs"]) 209 # Pre-process jobs that are not based on needs field 210 # e.g. sanity job in mesa MR pipelines 211 if not final_needs: 212 job_stage: str = jobs_metadata[job_name]["stage"] 213 stage_index: int = list(stage_sequence.keys()).index(job_stage) 214 if stage_index > 0: 215 final_needs |= jobs_from_early_stages[stage_index - 1] 216 pre_processed_dag[job_name] = final_needs 217 218 for job_name, needs in pre_processed_dag.items(): 219 jobs_metadata[job_name]["needs"] = needs 220 221 return jobs_metadata 222 223 224def traverse_dag_needs(jobs_metadata: Dag) -> None: 225 created_jobs = set(jobs_metadata.keys()) 226 for job, metadata in jobs_metadata.items(): 227 final_needs: set = deepcopy(metadata["needs"]) & created_jobs 228 # Post process jobs that are based on needs field 229 partial = True 230 231 while partial: 232 next_depth: set[str] = {n for dn in final_needs if dn in jobs_metadata for n in jobs_metadata[dn]["needs"]} 233 partial: bool = not final_needs.issuperset(next_depth) 234 final_needs = final_needs.union(next_depth) 235 236 jobs_metadata[job]["needs"] = final_needs 237 238 239def extract_stages_and_job_needs( 240 pipeline_jobs: dict[str, Any], pipeline_stages: dict[str, Any] 241) -> tuple[StageSeq, Dag]: 242 jobs_metadata = Dag() 243 # Record the stage sequence to post process deps that are not based on needs 244 # field, for example: sanity job 245 stage_sequence: OrderedDict[str, set[str]] = OrderedDict() 246 for stage in pipeline_stages["nodes"]: 247 stage_sequence[stage["name"]] = set() 248 249 for job in pipeline_jobs["nodes"]: 250 stage_sequence[job["stage"]["name"]].add(job["name"]) 251 dag_job: DagNode = { 252 "name": job["name"], 253 "stage": job["stage"]["name"], 254 "needs": set([j["node"]["name"] for j in job["needs"]["edges"]]), 255 } 256 jobs_metadata[job["name"]] = dag_job 257 258 return stage_sequence, jobs_metadata 259 260 261def create_job_needs_dag(gl_gql: GitlabGQL, params, disable_cache: bool = True) -> Dag: 262 """ 263 This function creates a Directed Acyclic Graph (DAG) to represent a sequence of jobs, where each 264 job has a set of jobs that it depends on (its "needs") and belongs to a certain "stage". 265 The "name" of the job is used as the key in the dictionary. 266 267 For example, consider the following DAG: 268 269 1. build stage: job1 -> job2 -> job3 270 2. test stage: job2 -> job4 271 272 - The job needs for job3 are: job1, job2 273 - The job needs for job4 are: job2 274 - The job2 needs to wait all jobs from build stage to finish. 275 276 The resulting DAG would look like this: 277 278 dag = { 279 "job1": {"needs": set(), "stage": "build", "name": "job1"}, 280 "job2": {"needs": {"job1", "job2", job3"}, "stage": "test", "name": "job2"}, 281 "job3": {"needs": {"job1", "job2"}, "stage": "build", "name": "job3"}, 282 "job4": {"needs": {"job2"}, "stage": "test", "name": "job4"}, 283 } 284 285 To access the job needs, one can do: 286 287 dag["job3"]["needs"] 288 289 This will return the set of jobs that job3 needs: {"job1", "job2"} 290 291 Args: 292 gl_gql (GitlabGQL): The `gl_gql` parameter is an instance of the `GitlabGQL` class, which is 293 used to make GraphQL queries to the GitLab API. 294 params (dict): The `params` parameter is a dictionary that contains the necessary parameters 295 for the GraphQL query. It is used to specify the details of the pipeline for which the 296 job needs DAG is being created. 297 The specific keys and values in the `params` dictionary will depend on 298 the requirements of the GraphQL query being executed 299 disable_cache (bool): The `disable_cache` parameter is a boolean that specifies whether the 300 301 Returns: 302 The final DAG (Directed Acyclic Graph) representing the job dependencies sourced from needs 303 or stages rule. 304 """ 305 stages_jobs_gql = gl_gql.query( 306 "pipeline_details.gql", 307 params=params, 308 paginated_key_loc=["project", "pipeline", "jobs"], 309 disable_cache=disable_cache, 310 ) 311 pipeline_data = stages_jobs_gql["project"]["pipeline"] 312 if not pipeline_data: 313 raise RuntimeError(f"Could not find any pipelines for {params}") 314 315 stage_sequence, jobs_metadata = extract_stages_and_job_needs( 316 pipeline_data["jobs"], pipeline_data["stages"] 317 ) 318 # Fill the DAG with the job needs from stages that don't have any needs but still need to wait 319 # for previous stages 320 final_dag = insert_early_stage_jobs(stage_sequence, jobs_metadata) 321 # Now that each job has its direct needs filled correctly, update the "needs" field for each job 322 # in the DAG by performing a topological traversal 323 traverse_dag_needs(final_dag) 324 325 return final_dag 326 327 328def filter_dag( 329 dag: Dag, job_name_regex: Pattern, include_stage_regex: Pattern, exclude_stage_regex: Pattern 330) -> Dag: 331 filtered_jobs: Dag = Dag({}) 332 for (job, data) in dag.items(): 333 if not job_name_regex.fullmatch(job): 334 continue 335 if not include_stage_regex.fullmatch(data["stage"]): 336 continue 337 if exclude_stage_regex.fullmatch(data["stage"]): 338 continue 339 filtered_jobs[job] = data 340 return filtered_jobs 341 342 343def print_dag(dag: Dag) -> None: 344 for job, data in sorted(dag.items()): 345 print(f"{job}:\n\t{' '.join(data['needs'])}\n") 346 347 348def fetch_merged_yaml(gl_gql: GitlabGQL, params) -> dict[str, Any]: 349 params["content"] = dedent("""\ 350 include: 351 - local: .gitlab-ci.yml 352 """) 353 raw_response = gl_gql.query("job_details.gql", params) 354 ci_config = raw_response["ciConfig"] 355 if merged_yaml := ci_config["mergedYaml"]: 356 return yaml.safe_load(merged_yaml) 357 if "errors" in ci_config: 358 for error in ci_config["errors"]: 359 print(error) 360 361 gl_gql.invalidate_query_cache() 362 raise ValueError( 363 """ 364 Could not fetch any content for merged YAML, 365 please verify if the git SHA exists in remote. 366 Maybe you forgot to `git push`? """ 367 ) 368 369 370def recursive_fill(job, relationship_field, target_data, acc_data: dict, merged_yaml): 371 if relatives := job.get(relationship_field): 372 if isinstance(relatives, str): 373 relatives = [relatives] 374 375 for relative in relatives: 376 parent_job = merged_yaml[relative] 377 acc_data = recursive_fill(parent_job, acc_data, merged_yaml) # type: ignore 378 379 acc_data |= job.get(target_data, {}) 380 381 return acc_data 382 383 384def get_variables(job, merged_yaml, project_path, sha) -> dict[str, str]: 385 p = get_project_root_dir() / ".gitlab-ci" / "image-tags.yml" 386 image_tags = yaml.safe_load(p.read_text()) 387 388 variables = image_tags["variables"] 389 variables |= merged_yaml["variables"] 390 variables |= job["variables"] 391 variables["CI_PROJECT_PATH"] = project_path 392 variables["CI_PROJECT_NAME"] = project_path.split("/")[1] 393 variables["CI_REGISTRY_IMAGE"] = "registry.freedesktop.org/${CI_PROJECT_PATH}" 394 variables["CI_COMMIT_SHA"] = sha 395 396 while recurse_among_variables_space(variables): 397 pass 398 399 return variables 400 401 402# Based on: https://stackoverflow.com/a/2158532/1079223 403def flatten(xs): 404 for x in xs: 405 if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): 406 yield from flatten(x) 407 else: 408 yield x 409 410 411def get_full_script(job) -> list[str]: 412 script = [] 413 for script_part in ("before_script", "script", "after_script"): 414 script.append(f"# {script_part}") 415 lines = flatten(job.get(script_part, [])) 416 script.extend(lines) 417 script.append("") 418 419 return script 420 421 422def recurse_among_variables_space(var_graph) -> bool: 423 updated = False 424 for var, value in var_graph.items(): 425 value = str(value) 426 dep_vars = [] 427 if match := re.findall(r"(\$[{]?[\w\d_]*[}]?)", value): 428 all_dep_vars = [v.lstrip("${").rstrip("}") for v in match] 429 # print(value, match, all_dep_vars) 430 dep_vars = [v for v in all_dep_vars if v in var_graph] 431 432 for dep_var in dep_vars: 433 dep_value = str(var_graph[dep_var]) 434 new_value = var_graph[var] 435 new_value = new_value.replace(f"${{{dep_var}}}", dep_value) 436 new_value = new_value.replace(f"${dep_var}", dep_value) 437 var_graph[var] = new_value 438 updated |= dep_value != new_value 439 440 return updated 441 442 443def print_job_final_definition(job_name, merged_yaml, project_path, sha): 444 job = merged_yaml[job_name] 445 variables = get_variables(job, merged_yaml, project_path, sha) 446 447 print("# --------- variables ---------------") 448 for var, value in sorted(variables.items()): 449 print(f"export {var}={value!r}") 450 451 # TODO: Recurse into needs to get full script 452 # TODO: maybe create a extra yaml file to avoid too much rework 453 script = get_full_script(job) 454 print() 455 print() 456 print("# --------- full script ---------------") 457 print("\n".join(script)) 458 459 if image := variables.get("MESA_IMAGE"): 460 print() 461 print() 462 print("# --------- container image ---------------") 463 print(image) 464 465 466def from_sha_to_pipeline_iid(gl_gql: GitlabGQL, params) -> str: 467 result = gl_gql.query("pipeline_utils.gql", params) 468 469 return result["project"]["pipelines"]["nodes"][0]["iid"] 470 471 472def parse_args() -> Namespace: 473 parser = ArgumentParser( 474 formatter_class=ArgumentDefaultsHelpFormatter, 475 description="CLI and library with utility functions to debug jobs via Gitlab GraphQL", 476 epilog=f"""Example: 477 {Path(__file__).name} --print-dag""", 478 ) 479 parser.add_argument("-pp", "--project-path", type=str, default="mesa/mesa") 480 parser.add_argument("--sha", "--rev", type=str, default='HEAD') 481 parser.add_argument( 482 "--regex", 483 type=str, 484 required=False, 485 default=".*", 486 help="Regex pattern for the job name to be considered", 487 ) 488 parser.add_argument( 489 "--include-stage", 490 type=str, 491 required=False, 492 default=".*", 493 help="Regex pattern for the stage name to be considered", 494 ) 495 parser.add_argument( 496 "--exclude-stage", 497 type=str, 498 required=False, 499 default="^$", 500 help="Regex pattern for the stage name to be excluded", 501 ) 502 mutex_group_print = parser.add_mutually_exclusive_group() 503 mutex_group_print.add_argument( 504 "--print-dag", 505 action="store_true", 506 help="Print job needs DAG", 507 ) 508 mutex_group_print.add_argument( 509 "--print-merged-yaml", 510 action="store_true", 511 help="Print the resulting YAML for the specific SHA", 512 ) 513 mutex_group_print.add_argument( 514 "--print-job-manifest", 515 metavar='JOB_NAME', 516 type=str, 517 help="Print the resulting job data" 518 ) 519 parser.add_argument( 520 "--gitlab-token-file", 521 type=str, 522 default=get_token_from_default_dir(), 523 help="force GitLab token, otherwise it's read from $XDG_CONFIG_HOME/gitlab-token", 524 ) 525 526 args = parser.parse_args() 527 args.gitlab_token = Path(args.gitlab_token_file).read_text().strip() 528 return args 529 530 531def main(): 532 args = parse_args() 533 gl_gql = GitlabGQL(token=args.gitlab_token) 534 535 sha = check_output(['git', 'rev-parse', args.sha]).decode('ascii').strip() 536 537 if args.print_dag: 538 iid = from_sha_to_pipeline_iid(gl_gql, {"projectPath": args.project_path, "sha": sha}) 539 dag = create_job_needs_dag( 540 gl_gql, {"projectPath": args.project_path, "iid": iid}, disable_cache=True 541 ) 542 543 dag = filter_dag(dag, re.compile(args.regex), re.compile(args.include_stage), re.compile(args.exclude_stage)) 544 545 print_dag(dag) 546 547 if args.print_merged_yaml or args.print_job_manifest: 548 merged_yaml = fetch_merged_yaml( 549 gl_gql, {"projectPath": args.project_path, "sha": sha} 550 ) 551 552 if args.print_merged_yaml: 553 print(yaml.dump(merged_yaml, indent=2)) 554 555 if args.print_job_manifest: 556 print_job_final_definition( 557 args.print_job_manifest, merged_yaml, args.project_path, sha 558 ) 559 560 561if __name__ == "__main__": 562 main() 563