xref: /aosp_15_r20/external/pytorch/cmake/public/LoadHIP.cmake (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1set(PYTORCH_FOUND_HIP FALSE)
2
3if(NOT DEFINED ENV{ROCM_PATH})
4  set(ROCM_PATH /opt/rocm)
5else()
6  set(ROCM_PATH $ENV{ROCM_PATH})
7endif()
8if(NOT DEFINED ENV{ROCM_INCLUDE_DIRS})
9  set(ROCM_INCLUDE_DIRS ${ROCM_PATH}/include)
10else()
11  set(ROCM_INCLUDE_DIRS $ENV{ROCM_INCLUDE_DIRS})
12endif()
13
14if(NOT EXISTS ${ROCM_PATH})
15  return()
16endif()
17
18# MAGMA_HOME
19if(NOT DEFINED ENV{MAGMA_HOME})
20  set(MAGMA_HOME ${ROCM_PATH}/magma)
21  set(ENV{MAGMA_HOME} ${ROCM_PATH}/magma)
22else()
23  set(MAGMA_HOME $ENV{MAGMA_HOME})
24endif()
25
26torch_hip_get_arch_list(PYTORCH_ROCM_ARCH)
27if(PYTORCH_ROCM_ARCH STREQUAL "")
28  message(FATAL_ERROR "No GPU arch specified for ROCm build. Please use PYTORCH_ROCM_ARCH environment variable to specify GPU archs to build for.")
29endif()
30message("Building PyTorch for GPU arch: ${PYTORCH_ROCM_ARCH}")
31
32# Add HIP to the CMAKE Module Path
33set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH})
34
35macro(find_package_and_print_version PACKAGE_NAME)
36  find_package("${PACKAGE_NAME}" ${ARGN})
37  message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
38endmacro()
39
40# Find the HIP Package
41find_package_and_print_version(HIP 1.0)
42
43if(HIP_FOUND)
44  set(PYTORCH_FOUND_HIP TRUE)
45  set(FOUND_ROCM_VERSION_H FALSE)
46
47  set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}")
48  set(file "${PROJECT_BINARY_DIR}/detect_rocm_version.cc")
49
50  # Find ROCM version for checks
51  # ROCM 5.0 and later will have header api for version management
52  if(EXISTS ${ROCM_INCLUDE_DIRS}/rocm_version.h)
53    set(FOUND_ROCM_VERSION_H TRUE)
54    file(WRITE ${file} ""
55      "#include <rocm_version.h>\n"
56      )
57  elseif(EXISTS ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h)
58    set(FOUND_ROCM_VERSION_H TRUE)
59    file(WRITE ${file} ""
60      "#include <rocm-core/rocm_version.h>\n"
61      )
62  else()
63    message("********************* rocm_version.h couldnt be found ******************\n")
64  endif()
65
66  if(FOUND_ROCM_VERSION_H)
67    file(APPEND ${file} ""
68      "#include <cstdio>\n"
69
70      "#ifndef ROCM_VERSION_PATCH\n"
71      "#define ROCM_VERSION_PATCH 0\n"
72      "#endif\n"
73      "#define STRINGIFYHELPER(x) #x\n"
74      "#define STRINGIFY(x) STRINGIFYHELPER(x)\n"
75      "int main() {\n"
76      "  printf(\"%d.%d.%s\", ROCM_VERSION_MAJOR, ROCM_VERSION_MINOR, STRINGIFY(ROCM_VERSION_PATCH));\n"
77      "  return 0;\n"
78      "}\n"
79      )
80
81    try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
82      CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
83      RUN_OUTPUT_VARIABLE rocm_version_from_header
84      COMPILE_OUTPUT_VARIABLE output_var
85      )
86    # We expect the compile to be successful if the include directory exists.
87    if(NOT compile_result)
88      message(FATAL_ERROR "Caffe2: Couldn't determine version from header: " ${output_var})
89    endif()
90    message(STATUS "Caffe2: Header version is: " ${rocm_version_from_header})
91    set(ROCM_VERSION_DEV_RAW ${rocm_version_from_header})
92    message("\n***** ROCm version from rocm_version.h ****\n")
93  endif()
94
95  string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+).*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW})
96
97  if(ROCM_VERSION_DEV_MATCH)
98    set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1})
99    set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2})
100    set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3})
101    set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}")
102    math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}")
103  endif()
104
105  message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}")
106  message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}")
107  message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}")
108  message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}")
109  message("ROCM_VERSION_DEV_INT:   ${ROCM_VERSION_DEV_INT}")
110
111  math(EXPR TORCH_HIP_VERSION "(${HIP_VERSION_MAJOR} * 100) + ${HIP_VERSION_MINOR}")
112  message("HIP_VERSION_MAJOR: ${HIP_VERSION_MAJOR}")
113  message("HIP_VERSION_MINOR: ${HIP_VERSION_MINOR}")
114  message("TORCH_HIP_VERSION: ${TORCH_HIP_VERSION}")
115
116  message("\n***** Library versions from dpkg *****\n")
117  execute_process(COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
118  execute_process(COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk "{print $2 \" VERSION: \" $3}")
119  execute_process(COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk "{print $2 \" VERSION: \" $3}")
120  execute_process(COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
121  execute_process(COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk "{print $2 \" VERSION: \" $3}")
122  execute_process(COMMAND dpkg -l COMMAND grep hip-base COMMAND awk "{print $2 \" VERSION: \" $3}")
123  execute_process(COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk "{print $2 \" VERSION: \" $3}")
124
125  message("\n***** Library versions from cmake find_package *****\n")
126
127  set(CMAKE_HIP_CLANG_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
128  set(CMAKE_HIP_CLANG_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
129  ### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.###
130
131  set(hip_DIR ${ROCM_PATH}/lib/cmake/hip)
132  set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64)
133  set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs)
134  set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr)
135  set(rocrand_DIR ${ROCM_PATH}/lib/cmake/rocrand)
136  set(hiprand_DIR ${ROCM_PATH}/lib/cmake/hiprand)
137  set(rocblas_DIR ${ROCM_PATH}/lib/cmake/rocblas)
138  set(hipblas_DIR ${ROCM_PATH}/lib/cmake/hipblas)
139  set(hipblaslt_DIR ${ROCM_PATH}/lib/cmake/hipblaslt)
140  set(miopen_DIR ${ROCM_PATH}/lib/cmake/miopen)
141  set(rocfft_DIR ${ROCM_PATH}/lib/cmake/rocfft)
142  set(hipfft_DIR ${ROCM_PATH}/lib/cmake/hipfft)
143  set(hipsparse_DIR ${ROCM_PATH}/lib/cmake/hipsparse)
144  set(rccl_DIR ${ROCM_PATH}/lib/cmake/rccl)
145  set(rocprim_DIR ${ROCM_PATH}/lib/cmake/rocprim)
146  set(hipcub_DIR ${ROCM_PATH}/lib/cmake/hipcub)
147  set(rocthrust_DIR ${ROCM_PATH}/lib/cmake/rocthrust)
148  set(hipsolver_DIR ${ROCM_PATH}/lib/cmake/hipsolver)
149  set(hiprtc_DIR ${ROCM_PATH}/lib/cmake/hiprtc)
150
151
152  find_package_and_print_version(hip REQUIRED)
153  find_package_and_print_version(hsa-runtime64 REQUIRED)
154  find_package_and_print_version(amd_comgr REQUIRED)
155  find_package_and_print_version(rocrand REQUIRED)
156  find_package_and_print_version(hiprand REQUIRED)
157  find_package_and_print_version(rocblas REQUIRED)
158  find_package_and_print_version(hipblas REQUIRED)
159  find_package_and_print_version(hipblaslt REQUIRED)
160  find_package_and_print_version(miopen REQUIRED)
161  find_package_and_print_version(hipfft REQUIRED)
162  find_package_and_print_version(hipsparse REQUIRED)
163  find_package_and_print_version(rccl)
164  find_package_and_print_version(rocprim REQUIRED)
165  find_package_and_print_version(hipcub REQUIRED)
166  find_package_and_print_version(rocthrust REQUIRED)
167  find_package_and_print_version(hipsolver REQUIRED)
168  find_package_and_print_version(hiprtc REQUIRED)
169
170
171  find_library(PYTORCH_HIP_LIBRARIES amdhip64 HINTS ${ROCM_PATH}/lib)
172  # TODO: miopen_LIBRARIES should return fullpath to the library file,
173  # however currently it's just the lib name
174  if(TARGET ${miopen_LIBRARIES})
175    set(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES})
176  else()
177    find_library(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${ROCM_PATH}/lib)
178  endif()
179  # TODO: rccl_LIBRARIES should return fullpath to the library file,
180  # however currently it's just the lib name
181  if(TARGET ${rccl_LIBRARIES})
182    set(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES})
183  else()
184    find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${ROCM_PATH}/lib)
185  endif()
186  find_library(ROCM_HIPRTC_LIB hiprtc HINTS ${ROCM_PATH}/lib)
187  # roctx is part of roctracer
188  find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
189
190  # check whether HIP declares new types
191  set(file "${PROJECT_BINARY_DIR}/hip_new_types.cc")
192  file(WRITE ${file} ""
193    "#include <hip/library_types.h>\n"
194    "int main() {\n"
195    "    hipDataType baz = HIP_R_8F_E4M3_FNUZ;\n"
196    "    return 0;\n"
197    "}\n"
198    )
199
200  try_compile(hip_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
201    CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
202    COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
203    OUTPUT_VARIABLE hip_compile_output)
204
205  if(hip_compile_result)
206    set(HIP_NEW_TYPE_ENUMS ON)
207    #message("HIP is using new type enums: ${hip_compile_output}")
208    message("HIP is using new type enums")
209  else()
210    set(HIP_NEW_TYPE_ENUMS OFF)
211    #message("HIP is NOT using new type enums: ${hip_compile_output}")
212    message("HIP is NOT using new type enums")
213  endif()
214
215endif()
216