xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/transpose_kernels.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRANSPOSE_KERNELS_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_TRANSPOSE_KERNELS_H_
18 
19 #include <cstdint>
20 
21 #include "third_party/eigen3/Eigen/Core"
22 
23 namespace xla {
24 
25 // Generic transpose kernel.
26 //
27 // All of the kernels that follow in this file are optimized versions of this
28 // generic kernel, specialized to particular block sizes and data types.
29 //
30 // The transpose kernel requires its input to be contiguous in one of the two
31 // dimensions being transposed, and the output to be contiguous in the other
32 // dimension.
33 //
34 // lda, ldb are strides in bytes.
35 template <typename T, int bs>
36 struct TransposeMicroKernel {
ApplyTransposeMicroKernel37   static void Apply(const char* __restrict a, int64_t lda, char* __restrict b,
38                     int64_t ldb) {
39     for (int i = 0; i < bs; ++i) {
40       for (int j = 0; j < bs; ++j) {
41         *reinterpret_cast<T*>(b + i * ldb + j * sizeof(T)) =
42             *reinterpret_cast<T const*>(a + j * lda + i * sizeof(T));
43       }
44     }
45   }
46 };
47 
48 // TODO(phawkins): it would be nice to remove the use of Eigen here, and instead
49 // allow for runtime dispatch of, say, AVX or AVX2 kernels where they are
50 // supported. On the other hand, using Eigen makes for easier cross-platform
51 // portability.
52 #ifdef EIGEN_VECTORIZE_AVX
53 
54 template <>
55 struct TransposeMicroKernel<uint8_t, /*bs=*/4> {
56   static void Apply(const char* __restrict a, int64_t lda, char* __restrict b,
57                     int64_t ldb) {
58     __m128i x = _mm_set_epi32(*reinterpret_cast<const uint32_t*>(a + lda * 0),
59                               *reinterpret_cast<const uint32_t*>(a + lda * 1),
60                               *reinterpret_cast<const uint32_t*>(a + lda * 2),
61                               *reinterpret_cast<const uint32_t*>(a + lda * 3));
62     __m128i mask =
63         _mm_setr_epi8(12, 8, 4, 0, 13, 9, 5, 1, 14, 10, 6, 2, 15, 11, 7, 3);
64     x = _mm_shuffle_epi8(x, mask);
65     *reinterpret_cast<uint32_t*>(b + ldb * 0) = _mm_extract_epi32(x, 0);
66     *reinterpret_cast<uint32_t*>(b + ldb * 1) = _mm_extract_epi32(x, 1);
67     *reinterpret_cast<uint32_t*>(b + ldb * 2) = _mm_extract_epi32(x, 2);
68     *reinterpret_cast<uint32_t*>(b + ldb * 3) = _mm_extract_epi32(x, 3);
69   }
70 };
71 
72 // TODO(phawkins): add an 8x8 byte transpose kernel.
73 
74 // TODO(phawkins): Eigen doesn't have a SSE/AVX byte Packet16c type. Add one
75 // and call it here rather than using AVX intrinsics.
76 template <>
77 struct TransposeMicroKernel<uint8_t, /*bs=*/16> {
78   static void Apply(const char* __restrict a, int64_t lda, char* __restrict b,
79                     int64_t ldb) {
80     std::array<__m128i, 16> packet;
81     for (int i = 0; i < 16; ++i) {
82       packet[i] =
83           _mm_loadu_si128(reinterpret_cast<const __m128i*>(a + lda * i));
84     }
85 
86     // If we number the elements in the input thus:
87     // kernel.packet[ 0] = {00, 01, 02, 03, 04, 05, 06, 07, 08, 09, 0a, 0b, 0c,
88     //                      0d, 0e, 0f}
89     // kernel.packet[ 1] = {10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1a, 1b, 1c,
90     //                      1d, 1e, 1f}
91     // ...
92     // kernel.packet[15] = {f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, fa, fb, fc,
93     //                      fd, fe, ff},
94     //
95     // the desired output is:
96     // kernel.packet[ 0] = {00, 10, 20, 30, 40, 50, 60, 70, 80, 90, a0, b0, c0,
97     //                      d0, e0, f0}
98     // kernel.packet[ 1] = {01, 11, 21, 31, 41, 51, 61, 71, 81, 91, a1, b1, c1,
99     //                      d1, e1, f1}
100     // ...
101     // kernel.packet[15] = {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, af, bf, cf,
102     //                      df, ef, ff},
103     // 00 10 01 11 02 12 03 13 04 14 05 15 06 16 07 17
104     __m128i t0 = _mm_unpacklo_epi8(packet[0], packet[1]);
105     // 08 18 09 19 0a 1a 0b 1b 0c 1c 0d 1d 0e 1e 0f 1f
106     __m128i t1 = _mm_unpackhi_epi8(packet[0], packet[1]);
107     // 20 30 21 31 22 32 ...                     27 37
108     __m128i t2 = _mm_unpacklo_epi8(packet[2], packet[3]);
109     // 28 38 29 39 2a 3a ...                     2f 3f
110     __m128i t3 = _mm_unpackhi_epi8(packet[2], packet[3]);
111     // 40 50 41 51 42 52                         47 57
112     __m128i t4 = _mm_unpacklo_epi8(packet[4], packet[5]);
113     // 48 58 49 59 4a 5a
114     __m128i t5 = _mm_unpackhi_epi8(packet[4], packet[5]);
115     __m128i t6 = _mm_unpacklo_epi8(packet[6], packet[7]);
116     __m128i t7 = _mm_unpackhi_epi8(packet[6], packet[7]);
117     __m128i t8 = _mm_unpacklo_epi8(packet[8], packet[9]);
118     __m128i t9 = _mm_unpackhi_epi8(packet[8], packet[9]);
119     __m128i ta = _mm_unpacklo_epi8(packet[10], packet[11]);
120     __m128i tb = _mm_unpackhi_epi8(packet[10], packet[11]);
121     __m128i tc = _mm_unpacklo_epi8(packet[12], packet[13]);
122     __m128i td = _mm_unpackhi_epi8(packet[12], packet[13]);
123     __m128i te = _mm_unpacklo_epi8(packet[14], packet[15]);
124     __m128i tf = _mm_unpackhi_epi8(packet[14], packet[15]);
125 
126     // 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33
127     __m128i s0 = _mm_unpacklo_epi16(t0, t2);
128     __m128i s1 = _mm_unpackhi_epi16(t0, t2);  // 04 14 24 34
129     __m128i s2 = _mm_unpacklo_epi16(t1, t3);  // 08 18 28 38 ...
130     __m128i s3 = _mm_unpackhi_epi16(t1, t3);  // 0c 1c 2c 3c ...
131     // 40 50 60 70 41 51 61 71 42 52 62 72 43 53 63 73
132     __m128i s4 = _mm_unpacklo_epi16(t4, t6);
133     __m128i s5 = _mm_unpackhi_epi16(t4, t6);  // 44 54 64 74 ...
134     __m128i s6 = _mm_unpacklo_epi16(t5, t7);
135     __m128i s7 = _mm_unpackhi_epi16(t5, t7);
136     __m128i s8 = _mm_unpacklo_epi16(t8, ta);
137     __m128i s9 = _mm_unpackhi_epi16(t8, ta);
138     __m128i sa = _mm_unpacklo_epi16(t9, tb);
139     __m128i sb = _mm_unpackhi_epi16(t9, tb);
140     __m128i sc = _mm_unpacklo_epi16(tc, te);
141     __m128i sd = _mm_unpackhi_epi16(tc, te);
142     __m128i se = _mm_unpacklo_epi16(td, tf);
143     __m128i sf = _mm_unpackhi_epi16(td, tf);
144 
145     // 00 10 20 30 40 50 60 70 01 11 21 31 41 51 61 71
146     __m128i u0 = _mm_unpacklo_epi32(s0, s4);
147     // 02 12 22 32 42 52 62 72 03 13 23 33 43 53 63 73
148     __m128i u1 = _mm_unpackhi_epi32(s0, s4);
149     __m128i u2 = _mm_unpacklo_epi32(s1, s5);
150     __m128i u3 = _mm_unpackhi_epi32(s1, s5);
151     __m128i u4 = _mm_unpacklo_epi32(s2, s6);
152     __m128i u5 = _mm_unpackhi_epi32(s2, s6);
153     __m128i u6 = _mm_unpacklo_epi32(s3, s7);
154     __m128i u7 = _mm_unpackhi_epi32(s3, s7);
155     __m128i u8 = _mm_unpacklo_epi32(s8, sc);
156     __m128i u9 = _mm_unpackhi_epi32(s8, sc);
157     __m128i ua = _mm_unpacklo_epi32(s9, sd);
158     __m128i ub = _mm_unpackhi_epi32(s9, sd);
159     __m128i uc = _mm_unpacklo_epi32(sa, se);
160     __m128i ud = _mm_unpackhi_epi32(sa, se);
161     __m128i ue = _mm_unpacklo_epi32(sb, sf);
162     __m128i uf = _mm_unpackhi_epi32(sb, sf);
163 
164     packet[0] = _mm_unpacklo_epi64(u0, u8);
165     packet[1] = _mm_unpackhi_epi64(u0, u8);
166     packet[2] = _mm_unpacklo_epi64(u1, u9);
167     packet[3] = _mm_unpackhi_epi64(u1, u9);
168     packet[4] = _mm_unpacklo_epi64(u2, ua);
169     packet[5] = _mm_unpackhi_epi64(u2, ua);
170     packet[6] = _mm_unpacklo_epi64(u3, ub);
171     packet[7] = _mm_unpackhi_epi64(u3, ub);
172     packet[8] = _mm_unpacklo_epi64(u4, uc);
173     packet[9] = _mm_unpackhi_epi64(u4, uc);
174     packet[10] = _mm_unpacklo_epi64(u5, ud);
175     packet[11] = _mm_unpackhi_epi64(u5, ud);
176     packet[12] = _mm_unpacklo_epi64(u6, ue);
177     packet[13] = _mm_unpackhi_epi64(u6, ue);
178     packet[14] = _mm_unpacklo_epi64(u7, uf);
179     packet[15] = _mm_unpackhi_epi64(u7, uf);
180     for (int i = 0; i < 16; ++i) {
181       _mm_storeu_si128(reinterpret_cast<__m128i*>(b + ldb * i), packet[i]);
182     }
183   }
184 };
185 
186 // TODO(phawkins): add an 4x4 uint16_t transpose kernel.
187 
188 template <>
189 struct TransposeMicroKernel<uint16_t, /*bs=*/8> {
190   static void Apply(const char* __restrict a, int64_t lda, char* __restrict b,
191                     int64_t ldb) {
192     using Eigen::internal::Packet8h;
193     using Eigen::internal::PacketBlock;
194     constexpr int bs = 8;
195     PacketBlock<Packet8h, bs> block;
196     for (int i = 0; i < bs; ++i) {
197       block.packet[i] = Eigen::internal::ploadu<Packet8h>(
198           reinterpret_cast<const Eigen::half*>(a + lda * i));
199     }
200     Eigen::internal::ptranspose(block);
201     for (int i = 0; i < bs; ++i) {
202       Eigen::internal::pstoreu<Eigen::half>(
203           reinterpret_cast<Eigen::half*>(b + ldb * i), block.packet[i]);
204     }
205   }
206 };
207 
208 template <>
209 struct TransposeMicroKernel<uint32_t, /*bs=*/4> {
210   static void Apply(const char* __restrict a, int64_t lda, char* __restrict b,
211                     int64_t ldb) {
212     using Eigen::internal::Packet4f;
213     using Eigen::internal::PacketBlock;
214     constexpr int bs = 4;
215     PacketBlock<Packet4f, bs> block;
216     for (int i = 0; i < bs; ++i) {
217       block.packet[i] = Eigen::internal::ploadu<Packet4f>(
218           reinterpret_cast<const float*>(a + lda * i));
219     }
220     Eigen::internal::ptranspose(block);
221     for (int i = 0; i < bs; ++i) {
222       Eigen::internal::pstoreu<float>(reinterpret_cast<float*>(b + ldb * i),
223                                       block.packet[i]);
224     }
225   }
226 };
227 
228 template <>
229 struct TransposeMicroKernel<uint32_t, /*bs=*/8> {
230   static void Apply(const char* __restrict a, int64_t lda, char* __restrict b,
231                     int64_t ldb) {
232     using Eigen::internal::Packet8f;
233     using Eigen::internal::PacketBlock;
234     constexpr int bs = 8;
235     PacketBlock<Packet8f, bs> block;
236     for (int i = 0; i < bs; ++i) {
237       block.packet[i] = Eigen::internal::ploadu<Packet8f>(
238           reinterpret_cast<const float*>(a + lda * i));
239     }
240     Eigen::internal::ptranspose(block);
241     for (int i = 0; i < bs; ++i) {
242       Eigen::internal::pstoreu<float>(reinterpret_cast<float*>(b + ldb * i),
243                                       block.packet[i]);
244     }
245   }
246 };
247 
248 template <>
249 struct TransposeMicroKernel<uint64_t, /*bs=*/2> {
250   static void Apply(const char* __restrict a, int64_t lda, char* __restrict b,
251                     int64_t ldb) {
252     using Eigen::internal::Packet2d;
253     using Eigen::internal::PacketBlock;
254     constexpr int bs = 2;
255     PacketBlock<Packet2d, bs> block;
256     for (int i = 0; i < bs; ++i) {
257       block.packet[i] = Eigen::internal::ploadu<Packet2d>(
258           reinterpret_cast<const double*>(a + lda * i));
259     }
260     Eigen::internal::ptranspose(block);
261     for (int i = 0; i < bs; ++i) {
262       Eigen::internal::pstoreu<double>(reinterpret_cast<double*>(b + ldb * i),
263                                        block.packet[i]);
264     }
265   }
266 };
267 
268 template <>
269 struct TransposeMicroKernel<uint64_t, /*bs=*/4> {
270   static void Apply(const char* __restrict a, int64_t lda, char* __restrict b,
271                     int64_t ldb) {
272     using Eigen::internal::Packet4d;
273     using Eigen::internal::PacketBlock;
274     constexpr int bs = 4;
275     PacketBlock<Packet4d, bs> block;
276     for (int i = 0; i < bs; ++i) {
277       block.packet[i] = Eigen::internal::ploadu<Packet4d>(
278           reinterpret_cast<const double*>(a + lda * i));
279     }
280     Eigen::internal::ptranspose(block);
281     for (int i = 0; i < bs; ++i) {
282       Eigen::internal::pstoreu<double>(reinterpret_cast<double*>(b + ldb * i),
283                                        block.packet[i]);
284     }
285   }
286 };
287 
288 #endif  // EIGEN_VECTORIZE_AVX
289 
290 }  // namespace xla
291 
292 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_TRANSPOSE_KERNELS_H_
293