1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/layout_util.h"
17
18 #include <sstream>
19
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/test_helpers.h"
23
24 namespace xla {
25 namespace {
26
27 class LayoutUtilTest : public ::testing::Test {
28 protected:
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64_t> dimensions,absl::Span<const int64_t> minor_to_major,absl::Span<const DimLevelType> dim_level_types={})29 Shape MakeShapeWithLayout(
30 PrimitiveType element_type, absl::Span<const int64_t> dimensions,
31 absl::Span<const int64_t> minor_to_major,
32 absl::Span<const DimLevelType> dim_level_types = {}) {
33 Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
34 *shape.mutable_layout() =
35 LayoutUtil::MakeLayout(minor_to_major, dim_level_types);
36 return shape;
37 }
38 };
39
TEST_F(LayoutUtilTest,TupleLayoutComparison)40 TEST_F(LayoutUtilTest, TupleLayoutComparison) {
41 Shape shape =
42 ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1})});
43 Shape other_shape =
44 ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
45
46 Shape tuple0 = ShapeUtil::MakeTupleShape({});
47 Shape tuple1 = ShapeUtil::MakeTupleShape({shape});
48 Shape tuple2 = ShapeUtil::MakeTupleShape({shape, shape});
49
50 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple0));
51 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple1));
52 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple2));
53 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple0));
54 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple0));
55
56 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple1));
57 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple2));
58 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple1));
59
60 Shape other_tuple2 = ShapeUtil::MakeTupleShape({shape, other_shape});
61 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple2));
62 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple2, other_tuple2));
63 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(other_tuple2, tuple2));
64 }
65
TEST_F(LayoutUtilTest,CopyLayoutDenseArray)66 TEST_F(LayoutUtilTest, CopyLayoutDenseArray) {
67 Shape src = MakeShapeWithLayout(F32, {2, 3}, {0, 1});
68 Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
69
70 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
71 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
72 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
73
74 // Should work if destination has no layout.
75 dst.clear_layout();
76 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
77 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
78 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
79
80 // If source is cleared, then destination should be cleared.
81 src.clear_layout();
82 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
83 EXPECT_TRUE(dst.has_layout());
84 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
85 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
86 EXPECT_FALSE(dst.has_layout());
87 }
88
TEST_F(LayoutUtilTest,CopyLayoutCSRArray)89 TEST_F(LayoutUtilTest, CopyLayoutCSRArray) {
90 Shape src =
91 MakeShapeWithLayout(F32, {2, 3}, {1, 0}, {DIM_DENSE, DIM_COMPRESSED});
92 Shape dst = MakeShapeWithLayout(F32, {2, 3}, {0, 1});
93
94 EXPECT_TRUE(LayoutUtil::IsSparseArray(src));
95 EXPECT_FALSE(LayoutUtil::IsSparseArray(dst));
96
97 EXPECT_TRUE(LayoutUtil::IsCSRArray(src));
98 EXPECT_FALSE(LayoutUtil::IsCSRArray(dst));
99
100 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
101 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
102 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
103 EXPECT_TRUE(LayoutUtil::IsCSRArray(dst));
104
105 // Should work if destination has no layout.
106 dst.clear_layout();
107 EXPECT_FALSE(LayoutUtil::IsCSRArray(dst));
108 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
109 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
110 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
111 EXPECT_TRUE(LayoutUtil::IsCSRArray(dst));
112
113 // Convert dst to a CSC array with dim 0 minor layout.
114 *dst.mutable_layout()->mutable_minor_to_major() = {0, 1};
115 EXPECT_TRUE(LayoutUtil::IsCSCArray(dst));
116 EXPECT_FALSE(LayoutUtil::IsCSRArray(dst));
117
118 // If source is cleared, then destination should be cleared.
119 src.clear_layout();
120 EXPECT_FALSE(LayoutUtil::IsCSRArray(src));
121 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
122 EXPECT_TRUE(dst.has_layout());
123 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
124 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
125 EXPECT_FALSE(dst.has_layout());
126 EXPECT_FALSE(LayoutUtil::IsCSRArray(dst));
127 }
128
TEST_F(LayoutUtilTest,CopyLayoutTuple)129 TEST_F(LayoutUtilTest, CopyLayoutTuple) {
130 Shape src = ShapeUtil::MakeTupleShape(
131 {MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
132 MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
133 ShapeUtil::MakeTupleShape(
134 {MakeShapeWithLayout(F32, {}, {}),
135 MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})});
136 Shape dst = ShapeUtil::MakeTupleShape(
137 {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
138 MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
139 ShapeUtil::MakeTupleShape(
140 {MakeShapeWithLayout(F32, {}, {}),
141 MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
142
143 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
144 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
145 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
146 }
147
TEST_F(LayoutUtilTest,CopyLayoutNotCompatibleSameRank)148 TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) {
149 Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
150 Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0});
151 ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
152 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
153 }
154
TEST_F(LayoutUtilTest,CopyLayoutNotCompatibleDifferentRank)155 TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) {
156 Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
157 Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
158 auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
159 EXPECT_FALSE(status.ok());
160 EXPECT_THAT(status.error_message(),
161 ::testing::ContainsRegex("cannot copy layout from shape"));
162 }
163
TEST_F(LayoutUtilTest,CopyLayoutNotCompatibleTuple)164 TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) {
165 Shape src =
166 ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
167 MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
168 ShapeUtil::MakeTupleShape({MakeShapeWithLayout(
169 F32, {1, 2, 3}, {0, 2, 1})})});
170 Shape dst = ShapeUtil::MakeTupleShape(
171 {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
172 MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
173 ShapeUtil::MakeTupleShape(
174 {MakeShapeWithLayout(F32, {}, {}),
175 MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
176
177 auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
178 EXPECT_FALSE(status.ok());
179 EXPECT_THAT(status.error_message(),
180 ::testing::ContainsRegex("cannot copy layout from shape"));
181 }
182
TEST_F(LayoutUtilTest,CopyLayoutBogusLayout)183 TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) {
184 Shape src = ShapeUtil::MakeShape(F32, {2, 3});
185 Shape dst = ShapeUtil::MakeShape(F32, {2, 3});
186 // Set layout to invalid value.
187 *src.mutable_layout() = LayoutUtil::MakeLayout({1, 2, 3, 4});
188
189 auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
190 EXPECT_FALSE(status.ok());
191 EXPECT_THAT(
192 status.error_message(),
193 ::testing::ContainsRegex("layout minor_to_major field contains .* "
194 "elements, but shape is rank"));
195 }
196
TEST_F(LayoutUtilTest,CopyTokenLayout)197 TEST_F(LayoutUtilTest, CopyTokenLayout) {
198 Shape src = ShapeUtil::MakeTokenShape();
199 Shape dst = ShapeUtil::MakeTokenShape();
200
201 // Layouts are trivially the same for token types and copying layouts should
202 // be a nop.
203 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
204 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
205 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
206 }
207
TEST_F(LayoutUtilTest,CopyOpaqueLayout)208 TEST_F(LayoutUtilTest, CopyOpaqueLayout) {
209 Shape src = ShapeUtil::MakeOpaqueShape();
210 Shape dst = ShapeUtil::MakeOpaqueShape();
211
212 // Layouts are trivially the same for opaque types and copying layouts should
213 // be a nop.
214 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
215 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
216 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
217 }
218
TEST_F(LayoutUtilTest,CopyTupleLayoutWithTokenAndOpaque)219 TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) {
220 Shape src = ShapeUtil::MakeTupleShape(
221 {MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
222 MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
223 ShapeUtil::MakeTupleShape(
224 {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
225 MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})});
226 Shape dst = ShapeUtil::MakeTupleShape(
227 {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
228 MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
229 ShapeUtil::MakeTupleShape(
230 {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
231 MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
232
233 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
234 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
235 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
236 }
237
TEST_F(LayoutUtilTest,ClearLayoutTuple)238 TEST_F(LayoutUtilTest, ClearLayoutTuple) {
239 Shape shape = ShapeUtil::MakeTupleShape(
240 {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
241 MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
242 ShapeUtil::MakeTupleShape(
243 {MakeShapeWithLayout(F32, {}, {}),
244 MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
245 EXPECT_TRUE(LayoutUtil::HasLayout(shape));
246 EXPECT_TRUE(shape.tuple_shapes(0).has_layout());
247 EXPECT_TRUE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
248
249 LayoutUtil::ClearLayout(&shape);
250
251 EXPECT_FALSE(LayoutUtil::HasLayout(shape));
252 EXPECT_FALSE(shape.tuple_shapes(0).has_layout());
253 EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
254 }
255
TEST_F(LayoutUtilTest,ClearLayoutOpaqueAndToken)256 TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) {
257 // Opaque and token types trivially have layouts.
258 for (Shape shape :
259 {ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) {
260 EXPECT_TRUE(LayoutUtil::HasLayout(shape));
261 LayoutUtil::ClearLayout(&shape);
262 EXPECT_TRUE(LayoutUtil::HasLayout(shape));
263 }
264 }
265
TEST_F(LayoutUtilTest,SetToDefaultLayoutTuple)266 TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) {
267 Shape shape = ShapeUtil::MakeTupleShape(
268 {MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}),
269 MakeShapeWithLayout(F32, {42, 123, 7}, {1, 2, 0}),
270 ShapeUtil::MakeTupleShape(
271 {MakeShapeWithLayout(F32, {}, {}),
272 MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 1, 2, 0})})});
273 EXPECT_FALSE(LayoutUtil::Equal(shape.tuple_shapes(0).layout(),
274 shape.tuple_shapes(1).layout()));
275 LayoutUtil::SetToDefaultLayout(&shape);
276 EXPECT_TRUE(LayoutUtil::Equal(shape.tuple_shapes(0).layout(),
277 shape.tuple_shapes(1).layout()));
278 EXPECT_TRUE(LayoutUtil::Equal(
279 LayoutUtil::GetDefaultLayoutForShape(shape.tuple_shapes(0)),
280 shape.tuple_shapes(1).layout()));
281 }
282
TEST_F(LayoutUtilTest,DefaultLayoutGettersMajorToMinor)283 TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) {
284 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
285 LayoutUtil::GetDefaultLayoutForR2()));
286 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({2, 1, 0}),
287 LayoutUtil::GetDefaultLayoutForR3()));
288 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({3, 2, 1, 0}),
289 LayoutUtil::GetDefaultLayoutForR4()));
290 EXPECT_TRUE(
291 LayoutUtil::Equal(LayoutUtil::MakeLayout({4, 3, 2, 1, 0}),
292 LayoutUtil::GetDefaultLayoutForShape(
293 ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25}))));
294 }
295
TEST_F(LayoutUtilTest,MakeDescending)296 TEST_F(LayoutUtilTest, MakeDescending) {
297 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(5),
298 LayoutUtil::MakeLayout({4, 3, 2, 1, 0})));
299 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(1),
300 LayoutUtil::MakeLayout({0})));
301 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(0),
302 LayoutUtil::MakeLayout({})));
303 }
304
TEST_F(LayoutUtilTest,MakeAscending)305 TEST_F(LayoutUtilTest, MakeAscending) {
306 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(5),
307 LayoutUtil::MakeLayout({0, 1, 2, 3, 4})));
308 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(1),
309 LayoutUtil::MakeLayout({0})));
310 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(0),
311 LayoutUtil::MakeLayout({})));
312 }
313
TEST_F(LayoutUtilTest,HumanStringWithTiling)314 TEST_F(LayoutUtilTest, HumanStringWithTiling) {
315 Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3, 4}, {0, 1, 2});
316 Tile* tile;
317
318 // No tiling.
319 EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), "f32[2,3,4]{0,1,2}");
320
321 // 2D tile.
322 tile = shape.mutable_layout()->add_tiles();
323 tile->add_dimensions(512);
324 tile->add_dimensions(1024);
325 EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
326 "f32[2,3,4]{0,1,2:T(512,1024)}");
327
328 // 1D tile.
329 shape.mutable_layout()->clear_tiles();
330 tile = shape.mutable_layout()->add_tiles();
331 tile->add_dimensions(512);
332 EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
333 "f32[2,3,4]{0,1,2:T(512)}");
334
335 // 2 tiles.
336 shape = ShapeUtil::MakeShapeWithLayout(BF16, {2, 3, 4}, {1, 2, 0});
337 tile = shape.mutable_layout()->add_tiles();
338 tile->add_dimensions(16);
339 tile->add_dimensions(256);
340 tile = shape.mutable_layout()->add_tiles();
341 tile->add_dimensions(2);
342 tile->add_dimensions(1);
343 EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
344 "bf16[2,3,4]{1,2,0:T(16,256)(2,1)}");
345
346 // PRED with element size of 8 bits.
347 shape = ShapeUtil::MakeShapeWithLayout(PRED, {8, 8, 8}, {0, 2, 1});
348 tile = shape.mutable_layout()->add_tiles();
349 tile->add_dimensions(8);
350 tile->add_dimensions(128);
351 EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
352 "pred[8,8,8]{0,2,1:T(8,128)}");
353
354 // PRED with element size of 32 bits.
355 shape.mutable_layout()->clear_tiles();
356 tile = shape.mutable_layout()->add_tiles();
357 tile->add_dimensions(8);
358 tile->add_dimensions(128);
359 shape.mutable_layout()->set_element_size_in_bits(32);
360 EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
361 "pred[8,8,8]{0,2,1:T(8,128)E(32)}");
362
363 // No tile. PRED with element size of 32 bits.
364 shape.mutable_layout()->clear_tiles();
365 shape.mutable_layout()->set_element_size_in_bits(32);
366 EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
367 "pred[8,8,8]{0,2,1:E(32)}");
368
369 // Tile with negative dimension size for combining dimensions.
370 shape = ShapeUtil::MakeShapeWithLayout(BF16, {2, 3, 1004}, {2, 1, 0});
371 tile = shape.mutable_layout()->add_tiles();
372 tile->add_dimensions(2);
373 tile->add_dimensions(Tile::kCombineDimension);
374 tile->add_dimensions(128);
375 EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
376 "bf16[2,3,1004]{2,1,0:T(2,*,128)}");
377
378 // Tile with two negative dimensions.
379 shape = ShapeUtil::MakeShapeWithLayout(BF16, {8, 2, 3, 1004}, {3, 2, 1, 0});
380 tile = shape.mutable_layout()->add_tiles();
381 tile->add_dimensions(2);
382 tile->add_dimensions(Tile::kCombineDimension);
383 tile->add_dimensions(Tile::kCombineDimension);
384 tile->add_dimensions(128);
385 EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
386 "bf16[8,2,3,1004]{3,2,1,0:T(2,*,*,128)}");
387 }
388
TEST_F(LayoutUtilTest,ValidateLayout_ValidArrayLayout)389 TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) {
390 Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1});
391 auto status =
392 LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false);
393 EXPECT_TRUE(status.ok());
394 status =
395 LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true);
396 EXPECT_TRUE(status.ok());
397 }
398
TEST_F(LayoutUtilTest,ValidateLayout_InvalidArrayLayout)399 TEST_F(LayoutUtilTest, ValidateLayout_InvalidArrayLayout) {
400 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
401 *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2});
402 auto status =
403 LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false);
404 EXPECT_FALSE(status.ok());
405 EXPECT_THAT(status.error_message(),
406 ::testing::HasSubstr("layout minor_to_major field "
407 "contains 3 elements, but shape is rank 2"));
408 status =
409 LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true);
410 EXPECT_FALSE(status.ok());
411 EXPECT_THAT(status.error_message(),
412 ::testing::HasSubstr("layout minor_to_major field "
413 "contains 3 elements, but shape is rank 2"));
414 }
415
TEST_F(LayoutUtilTest,ValidateLayout_InvalidDimLevelTypes)416 TEST_F(LayoutUtilTest, ValidateLayout_InvalidDimLevelTypes) {
417 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
418 *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
419 *shape.mutable_layout()->mutable_dim_level_types() = {DIM_DENSE, DIM_DENSE,
420 DIM_DENSE};
421 auto status =
422 LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false);
423 EXPECT_FALSE(status.ok());
424 EXPECT_THAT(status.error_message(),
425 ::testing::HasSubstr("layout dim_level_types field "
426 "contains 3 elements, but shape is rank 2"));
427 status =
428 LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true);
429 EXPECT_FALSE(status.ok());
430 EXPECT_THAT(status.error_message(),
431 ::testing::HasSubstr("layout dim_level_types field "
432 "contains 3 elements, but shape is rank 2"));
433 }
434
TEST_F(LayoutUtilTest,ValidateLayout_MissingArrayLayout)435 TEST_F(LayoutUtilTest, ValidateLayout_MissingArrayLayout) {
436 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
437 LayoutUtil::ClearLayout(&shape);
438 auto status =
439 LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false);
440 EXPECT_FALSE(status.ok());
441 EXPECT_THAT(status.error_message(),
442 ::testing::HasSubstr("shape f32[2,3] does not have a layout"));
443 status =
444 LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true);
445 EXPECT_TRUE(status.ok());
446 }
447
TEST_F(LayoutUtilTest,ValidateLayout_TupleSubshapesWithMissingLayouts)448 TEST_F(LayoutUtilTest, ValidateLayout_TupleSubshapesWithMissingLayouts) {
449 Shape sub_1_1_1 = ShapeUtil::MakeShape(F32, {1, 2});
450 Shape sub_1_1 = ShapeUtil::MakeTupleShape({sub_1_1_1});
451 Shape sub_1_2 = ShapeUtil::MakeShape(F32, {1, 2});
452 LayoutUtil::ClearLayout(&sub_1_2);
453 Shape sub_1 = ShapeUtil::MakeTupleShape({sub_1_1, sub_1_2});
454 Shape sub_2_1 = ShapeUtil::MakeShape(F32, {9});
455 LayoutUtil::ClearLayout(&sub_2_1);
456 Shape sub_2 = ShapeUtil::MakeTupleShape({sub_2_1});
457 Shape shape = ShapeUtil::MakeTupleShape({sub_1, sub_2});
458
459 auto status =
460 LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false);
461 EXPECT_FALSE(status.ok());
462 EXPECT_THAT(status.error_message(),
463 ::testing::HasSubstr("shape f32[1,2] does not have a layout"));
464 status =
465 LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true);
466 EXPECT_TRUE(status.ok());
467
468 // Add invalid layout on one of sub-shapes.
469 *shape.mutable_tuple_shapes(1)->mutable_tuple_shapes(0)->mutable_layout() =
470 LayoutUtil::MakeLayout({0, 2, 3});
471
472 status =
473 LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true);
474 EXPECT_FALSE(status.ok());
475 EXPECT_THAT(status.error_message(),
476 ::testing::HasSubstr("layout minor_to_major field "
477 "contains 3 elements, but shape is rank 1"));
478 }
479
TEST_F(LayoutUtilTest,MoveDimToMajor)480 TEST_F(LayoutUtilTest, MoveDimToMajor) {
481 const Layout layout = LayoutUtil::MakeLayout({2, 1, 0});
482 Layout new_layout = LayoutUtil::MoveDimToMajor(layout, 0);
483 EXPECT_EQ(new_layout, layout);
484
485 new_layout = LayoutUtil::MoveDimToMajor(layout, 1);
486 EXPECT_EQ(new_layout, LayoutUtil::MakeLayout({2, 0, 1}));
487 }
488
489 } // namespace
490 } // namespace xla
491