1# mypy: allow-untyped-defs 2import re 3 4import torch 5from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE 6 7 8# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like: 9# "... 10# from ..codecache import CudaKernelParamCache 11# ..." 12# In such cases, we do not need to hipify_torch the orignial class/file name in codegen/codecache 13 14 15def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str: 16 if torch.version.hip is None and not force_hipify: 17 return source_codes 18 19 def c2_repl(m): 20 return PYTORCH_MAP[m.group(0)] 21 22 # We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch, 23 # it will apply positive lookbehind (?<=\W) to the pattern to avoid matching 24 # keyword at the beginning of code line. However, this can happen in codegen, 25 # which will cause the pattern to not match. 26 27 # Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example 28 # we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA" 29 RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)") 30 31 source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes) 32 return source_codes 33