1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3""" The Python Hipify script. 4## 5# Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved. 6# 2017-2018 Advanced Micro Devices, Inc. and 7# Facebook Inc. All rights reserved. 8# 9# Permission is hereby granted, free of charge, to any person obtaining a copy 10# of this software and associated documentation files (the "Software"), to deal 11# in the Software without restriction, including without limitation the rights 12# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13# copies of the Software, and to permit persons to whom the Software is 14# furnished to do so, subject to the following conditions: 15# 16# The above copyright notice and this permission notice shall be included in 17# all copies or substantial portions of the Software. 18# 19# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 25# THE SOFTWARE. 26""" 27import argparse 28import fnmatch 29import re 30import shutil 31import sys 32import os 33 34from . import constants 35from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS 36from .cuda_to_hip_mappings import MATH_TRANSPILATIONS 37 38from typing import Dict, List, Iterator, Optional 39from collections.abc import Mapping, Iterable 40from enum import Enum 41 42class CurrentState(Enum): 43 INITIALIZED = 1 44 DONE = 2 45 46class HipifyResult: 47 def __init__(self, current_state, hipified_path): 48 self.current_state = current_state 49 self.hipified_path = hipified_path 50 self.status = "" 51 52 def __str__(self): 53 return (f"HipifyResult:: current_state: {self.current_state}, hipified_path : {self.hipified_path}, status: {self.status}") 54 55HipifyFinalResult = Dict[str, HipifyResult] 56HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n" 57HIPIFY_FINAL_RESULT: HipifyFinalResult = {} 58 59# Hardcode the PyTorch template map 60"""This dictionary provides the mapping from PyTorch kernel template types 61to their actual types.""" 62PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"} 63 64__all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter', 65 'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group', 66 'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared', 67 'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_special_file', 'is_caffe2_gpu_file', 68 'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header', 69 'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'CurrentState', 'HipifyResult', 'hipify'] 70 71 72class InputError(Exception): 73 # Exception raised for errors in the input. 74 75 def __init__(self, message): 76 super().__init__(message) 77 self.message = message 78 79 def __str__(self): 80 return f"Input error: {self.message}" 81 82 83def openf(filename, mode): 84 return open(filename, mode, errors='ignore') 85 86 87# Color coding for printing 88class bcolors: 89 HEADER = '\033[95m' 90 OKBLUE = '\033[94m' 91 OKGREEN = '\033[92m' 92 WARNING = '\033[93m' 93 FAIL = '\033[91m' 94 ENDC = '\033[0m' 95 BOLD = '\033[1m' 96 UNDERLINE = '\033[4m' 97 98 99# To the programmer, the output of hipify most likely are intermediates. 100# This class allows users of hipify to ask for a cleanup by running the 101# hipify and compilation in a with instantiating this context manager class 102# with keep_intermediates=False. 103# The main usecase is the cpp_extensions, specifically the load method. 104# It is a good idea to keep intermediates (in case of errors or to 105# not recompile unchanged files), but in cases where you don't want to 106# keep them (e.g. in the CI), this can be used to remove files. 107class GeneratedFileCleaner: 108 """Context Manager to clean up generated files""" 109 def __init__(self, keep_intermediates=False): 110 self.keep_intermediates = keep_intermediates 111 self.files_to_clean = set() 112 self.dirs_to_clean = [] 113 114 def __enter__(self): 115 return self 116 117 def open(self, fn, *args, **kwargs): 118 if not os.path.exists(fn): 119 self.files_to_clean.add(os.path.abspath(fn)) 120 return open(fn, *args, **kwargs) 121 122 def makedirs(self, dn, exist_ok=False): 123 parent, n = os.path.split(dn) 124 if not n: 125 parent, n = os.path.split(parent) 126 if parent and n and not os.path.exists(parent): 127 self.makedirs(parent, exist_ok=True) 128 if not os.path.isdir(dn) or not exist_ok: 129 os.mkdir(dn) 130 self.dirs_to_clean.append(os.path.abspath(dn)) 131 132 def __exit__(self, type, value, traceback): 133 if not self.keep_intermediates: 134 for f in self.files_to_clean: 135 os.unlink(f) 136 for d in self.dirs_to_clean[::-1]: 137 os.rmdir(d) 138 139 140def match_extensions(filename: str, extensions: Iterable) -> bool: 141 """Helper method to see if filename ends with certain extension""" 142 return any(filename.endswith(e) for e in extensions) 143 144 145def _fnmatch(filepath, patterns): 146 return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns) 147 148 149def matched_files_iter( 150 root_path: str, 151 includes: Iterable = (), 152 ignores: Iterable = (), 153 extensions: Iterable = (), 154 out_of_place_only: bool = False, 155 is_pytorch_extension: bool = False) -> Iterator[str]: 156 157 exact_matches = set(includes) 158 159 # This is a very rough heuristic; really, we want to avoid scanning 160 # any file which is not checked into source control, but this script 161 # needs to work even if you're in a Git or Hg checkout, so easier to 162 # just block the biggest time sinks that won't matter in the 163 # end. 164 for (abs_dirpath, dirs, filenames) in os.walk(root_path, topdown=True): 165 rel_dirpath = os.path.relpath(abs_dirpath, root_path) 166 if rel_dirpath == '.': 167 # Blah blah blah O(n) blah blah 168 if ".git" in dirs: 169 dirs.remove(".git") 170 if "build" in dirs: 171 dirs.remove("build") 172 if "third_party" in dirs: 173 dirs.remove("third_party") 174 dirs.append("third_party/nvfuser") 175 for filename in filenames: 176 filepath = os.path.join(abs_dirpath, filename) 177 rel_filepath = os.path.join(rel_dirpath, filename) 178 # We respect extensions, UNLESS you wrote the entire 179 # filename verbatim, in which case we always accept it 180 if ( 181 _fnmatch(filepath, includes) 182 and (not _fnmatch(filepath, ignores)) 183 and (match_extensions(filepath, extensions) or filepath in exact_matches) 184 ): 185 if not is_pytorch_extension: # for pytorch extensions, consider all files 186 if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(rel_filepath): 187 continue 188 if out_of_place_only and not is_out_of_place(rel_filepath): 189 continue 190 yield filepath 191 192 193def preprocess_file_and_save_result( 194 output_directory: str, 195 filepath: str, 196 all_files: Iterable, 197 header_include_dirs: Iterable, 198 stats: Dict[str, List], 199 hip_clang_launch: bool, 200 is_pytorch_extension: bool, 201 clean_ctx: GeneratedFileCleaner, 202 show_progress: bool) -> None: 203 fin_path = os.path.abspath(os.path.join(output_directory, filepath)) 204 hipify_result = HipifyResult(current_state=CurrentState.INITIALIZED, hipified_path=fin_path) 205 HIPIFY_FINAL_RESULT[fin_path] = hipify_result 206 result = preprocessor(output_directory, filepath, all_files, header_include_dirs, stats, 207 hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) 208 209 # Show what happened 210 if show_progress and "ignored" not in result.status: 211 print( 212 fin_path, "->", 213 result.hipified_path, result.status, flush=True) 214 215 HIPIFY_FINAL_RESULT[fin_path] = result 216 217 218def compute_stats(stats): 219 unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]} 220 221 # Print the number of unsupported calls 222 print(f"Total number of unsupported CUDA function calls: {len(unsupported_calls):d}") 223 224 # Print the list of unsupported calls 225 print(", ".join(unsupported_calls)) 226 227 # Print the number of kernel launches 228 print(f"\nTotal number of replaced kernel launches: {len(stats['kernel_launches']):d}") 229 230 231def add_dim3(kernel_string, cuda_kernel): 232 '''adds dim3() to the second and third arguments in the kernel launch''' 233 count = 0 234 closure = 0 235 kernel_string = kernel_string.replace("<<<", "").replace(">>>", "") 236 arg_locs: List[Dict[str, int]] = [{} for _ in range(2)] 237 arg_locs[count]['start'] = 0 238 for ind, c in enumerate(kernel_string): 239 if count > 1: 240 break 241 if c == "(": 242 closure += 1 243 elif c == ")": 244 closure -= 1 245 if (c == "," or ind == len(kernel_string) - 1) and closure == 0: 246 arg_locs[count]['end'] = ind + (c != ",") 247 count += 1 248 if count < 2: 249 arg_locs[count]['start'] = ind + 1 250 251 first_arg_raw = kernel_string[arg_locs[0]['start']:arg_locs[0]['end'] + 1] 252 second_arg_raw = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']] 253 254 first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ") 255 second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ") 256 257 first_arg_dim3 = f"dim3({first_arg_clean})" 258 second_arg_dim3 = f"dim3({second_arg_clean})" 259 260 first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3) 261 second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3) 262 cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3) 263 return cuda_kernel 264 265 266RE_KERNEL_LAUNCH = re.compile(r'([ ]+)(detail?)::[ ]+\\\n[ ]+') 267 268 269def processKernelLaunches(string, stats): 270 """ Replace the CUDA style Kernel launches with the HIP style kernel launches.""" 271 # Concat the namespace with the kernel names. (Find cleaner way of doing this later). 272 string = RE_KERNEL_LAUNCH.sub(lambda inp: f"{inp.group(1)}{inp.group(2)}::", string) 273 274 def grab_method_and_template(in_kernel): 275 # The positions for relevant kernel components. 276 pos = { 277 "kernel_launch": {"start": in_kernel["start"], "end": in_kernel["end"]}, 278 "kernel_name": {"start": -1, "end": -1}, 279 "template": {"start": -1, "end": -1} 280 } 281 282 # Count for balancing template 283 count = {"<>": 0} 284 285 # Status for whether we are parsing a certain item. 286 START = 0 287 AT_TEMPLATE = 1 288 AFTER_TEMPLATE = 2 289 AT_KERNEL_NAME = 3 290 291 status = START 292 293 # Parse the string character by character 294 for i in range(pos["kernel_launch"]["start"] - 1, -1, -1): 295 char = string[i] 296 297 # Handle Templating Arguments 298 if status in (START, AT_TEMPLATE): 299 if char == ">": 300 if status == START: 301 status = AT_TEMPLATE 302 pos["template"]["end"] = i 303 count["<>"] += 1 304 305 if char == "<": 306 count["<>"] -= 1 307 if count["<>"] == 0 and (status == AT_TEMPLATE): 308 pos["template"]["start"] = i 309 status = AFTER_TEMPLATE 310 311 # Handle Kernel Name 312 if status != AT_TEMPLATE: 313 if string[i].isalnum() or string[i] in {'(', ')', '_', ':', '#'}: 314 if status != AT_KERNEL_NAME: 315 status = AT_KERNEL_NAME 316 pos["kernel_name"]["end"] = i 317 318 # Case: Kernel name starts the string. 319 if i == 0: 320 pos["kernel_name"]["start"] = 0 321 322 # Finished 323 return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])] 324 325 else: 326 # Potential ending point if we're already traversing a kernel's name. 327 if status == AT_KERNEL_NAME: 328 pos["kernel_name"]["start"] = i 329 330 # Finished 331 return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])] 332 333 def find_kernel_bounds(string): 334 """Finds the starting and ending points for all kernel launches in the string.""" 335 kernel_end = 0 336 kernel_positions = [] 337 338 # Continue until we cannot find any more kernels anymore. 339 while string.find("<<<", kernel_end) != -1: 340 # Get kernel starting position (starting from the previous ending point) 341 kernel_start = string.find("<<<", kernel_end) 342 343 # Get kernel ending position (adjust end point past the >>>) 344 kernel_end = string.find(">>>", kernel_start) + 3 345 if kernel_end <= 0: 346 raise InputError("no kernel end found") 347 348 # Add to list of traversed kernels 349 kernel_positions.append({"start": kernel_start, "end": kernel_end, 350 "group": string[kernel_start: kernel_end]}) 351 352 return kernel_positions 353 354 # Replace comments and string literals from the code so that find_kernel_bounds does not 355 # wrongly capture kernels in comments and string literals. 356 # This function replaces them with "x" to keep positions. 357 def mask_comments(string): 358 in_comment = '' 359 prev_c = '' 360 new_string = '' 361 for c in string: 362 if in_comment == '': 363 # Outside comments 364 if c == '/' and prev_c == '/': 365 in_comment = '//' 366 elif c == '*' and prev_c == '/': 367 in_comment = '/*' 368 elif c == '"' and prev_c != '\\' and prev_c != "'": 369 in_comment = '"' 370 elif in_comment == '//': 371 # In // xxx 372 if c == '\r' or c == '\n': 373 in_comment = '' 374 elif in_comment == '/*': 375 # In /* xxx */ 376 if c == '/' and prev_c == '*': 377 in_comment = '' 378 elif in_comment == '"': 379 # In "" 380 if c == '"' and prev_c != '\\': 381 in_comment = '' 382 prev_c = c 383 if in_comment == '': 384 new_string += c 385 else: 386 new_string += 'x' 387 return new_string 388 389 # Grab positional ranges of all kernel launches 390 get_kernel_positions = list(find_kernel_bounds(mask_comments(string))) 391 output_string = string 392 393 # Replace each CUDA kernel with a HIP kernel. 394 for kernel in get_kernel_positions: 395 # Get kernel components 396 params = grab_method_and_template(kernel) 397 398 # Find parenthesis after kernel launch 399 parenthesis = string.find("(", kernel["end"]) 400 401 # Extract cuda kernel 402 cuda_kernel = string[params[0]["start"]:parenthesis + 1] 403 kernel_string = string[kernel['start']:kernel['end']] 404 end_param_index = 0 if params[1]['end'] == -1 else 1 405 kernel_name_with_template = string[params[0]['start']:params[end_param_index]['end'] + 1] 406 cuda_kernel_dim3 = add_dim3(kernel_string, cuda_kernel) 407 # Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size) 408 num_klp = len(extract_arguments(0, kernel["group"].replace("<<<", "(").replace(">>>", ")"))) 409 410 hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace( 411 ">>>", ", 0" * (4 - num_klp) + ">>>").replace("<<<", ", ").replace( 412 ">>>", ", ").replace(kernel_name_with_template, "(" + kernel_name_with_template + ")") 413 414 # Replace cuda kernel with hip kernel 415 output_string = output_string.replace(cuda_kernel, hip_kernel) 416 417 # Update the statistics 418 stats["kernel_launches"].append(hip_kernel) 419 420 return output_string 421 422 423def find_closure_group(input_string, start, group): 424 """Generalization for finding a balancing closure group 425 426 if group = ["(", ")"], then finds the first balanced parentheses. 427 if group = ["{", "}"], then finds the first balanced bracket. 428 429 Given an input string, a starting position in the input string, and the group type, 430 find_closure_group returns the positions of group[0] and group[1] as a tuple. 431 432 Example: 433 >>> find_closure_group("(hi)", 0, ["(", ")"]) 434 (0, 3) 435 """ 436 437 inside_parenthesis = False 438 parens = 0 439 pos = start 440 p_start, p_end = -1, -1 441 442 while pos < len(input_string): 443 if input_string[pos] == group[0]: 444 if inside_parenthesis is False: 445 inside_parenthesis = True 446 parens = 1 447 p_start = pos 448 else: 449 parens += 1 450 elif input_string[pos] == group[1] and inside_parenthesis: 451 parens -= 1 452 453 if parens == 0: 454 p_end = pos 455 return p_start, p_end 456 457 pos += 1 458 return None, None 459 460 461def find_bracket_group(input_string, start): 462 """Finds the first balanced parantheses.""" 463 return find_closure_group(input_string, start, group=["{", "}"]) 464 465 466def find_parentheses_group(input_string, start): 467 """Finds the first balanced bracket.""" 468 return find_closure_group(input_string, start, group=["(", ")"]) 469 470 471RE_ASSERT = re.compile(r"\bassert[ ]*\(") 472 473 474def replace_math_functions(input_string): 475 """FIXME: Temporarily replace std:: invocations of math functions 476 with non-std:: versions to prevent linker errors NOTE: This 477 can lead to correctness issues when running tests, since the 478 correct version of the math function (exp/expf) might not get 479 called. Plan is to remove this function once HIP supports 480 std:: math function calls inside device code 481 482 """ 483 output_string = input_string 484 for func in MATH_TRANSPILATIONS: 485 output_string = output_string.replace(fr'{func}(', f'{MATH_TRANSPILATIONS[func]}(') 486 487 return output_string 488 489 490RE_SYNCTHREADS = re.compile(r":?:?\b(__syncthreads)\b(\w*\()") 491 492 493def hip_header_magic(input_string): 494 """If the file makes kernel builtin calls and does not include the cuda_runtime.h header, 495 then automatically add an #include to match the "magic" includes provided by NVCC. 496 TODO: 497 Update logic to ignore cases where the cuda_runtime.h is included by another file. 498 """ 499 500 # Copy the input. 501 output_string = input_string 502 503 # Check if one of the following headers is already included. 504 headers = ["hip/hip_runtime.h", "hip/hip_runtime_api.h"] 505 if any(re.search(fr'#include ("{ext}"|<{ext}>)', output_string) for ext in headers): 506 return output_string 507 508 # Rough logic to detect if we're inside device code 509 hasDeviceLogic: int 510 hasDeviceLogic = "hipLaunchKernelGGL" in output_string 511 hasDeviceLogic += "__global__" in output_string 512 hasDeviceLogic += "__shared__" in output_string 513 hasDeviceLogic += RE_SYNCTHREADS.search(output_string) is not None 514 515 # If device logic found, provide the necessary header. 516 if hasDeviceLogic: 517 output_string = '#include "hip/hip_runtime.h"\n' + input_string 518 519 return output_string 520 521 522RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;") 523 524 525def replace_extern_shared(input_string): 526 """Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead. 527 https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#__shared__ 528 Example: 529 "extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)" 530 "extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)" 531 """ 532 output_string = input_string 533 output_string = RE_EXTERN_SHARED.sub( 534 lambda inp: f"HIP_DYNAMIC_SHARED({inp.group(1) or ''} {inp.group(2)}, {inp.group(3)})", output_string) 535 536 return output_string 537 538 539def get_hip_file_path(rel_filepath, is_pytorch_extension=False): 540 """ 541 Returns the new name of the hipified file 542 """ 543 # At the moment, some PyTorch source files are HIPified in place. The predicate 544 # is_out_of_place tells us if this is the case or not. 545 assert not os.path.isabs(rel_filepath) 546 if not is_pytorch_extension and not is_out_of_place(rel_filepath): 547 return rel_filepath 548 549 dirpath, filename = os.path.split(rel_filepath) 550 root, ext = os.path.splitext(filename) 551 552 # Here's the plan: 553 # 554 # In general, we need to disambiguate the HIPified filename so that 555 # it gets a different name from the original filename, so 556 # that we don't overwrite the original file 557 # 558 # There's a lot of different naming conventions across PyTorch 559 # and Caffe2, but the general recipe is to convert occurrences 560 # of cuda/gpu to hip, and add hip if there are no occurrences 561 # of cuda/gpu anywhere. 562 # 563 # Concretely, we do the following: 564 # 565 # - If there is a directory component named "cuda", replace 566 # it with "hip", AND 567 # 568 # - If the file name contains "CUDA", replace it with "HIP", AND 569 # 570 # - ALWAYS replace '.cu' with '.hip', because those files 571 # contain CUDA kernels that needs to be hipified and processed with 572 # hip compiler 573 # 574 # - If we are not hipifying a PyTorch extension, and the parent 575 # directory name did not change as a result of the above 576 # transformations, insert "hip" in the file path 577 # as the direct parent folder of the file 578 # 579 # - If we are hipifying a PyTorch extension, and the parent directory 580 # name as well as the filename (incl. extension) did not change as 581 # a result of the above transformations, insert "_hip" in the filename 582 # 583 # This isn't set in stone; we might adjust this to support other 584 # naming conventions. 585 586 if ext == '.cu': 587 ext = '.hip' 588 589 orig_filename = filename 590 orig_dirpath = dirpath 591 592 dirpath = dirpath.replace('cuda', 'hip') 593 dirpath = dirpath.replace('CUDA', 'HIP') 594 dirpath = dirpath.replace('THC', 'THH') 595 596 root = root.replace('cuda', 'hip') 597 root = root.replace('CUDA', 'HIP') 598 # Special case to handle caffe2/core/THCCachingAllocator 599 if dirpath != "caffe2/core": 600 root = root.replace('THC', 'THH') 601 602 if not is_pytorch_extension and dirpath == orig_dirpath: 603 dirpath = os.path.join(dirpath, 'hip') 604 605 if is_pytorch_extension and dirpath == orig_dirpath and (root + ext) == orig_filename: 606 root = root + "_hip" 607 608 return os.path.join(dirpath, root + ext) 609 610 611def is_out_of_place(rel_filepath): 612 assert not os.path.isabs(rel_filepath) 613 if rel_filepath.startswith("torch/"): 614 return False 615 if rel_filepath.startswith("third_party/nvfuser/"): 616 return False 617 if rel_filepath.startswith("tools/autograd/templates/"): 618 return False 619 return True 620 621 622# Keep this synchronized with includes/ignores in build_amd.py 623def is_pytorch_file(rel_filepath): 624 assert not os.path.isabs(rel_filepath) 625 if rel_filepath.startswith("aten/"): 626 if rel_filepath.startswith("aten/src/ATen/core/"): 627 return False 628 return True 629 if rel_filepath.startswith("torch/"): 630 return True 631 if rel_filepath.startswith("third_party/nvfuser/"): 632 return True 633 if rel_filepath.startswith("tools/autograd/templates/"): 634 return True 635 return False 636 637 638def is_cusparse_file(rel_filepath): 639 if is_pytorch_file(rel_filepath): 640 return "sparse" in rel_filepath.lower() 641 return False 642 643 644def is_special_file(rel_filepath): 645 if is_pytorch_file(rel_filepath): 646 if "sparse" in rel_filepath.lower(): 647 return True 648 elif "linalg" in rel_filepath.lower(): 649 if "batchlinearalgebralibblas" in rel_filepath.lower(): 650 return False # don't use "special" mappings for this specific linalg cublas file 651 return True 652 return False 653 654def is_caffe2_gpu_file(rel_filepath): 655 assert not os.path.isabs(rel_filepath) 656 if rel_filepath.startswith("c10/cuda"): 657 return True 658 filename = os.path.basename(rel_filepath) 659 _, ext = os.path.splitext(filename) 660 return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename) 661 662class TrieNode: 663 """A Trie node whose children are represented as a directory of char: TrieNode. 664 A special char '' represents end of word 665 """ 666 667 def __init__(self): 668 self.children = {} 669 670class Trie: 671 """Creates a Trie out of a list of words. The trie can be exported to a Regex pattern. 672 The corresponding Regex should match much faster than a simple Regex union.""" 673 674 def __init__(self): 675 """Initialize the trie with an empty root node.""" 676 self.root = TrieNode() 677 678 def add(self, word): 679 """Add a word to the Trie. """ 680 node = self.root 681 682 for char in word: 683 node.children.setdefault(char, TrieNode()) 684 node = node.children[char] 685 node.children[''] = True # Mark the end of the word 686 687 def dump(self): 688 """Return the root node of Trie. """ 689 return self.root 690 691 def quote(self, char): 692 """ Escape a char for regex. """ 693 return re.escape(char) 694 695 def search(self, word): 696 """Search whether word is present in the Trie. 697 Returns True if yes, else return False""" 698 node = self.root 699 for char in word: 700 if char in node.children: 701 node = node.children[char] 702 else: 703 return False 704 705 # make sure to check the end-of-word marker present 706 return '' in node.children 707 708 def _pattern(self, root): 709 """Convert a Trie into a regular expression pattern""" 710 node = root 711 712 if "" in node.children and len(node.children.keys()) == 1: 713 return None 714 715 alt = [] # store alternative patterns 716 cc = [] # store char to char classes 717 q = 0 # for node representing the end of word 718 for char in sorted(node.children.keys()): 719 if isinstance(node.children[char], TrieNode): 720 try: 721 recurse = self._pattern(node.children[char]) 722 alt.append(self.quote(char) + recurse) 723 except Exception: 724 cc.append(self.quote(char)) 725 else: 726 q = 1 727 cconly = not len(alt) > 0 728 729 if len(cc) > 0: 730 if len(cc) == 1: 731 alt.append(cc[0]) 732 else: 733 alt.append('[' + ''.join(cc) + ']') 734 735 if len(alt) == 1: 736 result = alt[0] 737 else: 738 result = "(?:" + "|".join(alt) + ")" 739 740 if q: 741 if cconly: 742 result += "?" 743 else: 744 result = f"(?:{result})?" 745 return result 746 747 def pattern(self): 748 """Export the Trie to a regex pattern.""" 749 return self._pattern(self.root) 750 751 def export_to_regex(self): 752 """Export the Trie to a regex pattern.""" 753 return self._pattern(self.root) 754 755CAFFE2_TRIE = Trie() 756CAFFE2_MAP = {} 757PYTORCH_TRIE = Trie() 758PYTORCH_MAP: Dict[str, object] = {} 759 760# In PyTorch, we map cuBLAS->rocBLAS and cuSPARSE->hipSPARSE. Note the prefix, roc versus hip. 761# The 'hip' APIs offer a more direct CUDA-friendly mapping, but calling rocBLAS directly has better performance. 762# Unfortunately, the roc* types and hip* types differ, i.e., rocblas_float_complex versus hipComplex. 763# In the case of SPARSE, we must use the hip types for complex instead of the roc types, 764# but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority. 765# Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place. 766# When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices. 767# Similarly, "linalg" files require rocBLAS -> hipSOLVER so they also need special handling. 768PYTORCH_SPECIAL_MAP = {} 769 770for mapping in CUDA_TO_HIP_MAPPINGS: 771 assert isinstance(mapping, Mapping) 772 for src, value in mapping.items(): 773 dst = value[0] 774 meta_data = value[1:] 775 if constants.API_CAFFE2 not in meta_data: 776 PYTORCH_TRIE.add(src) 777 # if src is already in PYTORCH_MAP and dst belongs to API_SPECIAL 778 # do not overwrite PYTORCH_MAP, store dst separately 779 if constants.API_SPECIAL in meta_data and PYTORCH_MAP.get(src, ""): 780 PYTORCH_SPECIAL_MAP[src] = dst 781 else: 782 PYTORCH_MAP[src] = dst 783 if constants.API_PYTORCH not in meta_data and constants.API_SPECIAL not in meta_data: 784 CAFFE2_TRIE.add(src) 785 CAFFE2_MAP[src] = dst 786RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.export_to_regex()) 787RE_PYTORCH_PREPROCESSOR = re.compile(fr'(?<=\W)({PYTORCH_TRIE.export_to_regex()})(?=\W)') 788 789RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"') 790RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>') 791RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"') 792RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh 793 794""" 795Returns a HipifyResult object with the following details: 796 "hipified_path" : absolute path of hipified source file 797 "status" : "ok" if hipified file was written out 798 "skipped" if an identical hipified file already existed or hipified file couldn't be written out 799 "ignored" if the source file was a hipified file itself or not meant to be hipified 800 "current_state" : CurrentState.INITIALIZED if source file is first ready to be hipified 801 CurrentState.DONE if source file is done with hipification process 802""" 803 804 805def preprocessor( 806 output_directory: str, 807 filepath: str, 808 all_files: Iterable, 809 header_include_dirs: Iterable, 810 stats: Dict[str, List], 811 hip_clang_launch: bool, 812 is_pytorch_extension: bool, 813 clean_ctx: GeneratedFileCleaner, 814 show_progress: bool) -> HipifyResult: 815 """ Executes the CUDA -> HIP conversion on the specified file. """ 816 fin_path = os.path.abspath(os.path.join(output_directory, filepath)) 817 hipify_result = HIPIFY_FINAL_RESULT[fin_path] 818 if filepath not in all_files: 819 hipify_result.hipified_path = None 820 hipify_result.status = "[ignored, not to be hipified]" 821 hipify_result.current_state = CurrentState.DONE 822 return hipify_result 823 824 rel_filepath = os.path.relpath(filepath, output_directory) 825 826 with open(fin_path, encoding='utf-8') as fin: 827 if fin.readline() == HIPIFY_C_BREADCRUMB: 828 hipify_result.hipified_path = None 829 hipify_result.status = "[ignored, input is hipified output]" 830 hipify_result.current_state = CurrentState.DONE 831 return hipify_result 832 fin.seek(0) 833 output_source = fin.read() 834 835 orig_output_source = output_source 836 837 # get_hip_file_path needs a relative path to work correctly 838 fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(rel_filepath, is_pytorch_extension))) 839 if not os.path.exists(os.path.dirname(fout_path)): 840 clean_ctx.makedirs(os.path.dirname(fout_path)) 841 842 # unsupported_calls statistics reporting is broken atm 843 def pt_repl(m): 844 return PYTORCH_MAP[m.group(0)] 845 846 def pt_special_repl(m): 847 # checks SPECIAL map first, and if a miss occurs, falls back to pytorch mappings 848 return PYTORCH_SPECIAL_MAP.get(m.group(0), pt_repl(m)) 849 850 851 if is_pytorch_extension: 852 output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source) 853 else: 854 if is_special_file(rel_filepath): 855 output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_special_repl, output_source) 856 elif is_pytorch_file(rel_filepath): 857 output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source) 858 else: 859 def c2_repl(m): 860 return CAFFE2_MAP[m.group(0)] 861 output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source) 862 863 # Header rewrites 864 def mk_repl(templ, include_current_dir=True): 865 def repl(m): 866 f = m.group(1) 867 dirpath, filename = os.path.split(f) 868 if ( 869 f.startswith(("ATen/cuda", 870 "ATen/native/cuda", 871 "ATen/native/nested/cuda", 872 "ATen/native/quantized/cuda", 873 "ATen/native/sparse/cuda", 874 "ATen/native/transformers/cuda", 875 "THC/")) or 876 (f.startswith("THC") and not f.startswith("THCP")) 877 ): 878 return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension)) 879 # if filename is one of the files being hipified for this extension 880 if (is_pytorch_extension and any(s.endswith(filename) for s in all_files)): 881 header_dir = None 882 header_filepath = None 883 # If include_current_dir True, look first in same dir as the including source file 884 if include_current_dir: 885 header_dir_to_check = os.path.dirname(fin_path) 886 header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f)) 887 if os.path.exists(header_path_to_check): 888 header_dir = header_dir_to_check 889 header_filepath = header_path_to_check 890 # If not found, look in include dirs one by one and first match wins 891 if header_filepath is None: 892 for header_include_dir in header_include_dirs: 893 header_dir_to_check = os.path.join(output_directory, header_include_dir) 894 header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f)) 895 if os.path.exists(header_path_to_check): 896 header_dir = header_dir_to_check 897 header_filepath = header_path_to_check 898 # If header file not found, keep as is 899 if header_filepath is None: 900 return m.group(0) 901 # Hipify header file first if needed 902 if header_filepath not in HIPIFY_FINAL_RESULT: 903 preprocess_file_and_save_result(output_directory, 904 header_filepath, 905 all_files, header_include_dirs, stats, hip_clang_launch, 906 is_pytorch_extension, clean_ctx, show_progress) 907 elif header_filepath in HIPIFY_FINAL_RESULT: 908 header_result = HIPIFY_FINAL_RESULT[header_filepath] 909 if header_result.current_state == CurrentState.INITIALIZED: 910 # get_hip_file_path needs a relative path to work correctly 911 header_rel_path = os.path.relpath(header_filepath, output_directory) 912 header_fout_path = os.path.abspath(os.path.join(output_directory, 913 get_hip_file_path(header_rel_path, is_pytorch_extension))) 914 header_result.hipified_path = header_fout_path 915 HIPIFY_FINAL_RESULT[header_filepath] = header_result 916 return templ.format(os.path.relpath(header_fout_path if header_fout_path is not None 917 else header_filepath, header_dir)) 918 hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath].hipified_path 919 return templ.format(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None 920 else header_filepath, header_dir)) 921 922 return m.group(0) 923 return repl 924 output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"', True), output_source) 925 output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>', False), output_source) 926 output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source) 927 928 # CMakeLists.txt rewrites 929 if filepath.endswith('CMakeLists.txt'): 930 output_source = output_source.replace('CUDA', 'HIP') 931 output_source = output_source.replace('THC', 'THH') 932 output_source = RE_CU_SUFFIX.sub('.hip', output_source) 933 934 # Perform Kernel Launch Replacements 935 if not hip_clang_launch: 936 output_source = processKernelLaunches(output_source, stats) 937 938 # Replace std:: with non-std:: versions 939 if (filepath.endswith((".cu", ".cuh"))) and "PowKernel" not in filepath: 940 output_source = replace_math_functions(output_source) 941 942 # Include header if device code is contained. 943 output_source = hip_header_magic(output_source) 944 945 # Replace the extern __shared__ 946 # NOTE: No longer needed after transition from hcc to hipclang. 947 # output_source = replace_extern_shared(output_source) 948 949 # Don't write out identical hipified files for extensions if dirpath has not changed 950 if ( 951 is_pytorch_extension 952 and orig_output_source == output_source 953 and os.path.dirname(fin_path) == os.path.dirname(fout_path) 954 ): 955 hipify_result.hipified_path = fin_path 956 hipify_result.status = "[skipped, no changes]" 957 hipify_result.current_state = CurrentState.DONE 958 return hipify_result 959 960 # Add hipify breadcrumb for C-style files to avoid re-hipification 961 if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")): 962 output_source = HIPIFY_C_BREADCRUMB + output_source 963 964 do_write = True 965 if os.path.exists(fout_path): 966 with open(fout_path, encoding='utf-8') as fout_old: 967 do_write = fout_old.read() != output_source 968 if do_write: 969 try: 970 with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout: 971 fout.write(output_source) 972 hipify_result.hipified_path = fout_path 973 hipify_result.status = "[ok]" 974 hipify_result.current_state = CurrentState.DONE 975 return hipify_result 976 except PermissionError as e: 977 print(f'{bcolors.WARNING}Failed to save {fout_path} with "{e.strerror}", leaving {fin_path} unchanged.{bcolors.ENDC}', 978 file=sys.stderr) 979 hipify_result.hipified_path = fin_path 980 hipify_result.status = "[skipped, no permissions]" 981 hipify_result.current_state = CurrentState.DONE 982 return hipify_result 983 else: 984 hipify_result.hipified_path = fout_path 985 hipify_result.status = "[skipped, already hipified]" 986 hipify_result.current_state = CurrentState.DONE 987 return hipify_result 988 989def file_specific_replacement(filepath, search_string, replace_string, strict=False): 990 with openf(filepath, "r+") as f: 991 contents = f.read() 992 if strict: 993 contents = re.sub(fr'\b({re.escape(search_string)})\b', lambda x: replace_string, contents) 994 else: 995 contents = contents.replace(search_string, replace_string) 996 f.seek(0) 997 f.write(contents) 998 f.truncate() 999 1000 1001def file_add_header(filepath, header): 1002 with openf(filepath, "r+") as f: 1003 contents = f.read() 1004 if header[0] != "<" and header[-1] != ">": 1005 header = f'"{header}"' 1006 contents = (f'#include {header} \n') + contents 1007 f.seek(0) 1008 f.write(contents) 1009 f.truncate() 1010 1011 1012def fix_static_global_kernels(in_txt): 1013 """Static global kernels in HIP results in a compilation error.""" 1014 in_txt = in_txt.replace(" __global__ static", "__global__") 1015 return in_txt 1016 1017 1018RE_INCLUDE = re.compile(r"#include .*\n") 1019 1020 1021def extract_arguments(start, string): 1022 """ Return the list of arguments in the upcoming function parameter closure. 1023 Example: 1024 string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))' 1025 arguments (output): 1026 '[{'start': 1, 'end': 7}, 1027 {'start': 8, 'end': 16}, 1028 {'start': 17, 'end': 19}, 1029 {'start': 20, 'end': 53}]' 1030 """ 1031 1032 arguments = [] 1033 closures = { 1034 "<": 0, 1035 "(": 0 1036 } 1037 current_position = start 1038 argument_start_pos = current_position + 1 1039 1040 # Search for final parenthesis 1041 while current_position < len(string): 1042 if string[current_position] == "(": 1043 closures["("] += 1 1044 elif string[current_position] == ")": 1045 closures["("] -= 1 1046 elif string[current_position] == "<": 1047 closures["<"] += 1 1048 elif string[current_position] == ">" and string[current_position - 1] != "-" and closures["<"] > 0: 1049 closures["<"] -= 1 1050 1051 # Finished all arguments 1052 if closures["("] == 0 and closures["<"] == 0: 1053 # Add final argument 1054 arguments.append({"start": argument_start_pos, "end": current_position}) 1055 break 1056 1057 # Finished current argument 1058 if closures["("] == 1 and closures["<"] == 0 and string[current_position] == ",": 1059 arguments.append({"start": argument_start_pos, "end": current_position}) 1060 argument_start_pos = current_position + 1 1061 1062 current_position += 1 1063 1064 return arguments 1065 1066 1067def str2bool(v): 1068 """ArgumentParser doesn't support type=bool. Thus, this helper method will convert 1069 from possible string types to True / False.""" 1070 if v.lower() in ('yes', 'true', 't', 'y', '1'): 1071 return True 1072 elif v.lower() in ('no', 'false', 'f', 'n', '0'): 1073 return False 1074 else: 1075 raise argparse.ArgumentTypeError('Boolean value expected.') 1076 1077 1078def hipify( 1079 project_directory: str, 1080 show_detailed: bool = False, 1081 extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"), 1082 header_extensions: Iterable = (".cuh", ".h", ".hpp"), 1083 output_directory: str = "", 1084 header_include_dirs: Iterable = (), 1085 includes: Iterable = ('*',), 1086 extra_files: Iterable = (), 1087 out_of_place_only: bool = False, 1088 ignores: Iterable = (), 1089 show_progress: bool = True, 1090 hip_clang_launch: bool = False, 1091 is_pytorch_extension: bool = False, 1092 hipify_extra_files_only: bool = False, 1093 clean_ctx: Optional[GeneratedFileCleaner] = None 1094) -> HipifyFinalResult: 1095 if project_directory == "": 1096 project_directory = os.getcwd() 1097 1098 # Verify the project directory exists. 1099 if not os.path.exists(project_directory): 1100 print("The project folder specified does not exist.") 1101 sys.exit(1) 1102 1103 # If no output directory, provide a default one. 1104 if not output_directory: 1105 project_directory.rstrip("/") 1106 output_directory = project_directory + "_amd" 1107 1108 if project_directory != output_directory: 1109 includes = [include.replace(project_directory, output_directory) for include in includes] 1110 ignores = [ignore.replace(project_directory, output_directory) for ignore in ignores] 1111 1112 # Copy from project directory to output directory if not done already. 1113 if not os.path.exists(output_directory): 1114 shutil.copytree(project_directory, output_directory) 1115 1116 all_files = list(matched_files_iter(output_directory, includes=includes, 1117 ignores=ignores, extensions=extensions, 1118 out_of_place_only=out_of_place_only, 1119 is_pytorch_extension=is_pytorch_extension)) 1120 all_files_set = set(all_files) 1121 for f in extra_files: 1122 if not os.path.isabs(f): 1123 f = os.path.join(output_directory, f) 1124 if f not in all_files_set: 1125 all_files.append(f) 1126 1127 # List all files in header_include_paths to ensure they are hipified 1128 from pathlib import Path 1129 for header_include_dir in header_include_dirs: 1130 if os.path.isabs(header_include_dir): 1131 header_include_dir_path = Path(header_include_dir) 1132 else: 1133 header_include_dir_path = Path(os.path.join(output_directory, header_include_dir)) 1134 for path in header_include_dir_path.rglob('*'): 1135 if ( 1136 path.is_file() 1137 and _fnmatch(str(path), includes) 1138 and (not _fnmatch(str(path), ignores)) 1139 and match_extensions(path.name, header_extensions) 1140 ): 1141 all_files.append(str(path)) 1142 1143 if clean_ctx is None: 1144 clean_ctx = GeneratedFileCleaner(keep_intermediates=True) 1145 1146 # Preprocessing statistics. 1147 stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []} 1148 1149 for filepath in (all_files if not hipify_extra_files_only else extra_files): 1150 preprocess_file_and_save_result(output_directory, filepath, all_files, header_include_dirs, 1151 stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) 1152 1153 print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr) 1154 1155 # Show detailed summary 1156 if show_detailed: 1157 compute_stats(stats) 1158 1159 return HIPIFY_FINAL_RESULT 1160