xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #pragma once
9 #include <cfloat>
10 #include <cstdio>
11 #include <cmath>
12 
13 ////////////////////////////////////////////////////////////////////////////////
14 // Debugging functions
15 ////////////////////////////////////////////////////////////////////////////////
16 // Nans & inf detection
17 #define NANCHECK(frag)                         \
18   {                                            \
19     for (int _i = 0; _i < frag.size(); ++_i) { \
20       assert(std::isfinite(float(frag[_i])));  \
21       assert(!std::isnan(float(frag[_i])));    \
22     }                                          \
23   }
24 
25 // Print on the first thread of the first block
26 #if 1
27 #define PRINT_WARP_ID 0
28 #define PRINT_LANE_ID 0
29 #define PRINT_B0_T0(msg, ...)                                         \
30   if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 &&        \
31       threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
32       threadIdx.z == 0) {                                             \
33     printf(msg "\n", ##__VA_ARGS__);                                  \
34   }
35 #define PRINT_T0(msg, ...)                                            \
36   if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
37       threadIdx.z == 0) {                                             \
38     printf(msg "\n", ##__VA_ARGS__);                                  \
39   }
40 #define PRINT_TX_LX(msg, ...)                                                 \
41   for (int bx = 0; bx < gridDim.x; ++bx) {                                    \
42     for (int by = 0; by < gridDim.y; ++by) {                                  \
43       for (int bz = 0; bz < gridDim.z; ++bz) {                                \
44         for (int tx = 0; tx < blockDim.x; ++tx) {                             \
45           for (int ty = 0; ty < blockDim.y; ++ty) {                           \
46             for (int tz = 0; tz < blockDim.z; ++tz) {                         \
47               __syncthreads();                                                \
48               if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \
49                   threadIdx.x == tx && threadIdx.y == ty &&                   \
50                   threadIdx.z == tz) {                                        \
51                 printf(                                                       \
52                     "[%d,%d,%d][%d,%d,%d]" msg "\n",                          \
53                     bx,                                                       \
54                     by,                                                       \
55                     bz,                                                       \
56                     tx,                                                       \
57                     ty,                                                       \
58                     tz,                                                       \
59                     ##__VA_ARGS__);                                           \
60               }                                                               \
61             }                                                                 \
62           }                                                                   \
63         }                                                                     \
64       }                                                                       \
65     }                                                                         \
66   }
67 #else
68 #define PRINT_B0_T0
69 #define PRINT_TX_LX
70 #endif
71 
72 struct __string_view {
73   char const* data;
74   std::size_t size;
75 };
76 #if __cplusplus >= 201402L
77 template <class T>
__get_type_name()78 constexpr __string_view __get_type_name() {
79   char const* p = __PRETTY_FUNCTION__;
80   while (*p++ != '=')
81     ;
82   for (; *p == ' '; ++p)
83     ;
84   char const* p2 = p;
85   int count = 1;
86   for (;; ++p2) {
87     switch (*p2) {
88       case '[':
89         ++count;
90         break;
91       case ']':
92         --count;
93         if (!count)
94           return {p, std::size_t(p2 - p)};
95     }
96   }
97   return {};
98 }
99 #else
100 template <class T>
__get_type_name()101 constexpr __string_view __get_type_name() {
102   return {"unsupported", 11};
103 }
104 #endif
105 
106 // Print a given array
107 #define PRINT_ACCUM8_T0_L0_START(name, accum, start)  \
108   PRINT_B0_T0(                                        \
109       "%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \
110       name,                                           \
111       int(start),                                     \
112       int(start + 8),                                 \
113       float(accum[start + 0]),                        \
114       float(accum[start + 1]),                        \
115       float(accum[start + 2]),                        \
116       float(accum[start + 3]),                        \
117       float(accum[start + 4]),                        \
118       float(accum[start + 5]),                        \
119       float(accum[start + 6]),                        \
120       float(accum[start + 7]));
121 #define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0)
122 #define PRINT_FRAG_T0_L0(name, frag)                          \
123   {                                                           \
124     auto typeStr = __get_type_name<decltype(frag)>();         \
125     PRINT_B0_T0("printing %s (%s)", name, typeStr.data);      \
126     for (int _start = 0; _start < frag.size(); _start += 8) { \
127       PRINT_ACCUM8_T0_L0_START("  ", frag, _start);           \
128     }                                                         \
129     /*__syncthreads();                                        \
130     NANCHECK(frag); */                                        \
131   }
132 #define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr)   \
133   {                                                         \
134     PRINT_B0_T0("printing %s (len=%d)", name, int(length)); \
135     for (int _start = 0; _start < length; _start += incr) { \
136       PRINT_ACCUM8_T0_L0_START("  ", array, _start);        \
137     }                                                       \
138   }
139 #define PRINT_ARRAY_T0_L0(name, array, length) \
140   PRINT_ARRAY_T0_L0_INCR(name, array, length, 8)
141 
142 // Print a 4x4 matrix
143 #define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y)                                           \
144   PRINT_B0_T0(                                                                                             \
145       "%s[%d:%d, %d:%d]:\n    %f, %f, %f, %f\n    %f, %f, %f, %f\n    %f, %f, %f, %f\n    %f, %f, %f, %f", \
146       name,                                                                                                \
147       int(start_x),                                                                                        \
148       int(start_x + 4),                                                                                    \
149       int(start_y),                                                                                        \
150       int(start_y + 4),                                                                                    \
151       float(ref.at({start_x + 0, start_y + 0})),                                                           \
152       float(ref.at({start_x + 0, start_y + 1})),                                                           \
153       float(ref.at({start_x + 0, start_y + 2})),                                                           \
154       float(ref.at({start_x + 0, start_y + 3})),                                                           \
155       float(ref.at({start_x + 1, start_y + 0})),                                                           \
156       float(ref.at({start_x + 1, start_y + 1})),                                                           \
157       float(ref.at({start_x + 1, start_y + 2})),                                                           \
158       float(ref.at({start_x + 1, start_y + 3})),                                                           \
159       float(ref.at({start_x + 2, start_y + 0})),                                                           \
160       float(ref.at({start_x + 2, start_y + 1})),                                                           \
161       float(ref.at({start_x + 2, start_y + 2})),                                                           \
162       float(ref.at({start_x + 2, start_y + 3})),                                                           \
163       float(ref.at({start_x + 3, start_y + 0})),                                                           \
164       float(ref.at({start_x + 3, start_y + 1})),                                                           \
165       float(ref.at({start_x + 3, start_y + 2})),                                                           \
166       float(ref.at({start_x + 3, start_y + 3})));
167 #define PRINT_TENSOR4x4_T0_L0(name, ref) \
168   PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0)
169 
170 #define PRINT_PROBLEM_SIZE(name, ps)            \
171   PRINT_B0_T0(                                  \
172       "%s.problem_size: {.m=%d, .n=%d, .k=%d}", \
173       name,                                     \
174       int(ps.m()),                              \
175       int(ps.n()),                              \
176       int(ps.k()))
177 
178 template <typename LambdaIterator, typename LaneOffsetT, typename AccumT>
print_warp_accum(AccumT accum,LaneOffsetT lane_offset,int32_t num_rows,int32_t num_cols)179 CUTLASS_DEVICE void print_warp_accum(
180     AccumT accum,
181     LaneOffsetT lane_offset,
182     int32_t num_rows,
183     int32_t num_cols) {
184   bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 &&
185       threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0;
186   for (int row = 0; row < num_rows; ++row) {
187     for (int col = 0; col < num_cols; ++col) {
188       if (col % 32 == 0) {
189         if (is_main) {
190           printf("\nmat[%3d, %3d:%3d]", row, col, col + 32);
191         }
192         __syncthreads();
193       }
194       LambdaIterator::iterateRows(
195           lane_offset,
196           [&](int accum_m) {},
197           [&](int accum_m, int accum_n, int idx) {
198             if (row == accum_m && col == accum_n &&
199                 (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) {
200               printf(" %6.1f", float(accum[idx]));
201             }
202           },
203           [&](int accum_m) {});
204       __syncthreads();
205     }
206     if (is_main) {
207       printf("\n");
208     }
209   }
210 }
211