1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAMacros.h> 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker #include <cstdint> 6*da0073e9SAndroid Build Coastguard Worker #include <memory> 7*da0073e9SAndroid Build Coastguard Worker #include <mutex> 8*da0073e9SAndroid Build Coastguard Worker #include <string> 9*da0073e9SAndroid Build Coastguard Worker #include <utility> 10*da0073e9SAndroid Build Coastguard Worker #include <vector> 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker #ifdef USE_CUDA 13*da0073e9SAndroid Build Coastguard Worker #define TORCH_USE_CUDA_DSA 14*da0073e9SAndroid Build Coastguard Worker #endif 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker /// Number of assertion failure messages we can store. If this is too small 17*da0073e9SAndroid Build Coastguard Worker /// threads will fail silently. 18*da0073e9SAndroid Build Coastguard Worker constexpr int C10_CUDA_DSA_ASSERTION_COUNT = 10; 19*da0073e9SAndroid Build Coastguard Worker constexpr int C10_CUDA_DSA_MAX_STR_LEN = 512; 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda { 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker /// Holds information about any device-side assertions that fail. 24*da0073e9SAndroid Build Coastguard Worker /// Held in managed memory and access by both the CPU and the GPU. 25*da0073e9SAndroid Build Coastguard Worker struct DeviceAssertionData { 26*da0073e9SAndroid Build Coastguard Worker /// Stringification of the assertion 27*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(*-c-arrays) 28*da0073e9SAndroid Build Coastguard Worker char assertion_msg[C10_CUDA_DSA_MAX_STR_LEN]{}; 29*da0073e9SAndroid Build Coastguard Worker /// File the assertion was in 30*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(*-c-arrays) 31*da0073e9SAndroid Build Coastguard Worker char filename[C10_CUDA_DSA_MAX_STR_LEN]{}; 32*da0073e9SAndroid Build Coastguard Worker /// Name of the function the assertion was in 33*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(*-c-arrays) 34*da0073e9SAndroid Build Coastguard Worker char function_name[C10_CUDA_DSA_MAX_STR_LEN]{}; 35*da0073e9SAndroid Build Coastguard Worker /// Line number the assertion was at 36*da0073e9SAndroid Build Coastguard Worker int line_number{}; 37*da0073e9SAndroid Build Coastguard Worker /// Number uniquely identifying the kernel launch that triggered the assertion 38*da0073e9SAndroid Build Coastguard Worker uint32_t caller{}; 39*da0073e9SAndroid Build Coastguard Worker /// block_id of the thread that failed the assertion 40*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(*-c-arrays) 41*da0073e9SAndroid Build Coastguard Worker int32_t block_id[3]{}; 42*da0073e9SAndroid Build Coastguard Worker /// third_id of the thread that failed the assertion 43*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(*-c-arrays) 44*da0073e9SAndroid Build Coastguard Worker int32_t thread_id[3]{}; 45*da0073e9SAndroid Build Coastguard Worker }; 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker /// Used to hold assertions generated by the device 48*da0073e9SAndroid Build Coastguard Worker /// Held in managed memory and access by both the CPU and the GPU. 49*da0073e9SAndroid Build Coastguard Worker struct DeviceAssertionsData { 50*da0073e9SAndroid Build Coastguard Worker /// Total number of assertions found; a subset of thse will be recorded 51*da0073e9SAndroid Build Coastguard Worker /// in `assertions` 52*da0073e9SAndroid Build Coastguard Worker int32_t assertion_count{}; 53*da0073e9SAndroid Build Coastguard Worker /// An array of assertions that will be written to in a race-free manner 54*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(*-c-arrays) 55*da0073e9SAndroid Build Coastguard Worker DeviceAssertionData assertions[C10_CUDA_DSA_ASSERTION_COUNT]{}; 56*da0073e9SAndroid Build Coastguard Worker }; 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker /// Use to hold info about kernel launches so that we can run kernels 59*da0073e9SAndroid Build Coastguard Worker /// asynchronously and still associate launches with device-side 60*da0073e9SAndroid Build Coastguard Worker /// assertion failures 61*da0073e9SAndroid Build Coastguard Worker struct CUDAKernelLaunchInfo { 62*da0073e9SAndroid Build Coastguard Worker /// Filename of the code where the kernel was launched from 63*da0073e9SAndroid Build Coastguard Worker const char* launch_filename; 64*da0073e9SAndroid Build Coastguard Worker /// Function from which the kernel was launched 65*da0073e9SAndroid Build Coastguard Worker const char* launch_function; 66*da0073e9SAndroid Build Coastguard Worker /// Line number of where the code was launched from 67*da0073e9SAndroid Build Coastguard Worker uint32_t launch_linenum; 68*da0073e9SAndroid Build Coastguard Worker /// Backtrace of where the kernel was launched from, only populated if 69*da0073e9SAndroid Build Coastguard Worker /// CUDAKernelLaunchRegistry::gather_launch_stacktrace is True 70*da0073e9SAndroid Build Coastguard Worker std::string launch_stacktrace; 71*da0073e9SAndroid Build Coastguard Worker /// Kernel that was launched 72*da0073e9SAndroid Build Coastguard Worker const char* kernel_name; 73*da0073e9SAndroid Build Coastguard Worker /// Device the kernel was launched on 74*da0073e9SAndroid Build Coastguard Worker int device; 75*da0073e9SAndroid Build Coastguard Worker /// Stream the kernel was launched on 76*da0073e9SAndroid Build Coastguard Worker int32_t stream; 77*da0073e9SAndroid Build Coastguard Worker /// A number that uniquely identifies the kernel launch 78*da0073e9SAndroid Build Coastguard Worker uint64_t generation_number; 79*da0073e9SAndroid Build Coastguard Worker }; 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker /// Circular buffer used to hold information about kernel launches 82*da0073e9SAndroid Build Coastguard Worker /// this is later used to reconstruct how a device-side kernel assertion failure 83*da0073e9SAndroid Build Coastguard Worker /// occurred CUDAKernelLaunchRegistry is used as a singleton 84*da0073e9SAndroid Build Coastguard Worker class C10_CUDA_API CUDAKernelLaunchRegistry { 85*da0073e9SAndroid Build Coastguard Worker private: 86*da0073e9SAndroid Build Coastguard Worker /// Assume that this is the max number of kernel launches that might ever be 87*da0073e9SAndroid Build Coastguard Worker /// enqueued across all streams on a single device 88*da0073e9SAndroid Build Coastguard Worker static constexpr int max_kernel_launches = 1024; 89*da0073e9SAndroid Build Coastguard Worker /// How many kernel launch infos we've inserted. Used to ensure that circular 90*da0073e9SAndroid Build Coastguard Worker /// queue doesn't provide false information by always increasing, but also to 91*da0073e9SAndroid Build Coastguard Worker /// mark where we are inserting into the queue 92*da0073e9SAndroid Build Coastguard Worker #ifdef TORCH_USE_CUDA_DSA 93*da0073e9SAndroid Build Coastguard Worker uint64_t generation_number = 0; 94*da0073e9SAndroid Build Coastguard Worker #endif 95*da0073e9SAndroid Build Coastguard Worker /// Shared mutex between writer and accessor to ensure multi-threaded safety. 96*da0073e9SAndroid Build Coastguard Worker mutable std::mutex read_write_mutex; 97*da0073e9SAndroid Build Coastguard Worker /// Used to ensure prevent race conditions in GPU memory allocation 98*da0073e9SAndroid Build Coastguard Worker mutable std::mutex gpu_alloc_mutex; 99*da0073e9SAndroid Build Coastguard Worker /// Pointer to managed memory keeping track of device-side assertions. There 100*da0073e9SAndroid Build Coastguard Worker /// is one entry for each possible device the process might work with. Unused 101*da0073e9SAndroid Build Coastguard Worker /// entries are nullptrs. We could also use an unordered_set here, but this 102*da0073e9SAndroid Build Coastguard Worker /// vector design will be faster and the wasted memory is small since we 103*da0073e9SAndroid Build Coastguard Worker /// expect the number of GPUs per node will always be small 104*da0073e9SAndroid Build Coastguard Worker std::vector< 105*da0073e9SAndroid Build Coastguard Worker std::unique_ptr<DeviceAssertionsData, void (*)(DeviceAssertionsData*)>> 106*da0073e9SAndroid Build Coastguard Worker uvm_assertions; 107*da0073e9SAndroid Build Coastguard Worker /// A single circular buffer holds information about every kernel launch the 108*da0073e9SAndroid Build Coastguard Worker /// process makes across all devices. 109*da0073e9SAndroid Build Coastguard Worker std::vector<CUDAKernelLaunchInfo> kernel_launches; 110*da0073e9SAndroid Build Coastguard Worker bool check_env_for_enable_launch_stacktracing() const; 111*da0073e9SAndroid Build Coastguard Worker bool check_env_for_dsa_enabled() const; 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker public: 114*da0073e9SAndroid Build Coastguard Worker CUDAKernelLaunchRegistry(); 115*da0073e9SAndroid Build Coastguard Worker /// Register a new kernel launch and obtain a generation number back to be 116*da0073e9SAndroid Build Coastguard Worker /// passed to the kernel 117*da0073e9SAndroid Build Coastguard Worker uint32_t insert( 118*da0073e9SAndroid Build Coastguard Worker const char* launch_filename, 119*da0073e9SAndroid Build Coastguard Worker const char* launch_function, 120*da0073e9SAndroid Build Coastguard Worker const uint32_t launch_linenum, 121*da0073e9SAndroid Build Coastguard Worker const char* kernel_name, 122*da0073e9SAndroid Build Coastguard Worker const int32_t stream_id); 123*da0073e9SAndroid Build Coastguard Worker /// Get copies of the kernel launch registry and each device's assertion 124*da0073e9SAndroid Build Coastguard Worker /// failure buffer so they can be inspected without raising race conditions 125*da0073e9SAndroid Build Coastguard Worker std:: 126*da0073e9SAndroid Build Coastguard Worker pair<std::vector<DeviceAssertionsData>, std::vector<CUDAKernelLaunchInfo>> 127*da0073e9SAndroid Build Coastguard Worker snapshot() const; 128*da0073e9SAndroid Build Coastguard Worker /// Get a pointer to the current device's assertion failure buffer. If no such 129*da0073e9SAndroid Build Coastguard Worker /// buffer exists then one is created. This means that the first kernel launch 130*da0073e9SAndroid Build Coastguard Worker /// made on each device will be slightly slower because memory allocations are 131*da0073e9SAndroid Build Coastguard Worker /// required 132*da0073e9SAndroid Build Coastguard Worker DeviceAssertionsData* get_uvm_assertions_ptr_for_current_device(); 133*da0073e9SAndroid Build Coastguard Worker /// Gets the global singleton of the registry 134*da0073e9SAndroid Build Coastguard Worker static CUDAKernelLaunchRegistry& get_singleton_ref(); 135*da0073e9SAndroid Build Coastguard Worker /// If not all devices support DSA, we disable it 136*da0073e9SAndroid Build Coastguard Worker const bool do_all_devices_support_managed_memory = false; 137*da0073e9SAndroid Build Coastguard Worker /// Whether or not to gather stack traces when launching kernels 138*da0073e9SAndroid Build Coastguard Worker bool gather_launch_stacktrace = false; 139*da0073e9SAndroid Build Coastguard Worker /// Whether or not host-side DSA is enabled or disabled at run-time 140*da0073e9SAndroid Build Coastguard Worker /// Note: Device-side code cannot be enabled/disabled at run-time 141*da0073e9SAndroid Build Coastguard Worker bool enabled_at_runtime = false; 142*da0073e9SAndroid Build Coastguard Worker /// Whether or not a device has indicated a failure 143*da0073e9SAndroid Build Coastguard Worker bool has_failed() const; 144*da0073e9SAndroid Build Coastguard Worker #ifdef TORCH_USE_CUDA_DSA 145*da0073e9SAndroid Build Coastguard Worker const bool enabled_at_compile_time = true; 146*da0073e9SAndroid Build Coastguard Worker #else 147*da0073e9SAndroid Build Coastguard Worker const bool enabled_at_compile_time = false; 148*da0073e9SAndroid Build Coastguard Worker #endif 149*da0073e9SAndroid Build Coastguard Worker }; 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker std::string c10_retrieve_device_side_assertion_info(); 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker // Each kernel launched with TORCH_DSA_KERNEL_LAUNCH 156*da0073e9SAndroid Build Coastguard Worker // requires the same input arguments. We introduce the following macro to 157*da0073e9SAndroid Build Coastguard Worker // standardize these. 158*da0073e9SAndroid Build Coastguard Worker #define TORCH_DSA_KERNEL_ARGS \ 159*da0073e9SAndroid Build Coastguard Worker [[maybe_unused]] c10::cuda::DeviceAssertionsData *const assertions_data, \ 160*da0073e9SAndroid Build Coastguard Worker [[maybe_unused]] uint32_t assertion_caller_id 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker // This macro can be used to pass the DSA arguments onward to another 163*da0073e9SAndroid Build Coastguard Worker // function 164*da0073e9SAndroid Build Coastguard Worker #define TORCH_DSA_KERNEL_ARGS_PASS assertions_data, assertion_caller_id 165