xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/aoti_hipify_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import re
3
4import torch
5from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE
6
7
8# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like:
9#   "...
10#    from ..codecache import CudaKernelParamCache
11#   ..."
12# In such cases, we do not need to hipify_torch the orignial class/file name in codegen/codecache
13
14
15def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str:
16    if torch.version.hip is None and not force_hipify:
17        return source_codes
18
19    def c2_repl(m):
20        return PYTORCH_MAP[m.group(0)]
21
22    # We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch,
23    # it will apply positive lookbehind (?<=\W) to the pattern to avoid matching
24    # keyword at the beginning of code line. However, this can happen in codegen,
25    # which will cause the pattern to not match.
26
27    # Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example
28    # we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA"
29    RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)")
30
31    source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes)
32    return source_codes
33