1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8 #pragma once
9 #include <cfloat>
10 #include <cstdio>
11 #include <cmath>
12
13 ////////////////////////////////////////////////////////////////////////////////
14 // Debugging functions
15 ////////////////////////////////////////////////////////////////////////////////
16 // Nans & inf detection
17 #define NANCHECK(frag) \
18 { \
19 for (int _i = 0; _i < frag.size(); ++_i) { \
20 assert(std::isfinite(float(frag[_i]))); \
21 assert(!std::isnan(float(frag[_i]))); \
22 } \
23 }
24
25 // Print on the first thread of the first block
26 #if 1
27 #define PRINT_WARP_ID 0
28 #define PRINT_LANE_ID 0
29 #define PRINT_B0_T0(msg, ...) \
30 if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \
31 threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
32 threadIdx.z == 0) { \
33 printf(msg "\n", ##__VA_ARGS__); \
34 }
35 #define PRINT_T0(msg, ...) \
36 if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
37 threadIdx.z == 0) { \
38 printf(msg "\n", ##__VA_ARGS__); \
39 }
40 #define PRINT_TX_LX(msg, ...) \
41 for (int bx = 0; bx < gridDim.x; ++bx) { \
42 for (int by = 0; by < gridDim.y; ++by) { \
43 for (int bz = 0; bz < gridDim.z; ++bz) { \
44 for (int tx = 0; tx < blockDim.x; ++tx) { \
45 for (int ty = 0; ty < blockDim.y; ++ty) { \
46 for (int tz = 0; tz < blockDim.z; ++tz) { \
47 __syncthreads(); \
48 if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \
49 threadIdx.x == tx && threadIdx.y == ty && \
50 threadIdx.z == tz) { \
51 printf( \
52 "[%d,%d,%d][%d,%d,%d]" msg "\n", \
53 bx, \
54 by, \
55 bz, \
56 tx, \
57 ty, \
58 tz, \
59 ##__VA_ARGS__); \
60 } \
61 } \
62 } \
63 } \
64 } \
65 } \
66 }
67 #else
68 #define PRINT_B0_T0
69 #define PRINT_TX_LX
70 #endif
71
72 struct __string_view {
73 char const* data;
74 std::size_t size;
75 };
76 #if __cplusplus >= 201402L
77 template <class T>
__get_type_name()78 constexpr __string_view __get_type_name() {
79 char const* p = __PRETTY_FUNCTION__;
80 while (*p++ != '=')
81 ;
82 for (; *p == ' '; ++p)
83 ;
84 char const* p2 = p;
85 int count = 1;
86 for (;; ++p2) {
87 switch (*p2) {
88 case '[':
89 ++count;
90 break;
91 case ']':
92 --count;
93 if (!count)
94 return {p, std::size_t(p2 - p)};
95 }
96 }
97 return {};
98 }
99 #else
100 template <class T>
__get_type_name()101 constexpr __string_view __get_type_name() {
102 return {"unsupported", 11};
103 }
104 #endif
105
106 // Print a given array
107 #define PRINT_ACCUM8_T0_L0_START(name, accum, start) \
108 PRINT_B0_T0( \
109 "%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \
110 name, \
111 int(start), \
112 int(start + 8), \
113 float(accum[start + 0]), \
114 float(accum[start + 1]), \
115 float(accum[start + 2]), \
116 float(accum[start + 3]), \
117 float(accum[start + 4]), \
118 float(accum[start + 5]), \
119 float(accum[start + 6]), \
120 float(accum[start + 7]));
121 #define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0)
122 #define PRINT_FRAG_T0_L0(name, frag) \
123 { \
124 auto typeStr = __get_type_name<decltype(frag)>(); \
125 PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \
126 for (int _start = 0; _start < frag.size(); _start += 8) { \
127 PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \
128 } \
129 /*__syncthreads(); \
130 NANCHECK(frag); */ \
131 }
132 #define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \
133 { \
134 PRINT_B0_T0("printing %s (len=%d)", name, int(length)); \
135 for (int _start = 0; _start < length; _start += incr) { \
136 PRINT_ACCUM8_T0_L0_START(" ", array, _start); \
137 } \
138 }
139 #define PRINT_ARRAY_T0_L0(name, array, length) \
140 PRINT_ARRAY_T0_L0_INCR(name, array, length, 8)
141
142 // Print a 4x4 matrix
143 #define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \
144 PRINT_B0_T0( \
145 "%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \
146 name, \
147 int(start_x), \
148 int(start_x + 4), \
149 int(start_y), \
150 int(start_y + 4), \
151 float(ref.at({start_x + 0, start_y + 0})), \
152 float(ref.at({start_x + 0, start_y + 1})), \
153 float(ref.at({start_x + 0, start_y + 2})), \
154 float(ref.at({start_x + 0, start_y + 3})), \
155 float(ref.at({start_x + 1, start_y + 0})), \
156 float(ref.at({start_x + 1, start_y + 1})), \
157 float(ref.at({start_x + 1, start_y + 2})), \
158 float(ref.at({start_x + 1, start_y + 3})), \
159 float(ref.at({start_x + 2, start_y + 0})), \
160 float(ref.at({start_x + 2, start_y + 1})), \
161 float(ref.at({start_x + 2, start_y + 2})), \
162 float(ref.at({start_x + 2, start_y + 3})), \
163 float(ref.at({start_x + 3, start_y + 0})), \
164 float(ref.at({start_x + 3, start_y + 1})), \
165 float(ref.at({start_x + 3, start_y + 2})), \
166 float(ref.at({start_x + 3, start_y + 3})));
167 #define PRINT_TENSOR4x4_T0_L0(name, ref) \
168 PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0)
169
170 #define PRINT_PROBLEM_SIZE(name, ps) \
171 PRINT_B0_T0( \
172 "%s.problem_size: {.m=%d, .n=%d, .k=%d}", \
173 name, \
174 int(ps.m()), \
175 int(ps.n()), \
176 int(ps.k()))
177
178 template <typename LambdaIterator, typename LaneOffsetT, typename AccumT>
print_warp_accum(AccumT accum,LaneOffsetT lane_offset,int32_t num_rows,int32_t num_cols)179 CUTLASS_DEVICE void print_warp_accum(
180 AccumT accum,
181 LaneOffsetT lane_offset,
182 int32_t num_rows,
183 int32_t num_cols) {
184 bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 &&
185 threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0;
186 for (int row = 0; row < num_rows; ++row) {
187 for (int col = 0; col < num_cols; ++col) {
188 if (col % 32 == 0) {
189 if (is_main) {
190 printf("\nmat[%3d, %3d:%3d]", row, col, col + 32);
191 }
192 __syncthreads();
193 }
194 LambdaIterator::iterateRows(
195 lane_offset,
196 [&](int accum_m) {},
197 [&](int accum_m, int accum_n, int idx) {
198 if (row == accum_m && col == accum_n &&
199 (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) {
200 printf(" %6.1f", float(accum[idx]));
201 }
202 },
203 [&](int accum_m) {});
204 __syncthreads();
205 }
206 if (is_main) {
207 printf("\n");
208 }
209 }
210 }
211