1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRANSPOSE_KERNELS_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_TRANSPOSE_KERNELS_H_ 18 19 #include <cstdint> 20 21 #include "third_party/eigen3/Eigen/Core" 22 23 namespace xla { 24 25 // Generic transpose kernel. 26 // 27 // All of the kernels that follow in this file are optimized versions of this 28 // generic kernel, specialized to particular block sizes and data types. 29 // 30 // The transpose kernel requires its input to be contiguous in one of the two 31 // dimensions being transposed, and the output to be contiguous in the other 32 // dimension. 33 // 34 // lda, ldb are strides in bytes. 35 template <typename T, int bs> 36 struct TransposeMicroKernel { ApplyTransposeMicroKernel37 static void Apply(const char* __restrict a, int64_t lda, char* __restrict b, 38 int64_t ldb) { 39 for (int i = 0; i < bs; ++i) { 40 for (int j = 0; j < bs; ++j) { 41 *reinterpret_cast<T*>(b + i * ldb + j * sizeof(T)) = 42 *reinterpret_cast<T const*>(a + j * lda + i * sizeof(T)); 43 } 44 } 45 } 46 }; 47 48 // TODO(phawkins): it would be nice to remove the use of Eigen here, and instead 49 // allow for runtime dispatch of, say, AVX or AVX2 kernels where they are 50 // supported. On the other hand, using Eigen makes for easier cross-platform 51 // portability. 52 #ifdef EIGEN_VECTORIZE_AVX 53 54 template <> 55 struct TransposeMicroKernel<uint8_t, /*bs=*/4> { 56 static void Apply(const char* __restrict a, int64_t lda, char* __restrict b, 57 int64_t ldb) { 58 __m128i x = _mm_set_epi32(*reinterpret_cast<const uint32_t*>(a + lda * 0), 59 *reinterpret_cast<const uint32_t*>(a + lda * 1), 60 *reinterpret_cast<const uint32_t*>(a + lda * 2), 61 *reinterpret_cast<const uint32_t*>(a + lda * 3)); 62 __m128i mask = 63 _mm_setr_epi8(12, 8, 4, 0, 13, 9, 5, 1, 14, 10, 6, 2, 15, 11, 7, 3); 64 x = _mm_shuffle_epi8(x, mask); 65 *reinterpret_cast<uint32_t*>(b + ldb * 0) = _mm_extract_epi32(x, 0); 66 *reinterpret_cast<uint32_t*>(b + ldb * 1) = _mm_extract_epi32(x, 1); 67 *reinterpret_cast<uint32_t*>(b + ldb * 2) = _mm_extract_epi32(x, 2); 68 *reinterpret_cast<uint32_t*>(b + ldb * 3) = _mm_extract_epi32(x, 3); 69 } 70 }; 71 72 // TODO(phawkins): add an 8x8 byte transpose kernel. 73 74 // TODO(phawkins): Eigen doesn't have a SSE/AVX byte Packet16c type. Add one 75 // and call it here rather than using AVX intrinsics. 76 template <> 77 struct TransposeMicroKernel<uint8_t, /*bs=*/16> { 78 static void Apply(const char* __restrict a, int64_t lda, char* __restrict b, 79 int64_t ldb) { 80 std::array<__m128i, 16> packet; 81 for (int i = 0; i < 16; ++i) { 82 packet[i] = 83 _mm_loadu_si128(reinterpret_cast<const __m128i*>(a + lda * i)); 84 } 85 86 // If we number the elements in the input thus: 87 // kernel.packet[ 0] = {00, 01, 02, 03, 04, 05, 06, 07, 08, 09, 0a, 0b, 0c, 88 // 0d, 0e, 0f} 89 // kernel.packet[ 1] = {10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1a, 1b, 1c, 90 // 1d, 1e, 1f} 91 // ... 92 // kernel.packet[15] = {f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, fa, fb, fc, 93 // fd, fe, ff}, 94 // 95 // the desired output is: 96 // kernel.packet[ 0] = {00, 10, 20, 30, 40, 50, 60, 70, 80, 90, a0, b0, c0, 97 // d0, e0, f0} 98 // kernel.packet[ 1] = {01, 11, 21, 31, 41, 51, 61, 71, 81, 91, a1, b1, c1, 99 // d1, e1, f1} 100 // ... 101 // kernel.packet[15] = {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, af, bf, cf, 102 // df, ef, ff}, 103 // 00 10 01 11 02 12 03 13 04 14 05 15 06 16 07 17 104 __m128i t0 = _mm_unpacklo_epi8(packet[0], packet[1]); 105 // 08 18 09 19 0a 1a 0b 1b 0c 1c 0d 1d 0e 1e 0f 1f 106 __m128i t1 = _mm_unpackhi_epi8(packet[0], packet[1]); 107 // 20 30 21 31 22 32 ... 27 37 108 __m128i t2 = _mm_unpacklo_epi8(packet[2], packet[3]); 109 // 28 38 29 39 2a 3a ... 2f 3f 110 __m128i t3 = _mm_unpackhi_epi8(packet[2], packet[3]); 111 // 40 50 41 51 42 52 47 57 112 __m128i t4 = _mm_unpacklo_epi8(packet[4], packet[5]); 113 // 48 58 49 59 4a 5a 114 __m128i t5 = _mm_unpackhi_epi8(packet[4], packet[5]); 115 __m128i t6 = _mm_unpacklo_epi8(packet[6], packet[7]); 116 __m128i t7 = _mm_unpackhi_epi8(packet[6], packet[7]); 117 __m128i t8 = _mm_unpacklo_epi8(packet[8], packet[9]); 118 __m128i t9 = _mm_unpackhi_epi8(packet[8], packet[9]); 119 __m128i ta = _mm_unpacklo_epi8(packet[10], packet[11]); 120 __m128i tb = _mm_unpackhi_epi8(packet[10], packet[11]); 121 __m128i tc = _mm_unpacklo_epi8(packet[12], packet[13]); 122 __m128i td = _mm_unpackhi_epi8(packet[12], packet[13]); 123 __m128i te = _mm_unpacklo_epi8(packet[14], packet[15]); 124 __m128i tf = _mm_unpackhi_epi8(packet[14], packet[15]); 125 126 // 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33 127 __m128i s0 = _mm_unpacklo_epi16(t0, t2); 128 __m128i s1 = _mm_unpackhi_epi16(t0, t2); // 04 14 24 34 129 __m128i s2 = _mm_unpacklo_epi16(t1, t3); // 08 18 28 38 ... 130 __m128i s3 = _mm_unpackhi_epi16(t1, t3); // 0c 1c 2c 3c ... 131 // 40 50 60 70 41 51 61 71 42 52 62 72 43 53 63 73 132 __m128i s4 = _mm_unpacklo_epi16(t4, t6); 133 __m128i s5 = _mm_unpackhi_epi16(t4, t6); // 44 54 64 74 ... 134 __m128i s6 = _mm_unpacklo_epi16(t5, t7); 135 __m128i s7 = _mm_unpackhi_epi16(t5, t7); 136 __m128i s8 = _mm_unpacklo_epi16(t8, ta); 137 __m128i s9 = _mm_unpackhi_epi16(t8, ta); 138 __m128i sa = _mm_unpacklo_epi16(t9, tb); 139 __m128i sb = _mm_unpackhi_epi16(t9, tb); 140 __m128i sc = _mm_unpacklo_epi16(tc, te); 141 __m128i sd = _mm_unpackhi_epi16(tc, te); 142 __m128i se = _mm_unpacklo_epi16(td, tf); 143 __m128i sf = _mm_unpackhi_epi16(td, tf); 144 145 // 00 10 20 30 40 50 60 70 01 11 21 31 41 51 61 71 146 __m128i u0 = _mm_unpacklo_epi32(s0, s4); 147 // 02 12 22 32 42 52 62 72 03 13 23 33 43 53 63 73 148 __m128i u1 = _mm_unpackhi_epi32(s0, s4); 149 __m128i u2 = _mm_unpacklo_epi32(s1, s5); 150 __m128i u3 = _mm_unpackhi_epi32(s1, s5); 151 __m128i u4 = _mm_unpacklo_epi32(s2, s6); 152 __m128i u5 = _mm_unpackhi_epi32(s2, s6); 153 __m128i u6 = _mm_unpacklo_epi32(s3, s7); 154 __m128i u7 = _mm_unpackhi_epi32(s3, s7); 155 __m128i u8 = _mm_unpacklo_epi32(s8, sc); 156 __m128i u9 = _mm_unpackhi_epi32(s8, sc); 157 __m128i ua = _mm_unpacklo_epi32(s9, sd); 158 __m128i ub = _mm_unpackhi_epi32(s9, sd); 159 __m128i uc = _mm_unpacklo_epi32(sa, se); 160 __m128i ud = _mm_unpackhi_epi32(sa, se); 161 __m128i ue = _mm_unpacklo_epi32(sb, sf); 162 __m128i uf = _mm_unpackhi_epi32(sb, sf); 163 164 packet[0] = _mm_unpacklo_epi64(u0, u8); 165 packet[1] = _mm_unpackhi_epi64(u0, u8); 166 packet[2] = _mm_unpacklo_epi64(u1, u9); 167 packet[3] = _mm_unpackhi_epi64(u1, u9); 168 packet[4] = _mm_unpacklo_epi64(u2, ua); 169 packet[5] = _mm_unpackhi_epi64(u2, ua); 170 packet[6] = _mm_unpacklo_epi64(u3, ub); 171 packet[7] = _mm_unpackhi_epi64(u3, ub); 172 packet[8] = _mm_unpacklo_epi64(u4, uc); 173 packet[9] = _mm_unpackhi_epi64(u4, uc); 174 packet[10] = _mm_unpacklo_epi64(u5, ud); 175 packet[11] = _mm_unpackhi_epi64(u5, ud); 176 packet[12] = _mm_unpacklo_epi64(u6, ue); 177 packet[13] = _mm_unpackhi_epi64(u6, ue); 178 packet[14] = _mm_unpacklo_epi64(u7, uf); 179 packet[15] = _mm_unpackhi_epi64(u7, uf); 180 for (int i = 0; i < 16; ++i) { 181 _mm_storeu_si128(reinterpret_cast<__m128i*>(b + ldb * i), packet[i]); 182 } 183 } 184 }; 185 186 // TODO(phawkins): add an 4x4 uint16_t transpose kernel. 187 188 template <> 189 struct TransposeMicroKernel<uint16_t, /*bs=*/8> { 190 static void Apply(const char* __restrict a, int64_t lda, char* __restrict b, 191 int64_t ldb) { 192 using Eigen::internal::Packet8h; 193 using Eigen::internal::PacketBlock; 194 constexpr int bs = 8; 195 PacketBlock<Packet8h, bs> block; 196 for (int i = 0; i < bs; ++i) { 197 block.packet[i] = Eigen::internal::ploadu<Packet8h>( 198 reinterpret_cast<const Eigen::half*>(a + lda * i)); 199 } 200 Eigen::internal::ptranspose(block); 201 for (int i = 0; i < bs; ++i) { 202 Eigen::internal::pstoreu<Eigen::half>( 203 reinterpret_cast<Eigen::half*>(b + ldb * i), block.packet[i]); 204 } 205 } 206 }; 207 208 template <> 209 struct TransposeMicroKernel<uint32_t, /*bs=*/4> { 210 static void Apply(const char* __restrict a, int64_t lda, char* __restrict b, 211 int64_t ldb) { 212 using Eigen::internal::Packet4f; 213 using Eigen::internal::PacketBlock; 214 constexpr int bs = 4; 215 PacketBlock<Packet4f, bs> block; 216 for (int i = 0; i < bs; ++i) { 217 block.packet[i] = Eigen::internal::ploadu<Packet4f>( 218 reinterpret_cast<const float*>(a + lda * i)); 219 } 220 Eigen::internal::ptranspose(block); 221 for (int i = 0; i < bs; ++i) { 222 Eigen::internal::pstoreu<float>(reinterpret_cast<float*>(b + ldb * i), 223 block.packet[i]); 224 } 225 } 226 }; 227 228 template <> 229 struct TransposeMicroKernel<uint32_t, /*bs=*/8> { 230 static void Apply(const char* __restrict a, int64_t lda, char* __restrict b, 231 int64_t ldb) { 232 using Eigen::internal::Packet8f; 233 using Eigen::internal::PacketBlock; 234 constexpr int bs = 8; 235 PacketBlock<Packet8f, bs> block; 236 for (int i = 0; i < bs; ++i) { 237 block.packet[i] = Eigen::internal::ploadu<Packet8f>( 238 reinterpret_cast<const float*>(a + lda * i)); 239 } 240 Eigen::internal::ptranspose(block); 241 for (int i = 0; i < bs; ++i) { 242 Eigen::internal::pstoreu<float>(reinterpret_cast<float*>(b + ldb * i), 243 block.packet[i]); 244 } 245 } 246 }; 247 248 template <> 249 struct TransposeMicroKernel<uint64_t, /*bs=*/2> { 250 static void Apply(const char* __restrict a, int64_t lda, char* __restrict b, 251 int64_t ldb) { 252 using Eigen::internal::Packet2d; 253 using Eigen::internal::PacketBlock; 254 constexpr int bs = 2; 255 PacketBlock<Packet2d, bs> block; 256 for (int i = 0; i < bs; ++i) { 257 block.packet[i] = Eigen::internal::ploadu<Packet2d>( 258 reinterpret_cast<const double*>(a + lda * i)); 259 } 260 Eigen::internal::ptranspose(block); 261 for (int i = 0; i < bs; ++i) { 262 Eigen::internal::pstoreu<double>(reinterpret_cast<double*>(b + ldb * i), 263 block.packet[i]); 264 } 265 } 266 }; 267 268 template <> 269 struct TransposeMicroKernel<uint64_t, /*bs=*/4> { 270 static void Apply(const char* __restrict a, int64_t lda, char* __restrict b, 271 int64_t ldb) { 272 using Eigen::internal::Packet4d; 273 using Eigen::internal::PacketBlock; 274 constexpr int bs = 4; 275 PacketBlock<Packet4d, bs> block; 276 for (int i = 0; i < bs; ++i) { 277 block.packet[i] = Eigen::internal::ploadu<Packet4d>( 278 reinterpret_cast<const double*>(a + lda * i)); 279 } 280 Eigen::internal::ptranspose(block); 281 for (int i = 0; i < bs; ++i) { 282 Eigen::internal::pstoreu<double>(reinterpret_cast<double*>(b + ldb * i), 283 block.packet[i]); 284 } 285 } 286 }; 287 288 #endif // EIGEN_VECTORIZE_AVX 289 290 } // namespace xla 291 292 #endif // TENSORFLOW_COMPILER_XLA_PJRT_TRANSPOSE_KERNELS_H_ 293