xref: /aosp_15_r20/external/pytorch/cmake/public/cuda.cmake (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# ---[ cuda
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker# Poor man's include guard
4*da0073e9SAndroid Build Coastguard Workerif(TARGET torch::cudart)
5*da0073e9SAndroid Build Coastguard Worker  return()
6*da0073e9SAndroid Build Coastguard Workerendif()
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker# sccache is only supported in CMake master and not in the newest official
9*da0073e9SAndroid Build Coastguard Worker# release (3.11.3) yet. Hence we need our own Modules_CUDA_fix to enable sccache.
10*da0073e9SAndroid Build Coastguard Workerlist(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/../Modules_CUDA_fix)
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker# We don't want to statically link cudart, because we rely on it's dynamic linkage in
13*da0073e9SAndroid Build Coastguard Worker# python (follow along torch/cuda/__init__.py and usage of cudaGetErrorName).
14*da0073e9SAndroid Build Coastguard Worker# Technically, we can link cudart here statically, and link libtorch_python.so
15*da0073e9SAndroid Build Coastguard Worker# to a dynamic libcudart.so, but that's just wasteful.
16*da0073e9SAndroid Build Coastguard Worker# However, on Windows, if this one gets switched off, the error "cuda: unknown error"
17*da0073e9SAndroid Build Coastguard Worker# will be raised when running the following code:
18*da0073e9SAndroid Build Coastguard Worker# >>> import torch
19*da0073e9SAndroid Build Coastguard Worker# >>> torch.cuda.is_available()
20*da0073e9SAndroid Build Coastguard Worker# >>> torch.cuda.current_device()
21*da0073e9SAndroid Build Coastguard Worker# More details can be found in the following links.
22*da0073e9SAndroid Build Coastguard Worker# https://github.com/pytorch/pytorch/issues/20635
23*da0073e9SAndroid Build Coastguard Worker# https://github.com/pytorch/pytorch/issues/17108
24*da0073e9SAndroid Build Coastguard Workerif(NOT MSVC)
25*da0073e9SAndroid Build Coastguard Worker  set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE INTERNAL "")
26*da0073e9SAndroid Build Coastguard Workerendif()
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker# Find CUDA.
29*da0073e9SAndroid Build Coastguard Workerfind_package(CUDA)
30*da0073e9SAndroid Build Coastguard Workerif(NOT CUDA_FOUND)
31*da0073e9SAndroid Build Coastguard Worker  message(WARNING
32*da0073e9SAndroid Build Coastguard Worker    "Caffe2: CUDA cannot be found. Depending on whether you are building "
33*da0073e9SAndroid Build Coastguard Worker    "Caffe2 or a Caffe2 dependent library, the next warning / error will "
34*da0073e9SAndroid Build Coastguard Worker    "give you more info.")
35*da0073e9SAndroid Build Coastguard Worker  set(CAFFE2_USE_CUDA OFF)
36*da0073e9SAndroid Build Coastguard Worker  return()
37*da0073e9SAndroid Build Coastguard Workerendif()
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker# Enable CUDA language support
40*da0073e9SAndroid Build Coastguard Workerset(CUDAToolkit_ROOT "${CUDA_TOOLKIT_ROOT_DIR}")
41*da0073e9SAndroid Build Coastguard Worker# Pass clang as host compiler, which according to the docs
42*da0073e9SAndroid Build Coastguard Worker# Must be done before CUDA language is enabled, see
43*da0073e9SAndroid Build Coastguard Worker# https://cmake.org/cmake/help/v3.15/variable/CMAKE_CUDA_HOST_COMPILER.html
44*da0073e9SAndroid Build Coastguard Workerif("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
45*da0073e9SAndroid Build Coastguard Worker  set(CMAKE_CUDA_HOST_COMPILER "${CMAKE_C_COMPILER}")
46*da0073e9SAndroid Build Coastguard Workerendif()
47*da0073e9SAndroid Build Coastguard Workerenable_language(CUDA)
48*da0073e9SAndroid Build Coastguard Workerif("X${CMAKE_CUDA_STANDARD}" STREQUAL "X" )
49*da0073e9SAndroid Build Coastguard Worker  set(CMAKE_CUDA_STANDARD ${CMAKE_CXX_STANDARD})
50*da0073e9SAndroid Build Coastguard Workerendif()
51*da0073e9SAndroid Build Coastguard Workerset(CMAKE_CUDA_STANDARD_REQUIRED ON)
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker# CMP0074 - find_package will respect <PackageName>_ROOT variables
54*da0073e9SAndroid Build Coastguard Workercmake_policy(PUSH)
55*da0073e9SAndroid Build Coastguard Workerif(CMAKE_VERSION VERSION_GREATER_EQUAL 3.12.0)
56*da0073e9SAndroid Build Coastguard Worker  cmake_policy(SET CMP0074 NEW)
57*da0073e9SAndroid Build Coastguard Workerendif()
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Workerfind_package(CUDAToolkit REQUIRED)
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Workercmake_policy(POP)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Workerif(NOT CMAKE_CUDA_COMPILER_VERSION VERSION_EQUAL CUDAToolkit_VERSION)
64*da0073e9SAndroid Build Coastguard Worker  message(FATAL_ERROR "Found two conflicting CUDA versions:\n"
65*da0073e9SAndroid Build Coastguard Worker                      "V${CMAKE_CUDA_COMPILER_VERSION} in '${CUDA_INCLUDE_DIRS}' and\n"
66*da0073e9SAndroid Build Coastguard Worker                      "V${CUDAToolkit_VERSION} in '${CUDAToolkit_INCLUDE_DIRS}'")
67*da0073e9SAndroid Build Coastguard Workerendif()
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Workermessage(STATUS "Caffe2: CUDA detected: " ${CUDA_VERSION})
70*da0073e9SAndroid Build Coastguard Workermessage(STATUS "Caffe2: CUDA nvcc is: " ${CUDA_NVCC_EXECUTABLE})
71*da0073e9SAndroid Build Coastguard Workermessage(STATUS "Caffe2: CUDA toolkit directory: " ${CUDA_TOOLKIT_ROOT_DIR})
72*da0073e9SAndroid Build Coastguard Workerif(CUDA_VERSION VERSION_LESS 11.0)
73*da0073e9SAndroid Build Coastguard Worker  message(FATAL_ERROR "PyTorch requires CUDA 11.0 or above.")
74*da0073e9SAndroid Build Coastguard Workerendif()
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Workerif(CUDA_FOUND)
77*da0073e9SAndroid Build Coastguard Worker  # Sometimes, we may mismatch nvcc with the CUDA headers we are
78*da0073e9SAndroid Build Coastguard Worker  # compiling with, e.g., if a ccache nvcc is fed to us by CUDA_NVCC_EXECUTABLE
79*da0073e9SAndroid Build Coastguard Worker  # but the PATH is not consistent with CUDA_HOME.  It's better safe
80*da0073e9SAndroid Build Coastguard Worker  # than sorry: make sure everything is consistent.
81*da0073e9SAndroid Build Coastguard Worker  if(MSVC AND CMAKE_GENERATOR MATCHES "Visual Studio")
82*da0073e9SAndroid Build Coastguard Worker    # When using Visual Studio, it attempts to lock the whole binary dir when
83*da0073e9SAndroid Build Coastguard Worker    # `try_run` is called, which will cause the build to fail.
84*da0073e9SAndroid Build Coastguard Worker    string(RANDOM BUILD_SUFFIX)
85*da0073e9SAndroid Build Coastguard Worker    set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}/${BUILD_SUFFIX}")
86*da0073e9SAndroid Build Coastguard Worker  else()
87*da0073e9SAndroid Build Coastguard Worker    set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}")
88*da0073e9SAndroid Build Coastguard Worker  endif()
89*da0073e9SAndroid Build Coastguard Worker  set(file "${PROJECT_BINARY_DIR}/detect_cuda_version.cc")
90*da0073e9SAndroid Build Coastguard Worker  file(WRITE ${file} ""
91*da0073e9SAndroid Build Coastguard Worker    "#include <cuda.h>\n"
92*da0073e9SAndroid Build Coastguard Worker    "#include <cstdio>\n"
93*da0073e9SAndroid Build Coastguard Worker    "int main() {\n"
94*da0073e9SAndroid Build Coastguard Worker    "  printf(\"%d.%d\", CUDA_VERSION / 1000, (CUDA_VERSION / 10) % 100);\n"
95*da0073e9SAndroid Build Coastguard Worker    "  return 0;\n"
96*da0073e9SAndroid Build Coastguard Worker    "}\n"
97*da0073e9SAndroid Build Coastguard Worker    )
98*da0073e9SAndroid Build Coastguard Worker  if(NOT CMAKE_CROSSCOMPILING)
99*da0073e9SAndroid Build Coastguard Worker    try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
100*da0073e9SAndroid Build Coastguard Worker      CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}"
101*da0073e9SAndroid Build Coastguard Worker      LINK_LIBRARIES ${CUDA_LIBRARIES}
102*da0073e9SAndroid Build Coastguard Worker      RUN_OUTPUT_VARIABLE cuda_version_from_header
103*da0073e9SAndroid Build Coastguard Worker      COMPILE_OUTPUT_VARIABLE output_var
104*da0073e9SAndroid Build Coastguard Worker      )
105*da0073e9SAndroid Build Coastguard Worker    if(NOT compile_result)
106*da0073e9SAndroid Build Coastguard Worker      message(FATAL_ERROR "Caffe2: Couldn't determine version from header: " ${output_var})
107*da0073e9SAndroid Build Coastguard Worker    endif()
108*da0073e9SAndroid Build Coastguard Worker    message(STATUS "Caffe2: Header version is: " ${cuda_version_from_header})
109*da0073e9SAndroid Build Coastguard Worker    if(NOT cuda_version_from_header STREQUAL ${CUDA_VERSION_STRING})
110*da0073e9SAndroid Build Coastguard Worker      # Force CUDA to be processed for again next time
111*da0073e9SAndroid Build Coastguard Worker      # TODO: I'm not sure if this counts as an implementation detail of
112*da0073e9SAndroid Build Coastguard Worker      # FindCUDA
113*da0073e9SAndroid Build Coastguard Worker      set(${cuda_version_from_findcuda} ${CUDA_VERSION_STRING})
114*da0073e9SAndroid Build Coastguard Worker      unset(CUDA_TOOLKIT_ROOT_DIR_INTERNAL CACHE)
115*da0073e9SAndroid Build Coastguard Worker      # Not strictly necessary, but for good luck.
116*da0073e9SAndroid Build Coastguard Worker      unset(CUDA_VERSION CACHE)
117*da0073e9SAndroid Build Coastguard Worker      # Error out
118*da0073e9SAndroid Build Coastguard Worker      message(FATAL_ERROR "FindCUDA says CUDA version is ${cuda_version_from_findcuda} (usually determined by nvcc), "
119*da0073e9SAndroid Build Coastguard Worker        "but the CUDA headers say the version is ${cuda_version_from_header}.  This often occurs "
120*da0073e9SAndroid Build Coastguard Worker        "when you set both CUDA_HOME and CUDA_NVCC_EXECUTABLE to "
121*da0073e9SAndroid Build Coastguard Worker        "non-standard locations, without also setting PATH to point to the correct nvcc.  "
122*da0073e9SAndroid Build Coastguard Worker        "Perhaps, try re-running this command again with PATH=${CUDA_TOOLKIT_ROOT_DIR}/bin:$PATH.  "
123*da0073e9SAndroid Build Coastguard Worker        "See above log messages for more diagnostics, and see https://github.com/pytorch/pytorch/issues/8092 for more details.")
124*da0073e9SAndroid Build Coastguard Worker    endif()
125*da0073e9SAndroid Build Coastguard Worker  endif()
126*da0073e9SAndroid Build Coastguard Workerendif()
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker# ---[ CUDA libraries wrapper
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker# find lbnvrtc.so
131*da0073e9SAndroid Build Coastguard Workerset(CUDA_NVRTC_LIB "${CUDA_nvrtc_LIBRARY}" CACHE FILEPATH "")
132*da0073e9SAndroid Build Coastguard Workerif(CUDA_NVRTC_LIB AND NOT CUDA_NVRTC_SHORTHASH)
133*da0073e9SAndroid Build Coastguard Worker  find_package(Python COMPONENTS Interpreter)
134*da0073e9SAndroid Build Coastguard Worker  execute_process(
135*da0073e9SAndroid Build Coastguard Worker    COMMAND Python::Interpreter -c
136*da0073e9SAndroid Build Coastguard Worker    "import hashlib;hash=hashlib.sha256();hash.update(open('${CUDA_NVRTC_LIB}','rb').read());print(hash.hexdigest()[:8])"
137*da0073e9SAndroid Build Coastguard Worker    RESULT_VARIABLE _retval
138*da0073e9SAndroid Build Coastguard Worker    OUTPUT_VARIABLE CUDA_NVRTC_SHORTHASH)
139*da0073e9SAndroid Build Coastguard Worker  if(NOT _retval EQUAL 0)
140*da0073e9SAndroid Build Coastguard Worker    message(WARNING "Failed to compute shorthash for libnvrtc.so")
141*da0073e9SAndroid Build Coastguard Worker    set(CUDA_NVRTC_SHORTHASH "XXXXXXXX")
142*da0073e9SAndroid Build Coastguard Worker  else()
143*da0073e9SAndroid Build Coastguard Worker    string(STRIP "${CUDA_NVRTC_SHORTHASH}" CUDA_NVRTC_SHORTHASH)
144*da0073e9SAndroid Build Coastguard Worker    message(STATUS "${CUDA_NVRTC_LIB} shorthash is ${CUDA_NVRTC_SHORTHASH}")
145*da0073e9SAndroid Build Coastguard Worker  endif()
146*da0073e9SAndroid Build Coastguard Workerendif()
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker# Create new style imported libraries.
149*da0073e9SAndroid Build Coastguard Worker# Several of these libraries have a hardcoded path if CAFFE2_STATIC_LINK_CUDA
150*da0073e9SAndroid Build Coastguard Worker# is set. This path is where sane CUDA installations have their static
151*da0073e9SAndroid Build Coastguard Worker# libraries installed. This flag should only be used for binary builds, so
152*da0073e9SAndroid Build Coastguard Worker# end-users should never have this flag set.
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker# cuda
155*da0073e9SAndroid Build Coastguard Workeradd_library(caffe2::cuda INTERFACE IMPORTED)
156*da0073e9SAndroid Build Coastguard Workerset_property(
157*da0073e9SAndroid Build Coastguard Worker    TARGET caffe2::cuda PROPERTY INTERFACE_LINK_LIBRARIES
158*da0073e9SAndroid Build Coastguard Worker    CUDA::cuda_driver)
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker# cudart
161*da0073e9SAndroid Build Coastguard Workeradd_library(torch::cudart INTERFACE IMPORTED)
162*da0073e9SAndroid Build Coastguard Workerif(CAFFE2_STATIC_LINK_CUDA)
163*da0073e9SAndroid Build Coastguard Worker    set_property(
164*da0073e9SAndroid Build Coastguard Worker        TARGET torch::cudart PROPERTY INTERFACE_LINK_LIBRARIES
165*da0073e9SAndroid Build Coastguard Worker        CUDA::cudart_static)
166*da0073e9SAndroid Build Coastguard Workerelse()
167*da0073e9SAndroid Build Coastguard Worker    set_property(
168*da0073e9SAndroid Build Coastguard Worker        TARGET torch::cudart PROPERTY INTERFACE_LINK_LIBRARIES
169*da0073e9SAndroid Build Coastguard Worker        CUDA::cudart)
170*da0073e9SAndroid Build Coastguard Workerendif()
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker# nvToolsExt
173*da0073e9SAndroid Build Coastguard Workerfind_path(nvtx3_dir NAMES nvtx3 PATHS "${PROJECT_SOURCE_DIR}/third_party/NVTX/c/include" NO_DEFAULT_PATH)
174*da0073e9SAndroid Build Coastguard Workerfind_package_handle_standard_args(nvtx3 DEFAULT_MSG nvtx3_dir)
175*da0073e9SAndroid Build Coastguard Workerif(nvtx3_FOUND)
176*da0073e9SAndroid Build Coastguard Worker  add_library(torch::nvtx3 INTERFACE IMPORTED)
177*da0073e9SAndroid Build Coastguard Worker  target_include_directories(torch::nvtx3 INTERFACE "${nvtx3_dir}")
178*da0073e9SAndroid Build Coastguard Worker  target_compile_definitions(torch::nvtx3 INTERFACE TORCH_CUDA_USE_NVTX3)
179*da0073e9SAndroid Build Coastguard Workerelse()
180*da0073e9SAndroid Build Coastguard Worker  message(WARNING "Cannot find NVTX3, find old NVTX instead")
181*da0073e9SAndroid Build Coastguard Worker  add_library(torch::nvtoolsext INTERFACE IMPORTED)
182*da0073e9SAndroid Build Coastguard Worker  set_property(TARGET torch::nvtoolsext PROPERTY INTERFACE_LINK_LIBRARIES CUDA::nvToolsExt)
183*da0073e9SAndroid Build Coastguard Workerendif()
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker# cublas
187*da0073e9SAndroid Build Coastguard Workeradd_library(caffe2::cublas INTERFACE IMPORTED)
188*da0073e9SAndroid Build Coastguard Workerif(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
189*da0073e9SAndroid Build Coastguard Worker    set_property(
190*da0073e9SAndroid Build Coastguard Worker        TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES
191*da0073e9SAndroid Build Coastguard Worker        # NOTE: cublas is always linked dynamically
192*da0073e9SAndroid Build Coastguard Worker        CUDA::cublas CUDA::cublasLt)
193*da0073e9SAndroid Build Coastguard Worker    set_property(
194*da0073e9SAndroid Build Coastguard Worker        TARGET caffe2::cublas APPEND PROPERTY INTERFACE_LINK_LIBRARIES
195*da0073e9SAndroid Build Coastguard Worker        CUDA::cudart_static rt)
196*da0073e9SAndroid Build Coastguard Workerelse()
197*da0073e9SAndroid Build Coastguard Worker    set_property(
198*da0073e9SAndroid Build Coastguard Worker        TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES
199*da0073e9SAndroid Build Coastguard Worker        CUDA::cublas CUDA::cublasLt)
200*da0073e9SAndroid Build Coastguard Workerendif()
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker# cudnn interface
203*da0073e9SAndroid Build Coastguard Worker# static linking is handled by USE_STATIC_CUDNN environment variable
204*da0073e9SAndroid Build Coastguard Workerif(CAFFE2_USE_CUDNN)
205*da0073e9SAndroid Build Coastguard Worker  if(USE_STATIC_CUDNN)
206*da0073e9SAndroid Build Coastguard Worker    set(CUDNN_STATIC ON CACHE BOOL "")
207*da0073e9SAndroid Build Coastguard Worker  else()
208*da0073e9SAndroid Build Coastguard Worker    set(CUDNN_STATIC OFF CACHE BOOL "")
209*da0073e9SAndroid Build Coastguard Worker  endif()
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker  find_package(CUDNN)
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker  if(NOT CUDNN_FOUND)
214*da0073e9SAndroid Build Coastguard Worker    message(WARNING
215*da0073e9SAndroid Build Coastguard Worker      "Cannot find cuDNN library. Turning the option off")
216*da0073e9SAndroid Build Coastguard Worker    set(CAFFE2_USE_CUDNN OFF)
217*da0073e9SAndroid Build Coastguard Worker  else()
218*da0073e9SAndroid Build Coastguard Worker    if(CUDNN_VERSION VERSION_LESS "8.1.0")
219*da0073e9SAndroid Build Coastguard Worker      message(FATAL_ERROR "PyTorch requires cuDNN 8.1 and above.")
220*da0073e9SAndroid Build Coastguard Worker    endif()
221*da0073e9SAndroid Build Coastguard Worker  endif()
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker  add_library(torch::cudnn INTERFACE IMPORTED)
224*da0073e9SAndroid Build Coastguard Worker  target_include_directories(torch::cudnn INTERFACE ${CUDNN_INCLUDE_PATH})
225*da0073e9SAndroid Build Coastguard Worker  if(CUDNN_STATIC AND NOT WIN32)
226*da0073e9SAndroid Build Coastguard Worker    target_link_options(torch::cudnn INTERFACE
227*da0073e9SAndroid Build Coastguard Worker        "-Wl,--exclude-libs,libcudnn_static.a")
228*da0073e9SAndroid Build Coastguard Worker  else()
229*da0073e9SAndroid Build Coastguard Worker    target_link_libraries(torch::cudnn INTERFACE ${CUDNN_LIBRARY_PATH})
230*da0073e9SAndroid Build Coastguard Worker  endif()
231*da0073e9SAndroid Build Coastguard Workerelse()
232*da0073e9SAndroid Build Coastguard Worker  message(STATUS "USE_CUDNN is set to 0. Compiling without cuDNN support")
233*da0073e9SAndroid Build Coastguard Workerendif()
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Workerif(CAFFE2_USE_CUSPARSELT)
236*da0073e9SAndroid Build Coastguard Worker  find_package(CUSPARSELT)
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker  if(NOT CUSPARSELT_FOUND)
239*da0073e9SAndroid Build Coastguard Worker    message(WARNING
240*da0073e9SAndroid Build Coastguard Worker      "Cannot find cuSPARSELt library. Turning the option off")
241*da0073e9SAndroid Build Coastguard Worker    set(CAFFE2_USE_CUSPARSELT OFF)
242*da0073e9SAndroid Build Coastguard Worker  else()
243*da0073e9SAndroid Build Coastguard Worker    add_library(torch::cusparselt INTERFACE IMPORTED)
244*da0073e9SAndroid Build Coastguard Worker    target_include_directories(torch::cusparselt INTERFACE ${CUSPARSELT_INCLUDE_PATH})
245*da0073e9SAndroid Build Coastguard Worker    target_link_libraries(torch::cusparselt INTERFACE ${CUSPARSELT_LIBRARY_PATH})
246*da0073e9SAndroid Build Coastguard Worker  endif()
247*da0073e9SAndroid Build Coastguard Workerelse()
248*da0073e9SAndroid Build Coastguard Worker  message(STATUS "USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support")
249*da0073e9SAndroid Build Coastguard Workerendif()
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Workerif(USE_CUDSS)
252*da0073e9SAndroid Build Coastguard Worker  find_package(CUDSS)
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker  if(NOT CUDSS_FOUND)
255*da0073e9SAndroid Build Coastguard Worker    message(WARNING
256*da0073e9SAndroid Build Coastguard Worker      "Cannot find CUDSS library. Turning the option off")
257*da0073e9SAndroid Build Coastguard Worker    set(USE_CUDSS OFF)
258*da0073e9SAndroid Build Coastguard Worker  else()
259*da0073e9SAndroid Build Coastguard Worker    add_library(torch::cudss INTERFACE IMPORTED)
260*da0073e9SAndroid Build Coastguard Worker    target_include_directories(torch::cudss INTERFACE ${CUDSS_INCLUDE_PATH})
261*da0073e9SAndroid Build Coastguard Worker    target_link_libraries(torch::cudss INTERFACE ${CUDSS_LIBRARY_PATH})
262*da0073e9SAndroid Build Coastguard Worker  endif()
263*da0073e9SAndroid Build Coastguard Workerelse()
264*da0073e9SAndroid Build Coastguard Worker  message(STATUS "USE_CUDSS is set to 0. Compiling without cuDSS support")
265*da0073e9SAndroid Build Coastguard Workerendif()
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker# cufile
268*da0073e9SAndroid Build Coastguard Workerif(CAFFE2_USE_CUFILE)
269*da0073e9SAndroid Build Coastguard Worker  add_library(torch::cufile INTERFACE IMPORTED)
270*da0073e9SAndroid Build Coastguard Worker  if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
271*da0073e9SAndroid Build Coastguard Worker      set_property(
272*da0073e9SAndroid Build Coastguard Worker          TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
273*da0073e9SAndroid Build Coastguard Worker          CUDA::cuFile_static)
274*da0073e9SAndroid Build Coastguard Worker  else()
275*da0073e9SAndroid Build Coastguard Worker      set_property(
276*da0073e9SAndroid Build Coastguard Worker          TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
277*da0073e9SAndroid Build Coastguard Worker          CUDA::cuFile)
278*da0073e9SAndroid Build Coastguard Worker  endif()
279*da0073e9SAndroid Build Coastguard Workerelse()
280*da0073e9SAndroid Build Coastguard Worker  message(STATUS "USE_CUFILE is set to 0. Compiling without cuFile support")
281*da0073e9SAndroid Build Coastguard Workerendif()
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker# curand
284*da0073e9SAndroid Build Coastguard Workeradd_library(caffe2::curand INTERFACE IMPORTED)
285*da0073e9SAndroid Build Coastguard Workerif(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
286*da0073e9SAndroid Build Coastguard Worker    set_property(
287*da0073e9SAndroid Build Coastguard Worker        TARGET caffe2::curand PROPERTY INTERFACE_LINK_LIBRARIES
288*da0073e9SAndroid Build Coastguard Worker        CUDA::curand_static)
289*da0073e9SAndroid Build Coastguard Workerelse()
290*da0073e9SAndroid Build Coastguard Worker    set_property(
291*da0073e9SAndroid Build Coastguard Worker        TARGET caffe2::curand PROPERTY INTERFACE_LINK_LIBRARIES
292*da0073e9SAndroid Build Coastguard Worker        CUDA::curand)
293*da0073e9SAndroid Build Coastguard Workerendif()
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker# cufft
296*da0073e9SAndroid Build Coastguard Workeradd_library(caffe2::cufft INTERFACE IMPORTED)
297*da0073e9SAndroid Build Coastguard Workerif(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
298*da0073e9SAndroid Build Coastguard Worker    set_property(
299*da0073e9SAndroid Build Coastguard Worker        TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES
300*da0073e9SAndroid Build Coastguard Worker        CUDA::cufft_static_nocallback)
301*da0073e9SAndroid Build Coastguard Workerelse()
302*da0073e9SAndroid Build Coastguard Worker    set_property(
303*da0073e9SAndroid Build Coastguard Worker        TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES
304*da0073e9SAndroid Build Coastguard Worker        CUDA::cufft)
305*da0073e9SAndroid Build Coastguard Workerendif()
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker# nvrtc
308*da0073e9SAndroid Build Coastguard Workeradd_library(caffe2::nvrtc INTERFACE IMPORTED)
309*da0073e9SAndroid Build Coastguard Workerset_property(
310*da0073e9SAndroid Build Coastguard Worker    TARGET caffe2::nvrtc PROPERTY INTERFACE_LINK_LIBRARIES
311*da0073e9SAndroid Build Coastguard Worker    CUDA::nvrtc caffe2::cuda)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker# Add onnx namepsace definition to nvcc
314*da0073e9SAndroid Build Coastguard Workerif(ONNX_NAMESPACE)
315*da0073e9SAndroid Build Coastguard Worker  list(APPEND CUDA_NVCC_FLAGS "-DONNX_NAMESPACE=${ONNX_NAMESPACE}")
316*da0073e9SAndroid Build Coastguard Workerelse()
317*da0073e9SAndroid Build Coastguard Worker  list(APPEND CUDA_NVCC_FLAGS "-DONNX_NAMESPACE=onnx_c2")
318*da0073e9SAndroid Build Coastguard Workerendif()
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker# Don't activate VC env again for Ninja generators with MSVC on Windows if CUDAHOSTCXX is not defined
321*da0073e9SAndroid Build Coastguard Worker# by adding --use-local-env.
322*da0073e9SAndroid Build Coastguard Workerif(MSVC AND CMAKE_GENERATOR STREQUAL "Ninja" AND NOT DEFINED ENV{CUDAHOSTCXX})
323*da0073e9SAndroid Build Coastguard Worker  list(APPEND CUDA_NVCC_FLAGS "--use-local-env")
324*da0073e9SAndroid Build Coastguard Workerendif()
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker# setting nvcc arch flags
327*da0073e9SAndroid Build Coastguard Workertorch_cuda_get_nvcc_gencode_flag(NVCC_FLAGS_EXTRA)
328*da0073e9SAndroid Build Coastguard Worker# CMake 3.18 adds integrated support for architecture selection, but we can't rely on it
329*da0073e9SAndroid Build Coastguard Workerset(CMAKE_CUDA_ARCHITECTURES OFF)
330*da0073e9SAndroid Build Coastguard Workerlist(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA})
331*da0073e9SAndroid Build Coastguard Workermessage(STATUS "Added CUDA NVCC flags for: ${NVCC_FLAGS_EXTRA}")
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker# disable some nvcc diagnostic that appears in boost, glog, glags, opencv, etc.
334*da0073e9SAndroid Build Coastguard Workerforeach(diag cc_clobber_ignored
335*da0073e9SAndroid Build Coastguard Worker             field_without_dll_interface
336*da0073e9SAndroid Build Coastguard Worker             base_class_has_different_dll_interface
337*da0073e9SAndroid Build Coastguard Worker             dll_interface_conflict_none_assumed
338*da0073e9SAndroid Build Coastguard Worker             dll_interface_conflict_dllexport_assumed
339*da0073e9SAndroid Build Coastguard Worker             bad_friend_decl)
340*da0073e9SAndroid Build Coastguard Worker  list(APPEND SUPPRESS_WARNING_FLAGS --diag_suppress=${diag})
341*da0073e9SAndroid Build Coastguard Workerendforeach()
342*da0073e9SAndroid Build Coastguard Workerstring(REPLACE ";" "," SUPPRESS_WARNING_FLAGS "${SUPPRESS_WARNING_FLAGS}")
343*da0073e9SAndroid Build Coastguard Workerlist(APPEND CUDA_NVCC_FLAGS -Xcudafe ${SUPPRESS_WARNING_FLAGS})
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Workerset(CUDA_PROPAGATE_HOST_FLAGS_BLOCKLIST "-Werror")
346*da0073e9SAndroid Build Coastguard Workerif(MSVC)
347*da0073e9SAndroid Build Coastguard Worker  list(APPEND CUDA_NVCC_FLAGS "--Werror" "cross-execution-space-call")
348*da0073e9SAndroid Build Coastguard Worker  list(APPEND CUDA_NVCC_FLAGS "--no-host-device-move-forward")
349*da0073e9SAndroid Build Coastguard Workerendif()
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker# Debug and Release symbol support
352*da0073e9SAndroid Build Coastguard Workerif(MSVC)
353*da0073e9SAndroid Build Coastguard Worker  if(${CAFFE2_USE_MSVC_STATIC_RUNTIME})
354*da0073e9SAndroid Build Coastguard Worker    string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -Xcompiler /MTd")
355*da0073e9SAndroid Build Coastguard Worker    string(APPEND CMAKE_CUDA_FLAGS_MINSIZEREL " -Xcompiler /MT")
356*da0073e9SAndroid Build Coastguard Worker    string(APPEND CMAKE_CUDA_FLAGS_RELEASE " -Xcompiler /MT")
357*da0073e9SAndroid Build Coastguard Worker    string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -Xcompiler /MT")
358*da0073e9SAndroid Build Coastguard Worker  else()
359*da0073e9SAndroid Build Coastguard Worker    string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -Xcompiler /MDd")
360*da0073e9SAndroid Build Coastguard Worker    string(APPEND CMAKE_CUDA_FLAGS_MINSIZEREL " -Xcompiler /MD")
361*da0073e9SAndroid Build Coastguard Worker    string(APPEND CMAKE_CUDA_FLAGS_RELEASE " -Xcompiler /MD")
362*da0073e9SAndroid Build Coastguard Worker    string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -Xcompiler /MD")
363*da0073e9SAndroid Build Coastguard Worker  endif()
364*da0073e9SAndroid Build Coastguard Worker  if(CUDA_NVCC_FLAGS MATCHES "Zi")
365*da0073e9SAndroid Build Coastguard Worker    list(APPEND CUDA_NVCC_FLAGS "-Xcompiler" "-FS")
366*da0073e9SAndroid Build Coastguard Worker  endif()
367*da0073e9SAndroid Build Coastguard Workerelseif(CUDA_DEVICE_DEBUG)
368*da0073e9SAndroid Build Coastguard Worker  list(APPEND CUDA_NVCC_FLAGS "-g" "-G")  # -G enables device code debugging symbols
369*da0073e9SAndroid Build Coastguard Workerendif()
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker# Set expt-relaxed-constexpr to suppress Eigen warnings
372*da0073e9SAndroid Build Coastguard Workerlist(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker# Set expt-extended-lambda to support lambda on device
375*da0073e9SAndroid Build Coastguard Workerlist(APPEND CUDA_NVCC_FLAGS "--expt-extended-lambda")
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Workerforeach(FLAG ${CUDA_NVCC_FLAGS})
378*da0073e9SAndroid Build Coastguard Worker  string(FIND "${FLAG}" " " flag_space_position)
379*da0073e9SAndroid Build Coastguard Worker  if(NOT flag_space_position EQUAL -1)
380*da0073e9SAndroid Build Coastguard Worker    message(FATAL_ERROR "Found spaces in CUDA_NVCC_FLAGS entry '${FLAG}'")
381*da0073e9SAndroid Build Coastguard Worker  endif()
382*da0073e9SAndroid Build Coastguard Worker  string(APPEND CMAKE_CUDA_FLAGS " ${FLAG}")
383*da0073e9SAndroid Build Coastguard Workerendforeach()
384