xref: /aosp_15_r20/external/pytorch/torch/utils/hipify/hipify_python.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3""" The Python Hipify script.
4##
5# Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved.
6#               2017-2018 Advanced Micro Devices, Inc. and
7#                         Facebook Inc. All rights reserved.
8#
9# Permission is hereby granted, free of charge, to any person obtaining a copy
10# of this software and associated documentation files (the "Software"), to deal
11# in the Software without restriction, including without limitation the rights
12# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13# copies of the Software, and to permit persons to whom the Software is
14# furnished to do so, subject to the following conditions:
15#
16# The above copyright notice and this permission notice shall be included in
17# all copies or substantial portions of the Software.
18#
19# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
22# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
25# THE SOFTWARE.
26"""
27import argparse
28import fnmatch
29import re
30import shutil
31import sys
32import os
33
34from . import constants
35from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
36from .cuda_to_hip_mappings import MATH_TRANSPILATIONS
37
38from typing import Dict, List, Iterator, Optional
39from collections.abc import Mapping, Iterable
40from enum import Enum
41
42class CurrentState(Enum):
43    INITIALIZED = 1
44    DONE = 2
45
46class HipifyResult:
47    def __init__(self, current_state, hipified_path):
48        self.current_state = current_state
49        self.hipified_path = hipified_path
50        self.status = ""
51
52    def __str__(self):
53        return (f"HipifyResult:: current_state: {self.current_state}, hipified_path : {self.hipified_path}, status: {self.status}")
54
55HipifyFinalResult = Dict[str, HipifyResult]
56HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n"
57HIPIFY_FINAL_RESULT: HipifyFinalResult = {}
58
59# Hardcode the PyTorch template map
60"""This dictionary provides the mapping from PyTorch kernel template types
61to their actual types."""
62PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"}
63
64__all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter',
65           'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group',
66           'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared',
67           'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_special_file', 'is_caffe2_gpu_file',
68           'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header',
69           'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'CurrentState', 'HipifyResult', 'hipify']
70
71
72class InputError(Exception):
73    # Exception raised for errors in the input.
74
75    def __init__(self, message):
76        super().__init__(message)
77        self.message = message
78
79    def __str__(self):
80        return f"Input error: {self.message}"
81
82
83def openf(filename, mode):
84    return open(filename, mode, errors='ignore')
85
86
87# Color coding for printing
88class bcolors:
89    HEADER = '\033[95m'
90    OKBLUE = '\033[94m'
91    OKGREEN = '\033[92m'
92    WARNING = '\033[93m'
93    FAIL = '\033[91m'
94    ENDC = '\033[0m'
95    BOLD = '\033[1m'
96    UNDERLINE = '\033[4m'
97
98
99# To the programmer, the output of hipify most likely are intermediates.
100# This class allows users of hipify to ask for a cleanup by running the
101# hipify and compilation in a with instantiating this context manager class
102# with keep_intermediates=False.
103# The main usecase is the cpp_extensions, specifically the load method.
104# It is a good idea to keep intermediates (in case of errors or to
105# not recompile unchanged files), but in cases where you don't want to
106# keep them (e.g. in the CI), this can be used to remove files.
107class GeneratedFileCleaner:
108    """Context Manager to clean up generated files"""
109    def __init__(self, keep_intermediates=False):
110        self.keep_intermediates = keep_intermediates
111        self.files_to_clean = set()
112        self.dirs_to_clean = []
113
114    def __enter__(self):
115        return self
116
117    def open(self, fn, *args, **kwargs):
118        if not os.path.exists(fn):
119            self.files_to_clean.add(os.path.abspath(fn))
120        return open(fn, *args, **kwargs)
121
122    def makedirs(self, dn, exist_ok=False):
123        parent, n = os.path.split(dn)
124        if not n:
125            parent, n = os.path.split(parent)
126        if parent and n and not os.path.exists(parent):
127            self.makedirs(parent, exist_ok=True)
128        if not os.path.isdir(dn) or not exist_ok:
129            os.mkdir(dn)
130            self.dirs_to_clean.append(os.path.abspath(dn))
131
132    def __exit__(self, type, value, traceback):
133        if not self.keep_intermediates:
134            for f in self.files_to_clean:
135                os.unlink(f)
136            for d in self.dirs_to_clean[::-1]:
137                os.rmdir(d)
138
139
140def match_extensions(filename: str, extensions: Iterable) -> bool:
141    """Helper method to see if filename ends with certain extension"""
142    return any(filename.endswith(e) for e in extensions)
143
144
145def _fnmatch(filepath, patterns):
146    return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
147
148
149def matched_files_iter(
150        root_path: str,
151        includes: Iterable = (),
152        ignores: Iterable = (),
153        extensions: Iterable = (),
154        out_of_place_only: bool = False,
155        is_pytorch_extension: bool = False) -> Iterator[str]:
156
157    exact_matches = set(includes)
158
159    # This is a very rough heuristic; really, we want to avoid scanning
160    # any file which is not checked into source control, but this script
161    # needs to work even if you're in a Git or Hg checkout, so easier to
162    # just block the biggest time sinks that won't matter in the
163    # end.
164    for (abs_dirpath, dirs, filenames) in os.walk(root_path, topdown=True):
165        rel_dirpath = os.path.relpath(abs_dirpath, root_path)
166        if rel_dirpath == '.':
167            # Blah blah blah O(n) blah blah
168            if ".git" in dirs:
169                dirs.remove(".git")
170            if "build" in dirs:
171                dirs.remove("build")
172            if "third_party" in dirs:
173                dirs.remove("third_party")
174                dirs.append("third_party/nvfuser")
175        for filename in filenames:
176            filepath = os.path.join(abs_dirpath, filename)
177            rel_filepath = os.path.join(rel_dirpath, filename)
178            # We respect extensions, UNLESS you wrote the entire
179            # filename verbatim, in which case we always accept it
180            if (
181                _fnmatch(filepath, includes)
182                and (not _fnmatch(filepath, ignores))
183                and (match_extensions(filepath, extensions) or filepath in exact_matches)
184            ):
185                if not is_pytorch_extension:  # for pytorch extensions, consider all files
186                    if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(rel_filepath):
187                        continue
188                    if out_of_place_only and not is_out_of_place(rel_filepath):
189                        continue
190                yield filepath
191
192
193def preprocess_file_and_save_result(
194        output_directory: str,
195        filepath: str,
196        all_files: Iterable,
197        header_include_dirs: Iterable,
198        stats: Dict[str, List],
199        hip_clang_launch: bool,
200        is_pytorch_extension: bool,
201        clean_ctx: GeneratedFileCleaner,
202        show_progress: bool) -> None:
203    fin_path = os.path.abspath(os.path.join(output_directory, filepath))
204    hipify_result = HipifyResult(current_state=CurrentState.INITIALIZED, hipified_path=fin_path)
205    HIPIFY_FINAL_RESULT[fin_path] = hipify_result
206    result = preprocessor(output_directory, filepath, all_files, header_include_dirs, stats,
207                          hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
208
209    # Show what happened
210    if show_progress and "ignored" not in result.status:
211        print(
212            fin_path, "->",
213            result.hipified_path, result.status, flush=True)
214
215    HIPIFY_FINAL_RESULT[fin_path] = result
216
217
218def compute_stats(stats):
219    unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]}
220
221    # Print the number of unsupported calls
222    print(f"Total number of unsupported CUDA function calls: {len(unsupported_calls):d}")
223
224    # Print the list of unsupported calls
225    print(", ".join(unsupported_calls))
226
227    # Print the number of kernel launches
228    print(f"\nTotal number of replaced kernel launches: {len(stats['kernel_launches']):d}")
229
230
231def add_dim3(kernel_string, cuda_kernel):
232    '''adds dim3() to the second and third arguments in the kernel launch'''
233    count = 0
234    closure = 0
235    kernel_string = kernel_string.replace("<<<", "").replace(">>>", "")
236    arg_locs: List[Dict[str, int]] = [{} for _ in range(2)]
237    arg_locs[count]['start'] = 0
238    for ind, c in enumerate(kernel_string):
239        if count > 1:
240            break
241        if c == "(":
242            closure += 1
243        elif c == ")":
244            closure -= 1
245        if (c == "," or ind == len(kernel_string) - 1) and closure == 0:
246            arg_locs[count]['end'] = ind + (c != ",")
247            count += 1
248            if count < 2:
249                arg_locs[count]['start'] = ind + 1
250
251    first_arg_raw = kernel_string[arg_locs[0]['start']:arg_locs[0]['end'] + 1]
252    second_arg_raw = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']]
253
254    first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ")
255    second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ")
256
257    first_arg_dim3 = f"dim3({first_arg_clean})"
258    second_arg_dim3 = f"dim3({second_arg_clean})"
259
260    first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3)
261    second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3)
262    cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3)
263    return cuda_kernel
264
265
266RE_KERNEL_LAUNCH = re.compile(r'([ ]+)(detail?)::[ ]+\\\n[ ]+')
267
268
269def processKernelLaunches(string, stats):
270    """ Replace the CUDA style Kernel launches with the HIP style kernel launches."""
271    # Concat the namespace with the kernel names. (Find cleaner way of doing this later).
272    string = RE_KERNEL_LAUNCH.sub(lambda inp: f"{inp.group(1)}{inp.group(2)}::", string)
273
274    def grab_method_and_template(in_kernel):
275        # The positions for relevant kernel components.
276        pos = {
277            "kernel_launch": {"start": in_kernel["start"], "end": in_kernel["end"]},
278            "kernel_name": {"start": -1, "end": -1},
279            "template": {"start": -1, "end": -1}
280        }
281
282        # Count for balancing template
283        count = {"<>": 0}
284
285        # Status for whether we are parsing a certain item.
286        START = 0
287        AT_TEMPLATE = 1
288        AFTER_TEMPLATE = 2
289        AT_KERNEL_NAME = 3
290
291        status = START
292
293        # Parse the string character by character
294        for i in range(pos["kernel_launch"]["start"] - 1, -1, -1):
295            char = string[i]
296
297            # Handle Templating Arguments
298            if status in (START, AT_TEMPLATE):
299                if char == ">":
300                    if status == START:
301                        status = AT_TEMPLATE
302                        pos["template"]["end"] = i
303                    count["<>"] += 1
304
305                if char == "<":
306                    count["<>"] -= 1
307                    if count["<>"] == 0 and (status == AT_TEMPLATE):
308                        pos["template"]["start"] = i
309                        status = AFTER_TEMPLATE
310
311            # Handle Kernel Name
312            if status != AT_TEMPLATE:
313                if string[i].isalnum() or string[i] in {'(', ')', '_', ':', '#'}:
314                    if status != AT_KERNEL_NAME:
315                        status = AT_KERNEL_NAME
316                        pos["kernel_name"]["end"] = i
317
318                    # Case: Kernel name starts the string.
319                    if i == 0:
320                        pos["kernel_name"]["start"] = 0
321
322                        # Finished
323                        return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]
324
325                else:
326                    # Potential ending point if we're already traversing a kernel's name.
327                    if status == AT_KERNEL_NAME:
328                        pos["kernel_name"]["start"] = i
329
330                        # Finished
331                        return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]
332
333    def find_kernel_bounds(string):
334        """Finds the starting and ending points for all kernel launches in the string."""
335        kernel_end = 0
336        kernel_positions = []
337
338        # Continue until we cannot find any more kernels anymore.
339        while string.find("<<<", kernel_end) != -1:
340            # Get kernel starting position (starting from the previous ending point)
341            kernel_start = string.find("<<<", kernel_end)
342
343            # Get kernel ending position (adjust end point past the >>>)
344            kernel_end = string.find(">>>", kernel_start) + 3
345            if kernel_end <= 0:
346                raise InputError("no kernel end found")
347
348            # Add to list of traversed kernels
349            kernel_positions.append({"start": kernel_start, "end": kernel_end,
350                                     "group": string[kernel_start: kernel_end]})
351
352        return kernel_positions
353
354    # Replace comments and string literals from the code so that find_kernel_bounds does not
355    # wrongly capture kernels in comments and string literals.
356    # This function replaces them with "x" to keep positions.
357    def mask_comments(string):
358        in_comment = ''
359        prev_c = ''
360        new_string = ''
361        for c in string:
362            if in_comment == '':
363                # Outside comments
364                if c == '/' and prev_c == '/':
365                    in_comment = '//'
366                elif c == '*' and prev_c == '/':
367                    in_comment = '/*'
368                elif c == '"' and prev_c != '\\' and prev_c != "'":
369                    in_comment = '"'
370            elif in_comment == '//':
371                # In // xxx
372                if c == '\r' or c == '\n':
373                    in_comment = ''
374            elif in_comment == '/*':
375                # In /* xxx */
376                if c == '/' and prev_c == '*':
377                    in_comment = ''
378            elif in_comment == '"':
379                # In ""
380                if c == '"' and prev_c != '\\':
381                    in_comment = ''
382            prev_c = c
383            if in_comment == '':
384                new_string += c
385            else:
386                new_string += 'x'
387        return new_string
388
389    # Grab positional ranges of all kernel launches
390    get_kernel_positions = list(find_kernel_bounds(mask_comments(string)))
391    output_string = string
392
393    # Replace each CUDA kernel with a HIP kernel.
394    for kernel in get_kernel_positions:
395        # Get kernel components
396        params = grab_method_and_template(kernel)
397
398        # Find parenthesis after kernel launch
399        parenthesis = string.find("(", kernel["end"])
400
401        # Extract cuda kernel
402        cuda_kernel = string[params[0]["start"]:parenthesis + 1]
403        kernel_string = string[kernel['start']:kernel['end']]
404        end_param_index = 0 if params[1]['end'] == -1 else 1
405        kernel_name_with_template = string[params[0]['start']:params[end_param_index]['end'] + 1]
406        cuda_kernel_dim3 = add_dim3(kernel_string, cuda_kernel)
407        # Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size)
408        num_klp = len(extract_arguments(0, kernel["group"].replace("<<<", "(").replace(">>>", ")")))
409
410        hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace(
411            ">>>", ", 0" * (4 - num_klp) + ">>>").replace("<<<", ", ").replace(
412            ">>>", ", ").replace(kernel_name_with_template, "(" + kernel_name_with_template + ")")
413
414        # Replace cuda kernel with hip kernel
415        output_string = output_string.replace(cuda_kernel, hip_kernel)
416
417        # Update the statistics
418        stats["kernel_launches"].append(hip_kernel)
419
420    return output_string
421
422
423def find_closure_group(input_string, start, group):
424    """Generalization for finding a balancing closure group
425
426         if group = ["(", ")"], then finds the first balanced parentheses.
427         if group = ["{", "}"], then finds the first balanced bracket.
428
429    Given an input string, a starting position in the input string, and the group type,
430    find_closure_group returns the positions of group[0] and group[1] as a tuple.
431
432    Example:
433        >>> find_closure_group("(hi)", 0, ["(", ")"])
434        (0, 3)
435    """
436
437    inside_parenthesis = False
438    parens = 0
439    pos = start
440    p_start, p_end = -1, -1
441
442    while pos < len(input_string):
443        if input_string[pos] == group[0]:
444            if inside_parenthesis is False:
445                inside_parenthesis = True
446                parens = 1
447                p_start = pos
448            else:
449                parens += 1
450        elif input_string[pos] == group[1] and inside_parenthesis:
451            parens -= 1
452
453            if parens == 0:
454                p_end = pos
455                return p_start, p_end
456
457        pos += 1
458    return None, None
459
460
461def find_bracket_group(input_string, start):
462    """Finds the first balanced parantheses."""
463    return find_closure_group(input_string, start, group=["{", "}"])
464
465
466def find_parentheses_group(input_string, start):
467    """Finds the first balanced bracket."""
468    return find_closure_group(input_string, start, group=["(", ")"])
469
470
471RE_ASSERT = re.compile(r"\bassert[ ]*\(")
472
473
474def replace_math_functions(input_string):
475    """FIXME: Temporarily replace std:: invocations of math functions
476        with non-std:: versions to prevent linker errors NOTE: This
477        can lead to correctness issues when running tests, since the
478        correct version of the math function (exp/expf) might not get
479        called.  Plan is to remove this function once HIP supports
480        std:: math function calls inside device code
481
482    """
483    output_string = input_string
484    for func in MATH_TRANSPILATIONS:
485        output_string = output_string.replace(fr'{func}(', f'{MATH_TRANSPILATIONS[func]}(')
486
487    return output_string
488
489
490RE_SYNCTHREADS = re.compile(r":?:?\b(__syncthreads)\b(\w*\()")
491
492
493def hip_header_magic(input_string):
494    """If the file makes kernel builtin calls and does not include the cuda_runtime.h header,
495    then automatically add an #include to match the "magic" includes provided by NVCC.
496    TODO:
497        Update logic to ignore cases where the cuda_runtime.h is included by another file.
498    """
499
500    # Copy the input.
501    output_string = input_string
502
503    # Check if one of the following headers is already included.
504    headers = ["hip/hip_runtime.h", "hip/hip_runtime_api.h"]
505    if any(re.search(fr'#include ("{ext}"|<{ext}>)', output_string) for ext in headers):
506        return output_string
507
508    # Rough logic to detect if we're inside device code
509    hasDeviceLogic: int
510    hasDeviceLogic = "hipLaunchKernelGGL" in output_string
511    hasDeviceLogic += "__global__" in output_string
512    hasDeviceLogic += "__shared__" in output_string
513    hasDeviceLogic += RE_SYNCTHREADS.search(output_string) is not None
514
515    # If device logic found, provide the necessary header.
516    if hasDeviceLogic:
517        output_string = '#include "hip/hip_runtime.h"\n' + input_string
518
519    return output_string
520
521
522RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;")
523
524
525def replace_extern_shared(input_string):
526    """Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
527       https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#__shared__
528    Example:
529        "extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
530        "extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
531    """
532    output_string = input_string
533    output_string = RE_EXTERN_SHARED.sub(
534        lambda inp: f"HIP_DYNAMIC_SHARED({inp.group(1) or ''} {inp.group(2)}, {inp.group(3)})", output_string)
535
536    return output_string
537
538
539def get_hip_file_path(rel_filepath, is_pytorch_extension=False):
540    """
541    Returns the new name of the hipified file
542    """
543    # At the moment, some PyTorch source files are HIPified in place.  The predicate
544    # is_out_of_place tells us if this is the case or not.
545    assert not os.path.isabs(rel_filepath)
546    if not is_pytorch_extension and not is_out_of_place(rel_filepath):
547        return rel_filepath
548
549    dirpath, filename = os.path.split(rel_filepath)
550    root, ext = os.path.splitext(filename)
551
552    # Here's the plan:
553    #
554    # In general, we need to disambiguate the HIPified filename so that
555    # it gets a different name from the original filename, so
556    # that we don't overwrite the original file
557    #
558    # There's a lot of different naming conventions across PyTorch
559    # and Caffe2, but the general recipe is to convert occurrences
560    # of cuda/gpu to hip, and add hip if there are no occurrences
561    # of cuda/gpu anywhere.
562    #
563    # Concretely, we do the following:
564    #
565    #   - If there is a directory component named "cuda", replace
566    #     it with "hip", AND
567    #
568    #   - If the file name contains "CUDA", replace it with "HIP", AND
569    #
570    #   - ALWAYS replace '.cu' with '.hip', because those files
571    #     contain CUDA kernels that needs to be hipified and processed with
572    #     hip compiler
573    #
574    #   - If we are not hipifying a PyTorch extension, and the parent
575    #     directory name did not change as a result of the above
576    #     transformations, insert "hip" in the file path
577    #     as the direct parent folder of the file
578    #
579    #   - If we are hipifying a PyTorch extension, and the parent directory
580    #     name as well as the filename (incl. extension) did not change as
581    #     a result of the above transformations, insert "_hip" in the filename
582    #
583    # This isn't set in stone; we might adjust this to support other
584    # naming conventions.
585
586    if ext == '.cu':
587        ext = '.hip'
588
589    orig_filename = filename
590    orig_dirpath = dirpath
591
592    dirpath = dirpath.replace('cuda', 'hip')
593    dirpath = dirpath.replace('CUDA', 'HIP')
594    dirpath = dirpath.replace('THC', 'THH')
595
596    root = root.replace('cuda', 'hip')
597    root = root.replace('CUDA', 'HIP')
598    # Special case to handle caffe2/core/THCCachingAllocator
599    if dirpath != "caffe2/core":
600        root = root.replace('THC', 'THH')
601
602    if not is_pytorch_extension and dirpath == orig_dirpath:
603        dirpath = os.path.join(dirpath, 'hip')
604
605    if is_pytorch_extension and dirpath == orig_dirpath and (root + ext) == orig_filename:
606        root = root + "_hip"
607
608    return os.path.join(dirpath, root + ext)
609
610
611def is_out_of_place(rel_filepath):
612    assert not os.path.isabs(rel_filepath)
613    if rel_filepath.startswith("torch/"):
614        return False
615    if rel_filepath.startswith("third_party/nvfuser/"):
616        return False
617    if rel_filepath.startswith("tools/autograd/templates/"):
618        return False
619    return True
620
621
622# Keep this synchronized with includes/ignores in build_amd.py
623def is_pytorch_file(rel_filepath):
624    assert not os.path.isabs(rel_filepath)
625    if rel_filepath.startswith("aten/"):
626        if rel_filepath.startswith("aten/src/ATen/core/"):
627            return False
628        return True
629    if rel_filepath.startswith("torch/"):
630        return True
631    if rel_filepath.startswith("third_party/nvfuser/"):
632        return True
633    if rel_filepath.startswith("tools/autograd/templates/"):
634        return True
635    return False
636
637
638def is_cusparse_file(rel_filepath):
639    if is_pytorch_file(rel_filepath):
640        return "sparse" in rel_filepath.lower()
641    return False
642
643
644def is_special_file(rel_filepath):
645    if is_pytorch_file(rel_filepath):
646        if "sparse" in rel_filepath.lower():
647            return True
648        elif "linalg" in rel_filepath.lower():
649            if "batchlinearalgebralibblas" in rel_filepath.lower():
650                return False  # don't use "special" mappings for this specific linalg cublas file
651            return True
652    return False
653
654def is_caffe2_gpu_file(rel_filepath):
655    assert not os.path.isabs(rel_filepath)
656    if rel_filepath.startswith("c10/cuda"):
657        return True
658    filename = os.path.basename(rel_filepath)
659    _, ext = os.path.splitext(filename)
660    return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
661
662class TrieNode:
663    """A Trie node whose children are represented as a directory of char: TrieNode.
664       A special char '' represents end of word
665    """
666
667    def __init__(self):
668        self.children = {}
669
670class Trie:
671    """Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
672    The corresponding Regex should match much faster than a simple Regex union."""
673
674    def __init__(self):
675        """Initialize the trie with an empty root node."""
676        self.root = TrieNode()
677
678    def add(self, word):
679        """Add a word to the Trie. """
680        node = self.root
681
682        for char in word:
683            node.children.setdefault(char, TrieNode())
684            node = node.children[char]
685        node.children[''] = True    # Mark the end of the word
686
687    def dump(self):
688        """Return the root node of Trie. """
689        return self.root
690
691    def quote(self, char):
692        """ Escape a char for regex. """
693        return re.escape(char)
694
695    def search(self, word):
696        """Search whether word is present in the Trie.
697        Returns True if yes, else return False"""
698        node = self.root
699        for char in word:
700            if char in node.children:
701                node = node.children[char]
702            else:
703                return False
704
705        # make sure to check the end-of-word marker present
706        return '' in node.children
707
708    def _pattern(self, root):
709        """Convert a Trie into a regular expression pattern"""
710        node = root
711
712        if "" in node.children and len(node.children.keys()) == 1:
713            return None
714
715        alt = []    # store alternative patterns
716        cc = []     # store char to char classes
717        q = 0       # for node representing the end of word
718        for char in sorted(node.children.keys()):
719            if isinstance(node.children[char], TrieNode):
720                try:
721                    recurse = self._pattern(node.children[char])
722                    alt.append(self.quote(char) + recurse)
723                except Exception:
724                    cc.append(self.quote(char))
725            else:
726                q = 1
727        cconly = not len(alt) > 0
728
729        if len(cc) > 0:
730            if len(cc) == 1:
731                alt.append(cc[0])
732            else:
733                alt.append('[' + ''.join(cc) + ']')
734
735        if len(alt) == 1:
736            result = alt[0]
737        else:
738            result = "(?:" + "|".join(alt) + ")"
739
740        if q:
741            if cconly:
742                result += "?"
743            else:
744                result = f"(?:{result})?"
745        return result
746
747    def pattern(self):
748        """Export the Trie to a regex pattern."""
749        return self._pattern(self.root)
750
751    def export_to_regex(self):
752        """Export the Trie to a regex pattern."""
753        return self._pattern(self.root)
754
755CAFFE2_TRIE = Trie()
756CAFFE2_MAP = {}
757PYTORCH_TRIE = Trie()
758PYTORCH_MAP: Dict[str, object] = {}
759
760# In PyTorch, we map cuBLAS->rocBLAS and cuSPARSE->hipSPARSE. Note the prefix, roc versus hip.
761# The 'hip' APIs offer a more direct CUDA-friendly mapping, but calling rocBLAS directly has better performance.
762# Unfortunately, the roc* types and hip* types differ, i.e., rocblas_float_complex versus hipComplex.
763# In the case of SPARSE, we must use the hip types for complex instead of the roc types,
764# but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority.
765# Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place.
766# When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices.
767# Similarly, "linalg" files require rocBLAS -> hipSOLVER so they also need special handling.
768PYTORCH_SPECIAL_MAP = {}
769
770for mapping in CUDA_TO_HIP_MAPPINGS:
771    assert isinstance(mapping, Mapping)
772    for src, value in mapping.items():
773        dst = value[0]
774        meta_data = value[1:]
775        if constants.API_CAFFE2 not in meta_data:
776            PYTORCH_TRIE.add(src)
777            # if src is already in PYTORCH_MAP and dst belongs to API_SPECIAL
778            # do not overwrite PYTORCH_MAP, store dst separately
779            if constants.API_SPECIAL in meta_data and PYTORCH_MAP.get(src, ""):
780                PYTORCH_SPECIAL_MAP[src] = dst
781            else:
782                PYTORCH_MAP[src] = dst
783        if constants.API_PYTORCH not in meta_data and constants.API_SPECIAL not in meta_data:
784            CAFFE2_TRIE.add(src)
785            CAFFE2_MAP[src] = dst
786RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.export_to_regex())
787RE_PYTORCH_PREPROCESSOR = re.compile(fr'(?<=\W)({PYTORCH_TRIE.export_to_regex()})(?=\W)')
788
789RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"')
790RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>')
791RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"')
792RE_CU_SUFFIX = re.compile(r'\.cu\b')  # be careful not to pick up .cuh
793
794"""
795Returns a HipifyResult object with the following details:
796    "hipified_path" : absolute path of hipified source file
797    "status"        : "ok"      if hipified file was written out
798                      "skipped" if an identical hipified file already existed or hipified file couldn't be written out
799                      "ignored" if the source file was a hipified file itself or not meant to be hipified
800    "current_state" : CurrentState.INITIALIZED if source file is first ready to be hipified
801                      CurrentState.DONE if source file is done with hipification process
802"""
803
804
805def preprocessor(
806        output_directory: str,
807        filepath: str,
808        all_files: Iterable,
809        header_include_dirs: Iterable,
810        stats: Dict[str, List],
811        hip_clang_launch: bool,
812        is_pytorch_extension: bool,
813        clean_ctx: GeneratedFileCleaner,
814        show_progress: bool) -> HipifyResult:
815    """ Executes the CUDA -> HIP conversion on the specified file. """
816    fin_path = os.path.abspath(os.path.join(output_directory, filepath))
817    hipify_result = HIPIFY_FINAL_RESULT[fin_path]
818    if filepath not in all_files:
819        hipify_result.hipified_path = None
820        hipify_result.status = "[ignored, not to be hipified]"
821        hipify_result.current_state = CurrentState.DONE
822        return hipify_result
823
824    rel_filepath = os.path.relpath(filepath, output_directory)
825
826    with open(fin_path, encoding='utf-8') as fin:
827        if fin.readline() == HIPIFY_C_BREADCRUMB:
828            hipify_result.hipified_path = None
829            hipify_result.status = "[ignored, input is hipified output]"
830            hipify_result.current_state = CurrentState.DONE
831            return hipify_result
832        fin.seek(0)
833        output_source = fin.read()
834
835    orig_output_source = output_source
836
837    # get_hip_file_path needs a relative path to work correctly
838    fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(rel_filepath, is_pytorch_extension)))
839    if not os.path.exists(os.path.dirname(fout_path)):
840        clean_ctx.makedirs(os.path.dirname(fout_path))
841
842    # unsupported_calls statistics reporting is broken atm
843    def pt_repl(m):
844        return PYTORCH_MAP[m.group(0)]
845
846    def pt_special_repl(m):
847        # checks SPECIAL map first, and if a miss occurs, falls back to pytorch mappings
848        return PYTORCH_SPECIAL_MAP.get(m.group(0), pt_repl(m))
849
850
851    if is_pytorch_extension:
852        output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
853    else:
854        if is_special_file(rel_filepath):
855            output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_special_repl, output_source)
856        elif is_pytorch_file(rel_filepath):
857            output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
858        else:
859            def c2_repl(m):
860                return CAFFE2_MAP[m.group(0)]
861            output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)
862
863    # Header rewrites
864    def mk_repl(templ, include_current_dir=True):
865        def repl(m):
866            f = m.group(1)
867            dirpath, filename = os.path.split(f)
868            if (
869                f.startswith(("ATen/cuda",
870                              "ATen/native/cuda",
871                              "ATen/native/nested/cuda",
872                              "ATen/native/quantized/cuda",
873                              "ATen/native/sparse/cuda",
874                              "ATen/native/transformers/cuda",
875                              "THC/")) or
876                (f.startswith("THC") and not f.startswith("THCP"))
877            ):
878                return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension))
879            # if filename is one of the files being hipified for this extension
880            if (is_pytorch_extension and any(s.endswith(filename) for s in all_files)):
881                header_dir = None
882                header_filepath = None
883                # If include_current_dir True, look first in same dir as the including source file
884                if include_current_dir:
885                    header_dir_to_check = os.path.dirname(fin_path)
886                    header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
887                    if os.path.exists(header_path_to_check):
888                        header_dir = header_dir_to_check
889                        header_filepath = header_path_to_check
890                # If not found, look in include dirs one by one and first match wins
891                if header_filepath is None:
892                    for header_include_dir in header_include_dirs:
893                        header_dir_to_check = os.path.join(output_directory, header_include_dir)
894                        header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
895                        if os.path.exists(header_path_to_check):
896                            header_dir = header_dir_to_check
897                            header_filepath = header_path_to_check
898                # If header file not found, keep as is
899                if header_filepath is None:
900                    return m.group(0)
901                # Hipify header file first if needed
902                if header_filepath not in HIPIFY_FINAL_RESULT:
903                    preprocess_file_and_save_result(output_directory,
904                                                    header_filepath,
905                                                    all_files, header_include_dirs, stats, hip_clang_launch,
906                                                    is_pytorch_extension, clean_ctx, show_progress)
907                elif header_filepath in HIPIFY_FINAL_RESULT:
908                    header_result = HIPIFY_FINAL_RESULT[header_filepath]
909                    if header_result.current_state == CurrentState.INITIALIZED:
910                        # get_hip_file_path needs a relative path to work correctly
911                        header_rel_path = os.path.relpath(header_filepath, output_directory)
912                        header_fout_path = os.path.abspath(os.path.join(output_directory,
913                                                                        get_hip_file_path(header_rel_path, is_pytorch_extension)))
914                        header_result.hipified_path = header_fout_path
915                        HIPIFY_FINAL_RESULT[header_filepath] = header_result
916                        return templ.format(os.path.relpath(header_fout_path if header_fout_path is not None
917                                                            else header_filepath, header_dir))
918                hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath].hipified_path
919                return templ.format(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None
920                                                    else header_filepath, header_dir))
921
922            return m.group(0)
923        return repl
924    output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"', True), output_source)
925    output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>', False), output_source)
926    output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source)
927
928    # CMakeLists.txt rewrites
929    if filepath.endswith('CMakeLists.txt'):
930        output_source = output_source.replace('CUDA', 'HIP')
931        output_source = output_source.replace('THC', 'THH')
932        output_source = RE_CU_SUFFIX.sub('.hip', output_source)
933
934    # Perform Kernel Launch Replacements
935    if not hip_clang_launch:
936        output_source = processKernelLaunches(output_source, stats)
937
938    # Replace std:: with non-std:: versions
939    if (filepath.endswith((".cu", ".cuh"))) and "PowKernel" not in filepath:
940        output_source = replace_math_functions(output_source)
941
942    # Include header if device code is contained.
943    output_source = hip_header_magic(output_source)
944
945    # Replace the extern __shared__
946    # NOTE: No longer needed after transition from hcc to hipclang.
947    # output_source = replace_extern_shared(output_source)
948
949    # Don't write out identical hipified files for extensions if dirpath has not changed
950    if (
951        is_pytorch_extension
952        and orig_output_source == output_source
953        and os.path.dirname(fin_path) == os.path.dirname(fout_path)
954    ):
955        hipify_result.hipified_path = fin_path
956        hipify_result.status = "[skipped, no changes]"
957        hipify_result.current_state = CurrentState.DONE
958        return hipify_result
959
960    # Add hipify breadcrumb for C-style files to avoid re-hipification
961    if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")):
962        output_source = HIPIFY_C_BREADCRUMB + output_source
963
964    do_write = True
965    if os.path.exists(fout_path):
966        with open(fout_path, encoding='utf-8') as fout_old:
967            do_write = fout_old.read() != output_source
968    if do_write:
969        try:
970            with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout:
971                fout.write(output_source)
972            hipify_result.hipified_path = fout_path
973            hipify_result.status = "[ok]"
974            hipify_result.current_state = CurrentState.DONE
975            return hipify_result
976        except PermissionError as e:
977            print(f'{bcolors.WARNING}Failed to save {fout_path} with "{e.strerror}", leaving {fin_path} unchanged.{bcolors.ENDC}',
978                  file=sys.stderr)
979            hipify_result.hipified_path = fin_path
980            hipify_result.status = "[skipped, no permissions]"
981            hipify_result.current_state = CurrentState.DONE
982            return hipify_result
983    else:
984        hipify_result.hipified_path = fout_path
985        hipify_result.status = "[skipped, already hipified]"
986        hipify_result.current_state = CurrentState.DONE
987        return hipify_result
988
989def file_specific_replacement(filepath, search_string, replace_string, strict=False):
990    with openf(filepath, "r+") as f:
991        contents = f.read()
992        if strict:
993            contents = re.sub(fr'\b({re.escape(search_string)})\b', lambda x: replace_string, contents)
994        else:
995            contents = contents.replace(search_string, replace_string)
996        f.seek(0)
997        f.write(contents)
998        f.truncate()
999
1000
1001def file_add_header(filepath, header):
1002    with openf(filepath, "r+") as f:
1003        contents = f.read()
1004        if header[0] != "<" and header[-1] != ">":
1005            header = f'"{header}"'
1006        contents = (f'#include {header} \n') + contents
1007        f.seek(0)
1008        f.write(contents)
1009        f.truncate()
1010
1011
1012def fix_static_global_kernels(in_txt):
1013    """Static global kernels in HIP results in a compilation error."""
1014    in_txt = in_txt.replace(" __global__ static", "__global__")
1015    return in_txt
1016
1017
1018RE_INCLUDE = re.compile(r"#include .*\n")
1019
1020
1021def extract_arguments(start, string):
1022    """ Return the list of arguments in the upcoming function parameter closure.
1023        Example:
1024        string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
1025        arguments (output):
1026            '[{'start': 1, 'end': 7},
1027            {'start': 8, 'end': 16},
1028            {'start': 17, 'end': 19},
1029            {'start': 20, 'end': 53}]'
1030    """
1031
1032    arguments = []
1033    closures = {
1034        "<": 0,
1035        "(": 0
1036    }
1037    current_position = start
1038    argument_start_pos = current_position + 1
1039
1040    # Search for final parenthesis
1041    while current_position < len(string):
1042        if string[current_position] == "(":
1043            closures["("] += 1
1044        elif string[current_position] == ")":
1045            closures["("] -= 1
1046        elif string[current_position] == "<":
1047            closures["<"] += 1
1048        elif string[current_position] == ">" and string[current_position - 1] != "-" and closures["<"] > 0:
1049            closures["<"] -= 1
1050
1051        # Finished all arguments
1052        if closures["("] == 0 and closures["<"] == 0:
1053            # Add final argument
1054            arguments.append({"start": argument_start_pos, "end": current_position})
1055            break
1056
1057        # Finished current argument
1058        if closures["("] == 1 and closures["<"] == 0 and string[current_position] == ",":
1059            arguments.append({"start": argument_start_pos, "end": current_position})
1060            argument_start_pos = current_position + 1
1061
1062        current_position += 1
1063
1064    return arguments
1065
1066
1067def str2bool(v):
1068    """ArgumentParser doesn't support type=bool. Thus, this helper method will convert
1069    from possible string types to True / False."""
1070    if v.lower() in ('yes', 'true', 't', 'y', '1'):
1071        return True
1072    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
1073        return False
1074    else:
1075        raise argparse.ArgumentTypeError('Boolean value expected.')
1076
1077
1078def hipify(
1079    project_directory: str,
1080    show_detailed: bool = False,
1081    extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
1082    header_extensions: Iterable = (".cuh", ".h", ".hpp"),
1083    output_directory: str = "",
1084    header_include_dirs: Iterable = (),
1085    includes: Iterable = ('*',),
1086    extra_files: Iterable = (),
1087    out_of_place_only: bool = False,
1088    ignores: Iterable = (),
1089    show_progress: bool = True,
1090    hip_clang_launch: bool = False,
1091    is_pytorch_extension: bool = False,
1092    hipify_extra_files_only: bool = False,
1093    clean_ctx: Optional[GeneratedFileCleaner] = None
1094) -> HipifyFinalResult:
1095    if project_directory == "":
1096        project_directory = os.getcwd()
1097
1098    # Verify the project directory exists.
1099    if not os.path.exists(project_directory):
1100        print("The project folder specified does not exist.")
1101        sys.exit(1)
1102
1103    # If no output directory, provide a default one.
1104    if not output_directory:
1105        project_directory.rstrip("/")
1106        output_directory = project_directory + "_amd"
1107
1108    if project_directory != output_directory:
1109        includes = [include.replace(project_directory, output_directory) for include in includes]
1110        ignores = [ignore.replace(project_directory, output_directory) for ignore in ignores]
1111
1112    # Copy from project directory to output directory if not done already.
1113    if not os.path.exists(output_directory):
1114        shutil.copytree(project_directory, output_directory)
1115
1116    all_files = list(matched_files_iter(output_directory, includes=includes,
1117                                        ignores=ignores, extensions=extensions,
1118                                        out_of_place_only=out_of_place_only,
1119                                        is_pytorch_extension=is_pytorch_extension))
1120    all_files_set = set(all_files)
1121    for f in extra_files:
1122        if not os.path.isabs(f):
1123            f = os.path.join(output_directory, f)
1124        if f not in all_files_set:
1125            all_files.append(f)
1126
1127    # List all files in header_include_paths to ensure they are hipified
1128    from pathlib import Path
1129    for header_include_dir in header_include_dirs:
1130        if os.path.isabs(header_include_dir):
1131            header_include_dir_path = Path(header_include_dir)
1132        else:
1133            header_include_dir_path = Path(os.path.join(output_directory, header_include_dir))
1134        for path in header_include_dir_path.rglob('*'):
1135            if (
1136                path.is_file()
1137                and _fnmatch(str(path), includes)
1138                and (not _fnmatch(str(path), ignores))
1139                and match_extensions(path.name, header_extensions)
1140            ):
1141                all_files.append(str(path))
1142
1143    if clean_ctx is None:
1144        clean_ctx = GeneratedFileCleaner(keep_intermediates=True)
1145
1146    # Preprocessing statistics.
1147    stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []}
1148
1149    for filepath in (all_files if not hipify_extra_files_only else extra_files):
1150        preprocess_file_and_save_result(output_directory, filepath, all_files, header_include_dirs,
1151                                        stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
1152
1153    print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)
1154
1155    # Show detailed summary
1156    if show_detailed:
1157        compute_stats(stats)
1158
1159    return HIPIFY_FINAL_RESULT
1160