xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/layout_util_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/layout_util.h"
17 
18 #include <sstream>
19 
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/test_helpers.h"
23 
24 namespace xla {
25 namespace {
26 
27 class LayoutUtilTest : public ::testing::Test {
28  protected:
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64_t> dimensions,absl::Span<const int64_t> minor_to_major,absl::Span<const DimLevelType> dim_level_types={})29   Shape MakeShapeWithLayout(
30       PrimitiveType element_type, absl::Span<const int64_t> dimensions,
31       absl::Span<const int64_t> minor_to_major,
32       absl::Span<const DimLevelType> dim_level_types = {}) {
33     Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
34     *shape.mutable_layout() =
35         LayoutUtil::MakeLayout(minor_to_major, dim_level_types);
36     return shape;
37   }
38 };
39 
TEST_F(LayoutUtilTest,TupleLayoutComparison)40 TEST_F(LayoutUtilTest, TupleLayoutComparison) {
41   Shape shape =
42       ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1})});
43   Shape other_shape =
44       ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
45 
46   Shape tuple0 = ShapeUtil::MakeTupleShape({});
47   Shape tuple1 = ShapeUtil::MakeTupleShape({shape});
48   Shape tuple2 = ShapeUtil::MakeTupleShape({shape, shape});
49 
50   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple0));
51   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple1));
52   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple2));
53   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple0));
54   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple0));
55 
56   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple1));
57   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple2));
58   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple1));
59 
60   Shape other_tuple2 = ShapeUtil::MakeTupleShape({shape, other_shape});
61   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple2));
62   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple2, other_tuple2));
63   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(other_tuple2, tuple2));
64 }
65 
TEST_F(LayoutUtilTest,CopyLayoutDenseArray)66 TEST_F(LayoutUtilTest, CopyLayoutDenseArray) {
67   Shape src = MakeShapeWithLayout(F32, {2, 3}, {0, 1});
68   Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
69 
70   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
71   EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
72   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
73 
74   // Should work if destination has no layout.
75   dst.clear_layout();
76   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
77   EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
78   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
79 
80   // If source is cleared, then destination should be cleared.
81   src.clear_layout();
82   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
83   EXPECT_TRUE(dst.has_layout());
84   EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
85   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
86   EXPECT_FALSE(dst.has_layout());
87 }
88 
TEST_F(LayoutUtilTest,CopyLayoutCSRArray)89 TEST_F(LayoutUtilTest, CopyLayoutCSRArray) {
90   Shape src =
91       MakeShapeWithLayout(F32, {2, 3}, {1, 0}, {DIM_DENSE, DIM_COMPRESSED});
92   Shape dst = MakeShapeWithLayout(F32, {2, 3}, {0, 1});
93 
94   EXPECT_TRUE(LayoutUtil::IsSparseArray(src));
95   EXPECT_FALSE(LayoutUtil::IsSparseArray(dst));
96 
97   EXPECT_TRUE(LayoutUtil::IsCSRArray(src));
98   EXPECT_FALSE(LayoutUtil::IsCSRArray(dst));
99 
100   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
101   EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
102   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
103   EXPECT_TRUE(LayoutUtil::IsCSRArray(dst));
104 
105   // Should work if destination has no layout.
106   dst.clear_layout();
107   EXPECT_FALSE(LayoutUtil::IsCSRArray(dst));
108   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
109   EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
110   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
111   EXPECT_TRUE(LayoutUtil::IsCSRArray(dst));
112 
113   // Convert dst to a CSC array with dim 0 minor layout.
114   *dst.mutable_layout()->mutable_minor_to_major() = {0, 1};
115   EXPECT_TRUE(LayoutUtil::IsCSCArray(dst));
116   EXPECT_FALSE(LayoutUtil::IsCSRArray(dst));
117 
118   // If source is cleared, then destination should be cleared.
119   src.clear_layout();
120   EXPECT_FALSE(LayoutUtil::IsCSRArray(src));
121   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
122   EXPECT_TRUE(dst.has_layout());
123   EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
124   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
125   EXPECT_FALSE(dst.has_layout());
126   EXPECT_FALSE(LayoutUtil::IsCSRArray(dst));
127 }
128 
TEST_F(LayoutUtilTest,CopyLayoutTuple)129 TEST_F(LayoutUtilTest, CopyLayoutTuple) {
130   Shape src = ShapeUtil::MakeTupleShape(
131       {MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
132        MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
133        ShapeUtil::MakeTupleShape(
134            {MakeShapeWithLayout(F32, {}, {}),
135             MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})});
136   Shape dst = ShapeUtil::MakeTupleShape(
137       {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
138        MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
139        ShapeUtil::MakeTupleShape(
140            {MakeShapeWithLayout(F32, {}, {}),
141             MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
142 
143   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
144   EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
145   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
146 }
147 
TEST_F(LayoutUtilTest,CopyLayoutNotCompatibleSameRank)148 TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) {
149   Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
150   Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0});
151   ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
152   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
153 }
154 
TEST_F(LayoutUtilTest,CopyLayoutNotCompatibleDifferentRank)155 TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) {
156   Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
157   Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
158   auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
159   EXPECT_FALSE(status.ok());
160   EXPECT_THAT(status.error_message(),
161               ::testing::ContainsRegex("cannot copy layout from shape"));
162 }
163 
TEST_F(LayoutUtilTest,CopyLayoutNotCompatibleTuple)164 TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) {
165   Shape src =
166       ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
167                                  MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
168                                  ShapeUtil::MakeTupleShape({MakeShapeWithLayout(
169                                      F32, {1, 2, 3}, {0, 2, 1})})});
170   Shape dst = ShapeUtil::MakeTupleShape(
171       {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
172        MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
173        ShapeUtil::MakeTupleShape(
174            {MakeShapeWithLayout(F32, {}, {}),
175             MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
176 
177   auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
178   EXPECT_FALSE(status.ok());
179   EXPECT_THAT(status.error_message(),
180               ::testing::ContainsRegex("cannot copy layout from shape"));
181 }
182 
TEST_F(LayoutUtilTest,CopyLayoutBogusLayout)183 TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) {
184   Shape src = ShapeUtil::MakeShape(F32, {2, 3});
185   Shape dst = ShapeUtil::MakeShape(F32, {2, 3});
186   // Set layout to invalid value.
187   *src.mutable_layout() = LayoutUtil::MakeLayout({1, 2, 3, 4});
188 
189   auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
190   EXPECT_FALSE(status.ok());
191   EXPECT_THAT(
192       status.error_message(),
193       ::testing::ContainsRegex("layout minor_to_major field contains .* "
194                                "elements, but shape is rank"));
195 }
196 
TEST_F(LayoutUtilTest,CopyTokenLayout)197 TEST_F(LayoutUtilTest, CopyTokenLayout) {
198   Shape src = ShapeUtil::MakeTokenShape();
199   Shape dst = ShapeUtil::MakeTokenShape();
200 
201   // Layouts are trivially the same for token types and copying layouts should
202   // be a nop.
203   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
204   EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
205   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
206 }
207 
TEST_F(LayoutUtilTest,CopyOpaqueLayout)208 TEST_F(LayoutUtilTest, CopyOpaqueLayout) {
209   Shape src = ShapeUtil::MakeOpaqueShape();
210   Shape dst = ShapeUtil::MakeOpaqueShape();
211 
212   // Layouts are trivially the same for opaque types and copying layouts should
213   // be a nop.
214   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
215   EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
216   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
217 }
218 
TEST_F(LayoutUtilTest,CopyTupleLayoutWithTokenAndOpaque)219 TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) {
220   Shape src = ShapeUtil::MakeTupleShape(
221       {MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
222        MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
223        ShapeUtil::MakeTupleShape(
224            {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
225             MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})});
226   Shape dst = ShapeUtil::MakeTupleShape(
227       {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
228        MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
229        ShapeUtil::MakeTupleShape(
230            {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
231             MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
232 
233   EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
234   EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
235   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
236 }
237 
TEST_F(LayoutUtilTest,ClearLayoutTuple)238 TEST_F(LayoutUtilTest, ClearLayoutTuple) {
239   Shape shape = ShapeUtil::MakeTupleShape(
240       {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
241        MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
242        ShapeUtil::MakeTupleShape(
243            {MakeShapeWithLayout(F32, {}, {}),
244             MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
245   EXPECT_TRUE(LayoutUtil::HasLayout(shape));
246   EXPECT_TRUE(shape.tuple_shapes(0).has_layout());
247   EXPECT_TRUE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
248 
249   LayoutUtil::ClearLayout(&shape);
250 
251   EXPECT_FALSE(LayoutUtil::HasLayout(shape));
252   EXPECT_FALSE(shape.tuple_shapes(0).has_layout());
253   EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
254 }
255 
TEST_F(LayoutUtilTest,ClearLayoutOpaqueAndToken)256 TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) {
257   // Opaque and token types trivially have layouts.
258   for (Shape shape :
259        {ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) {
260     EXPECT_TRUE(LayoutUtil::HasLayout(shape));
261     LayoutUtil::ClearLayout(&shape);
262     EXPECT_TRUE(LayoutUtil::HasLayout(shape));
263   }
264 }
265 
TEST_F(LayoutUtilTest,SetToDefaultLayoutTuple)266 TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) {
267   Shape shape = ShapeUtil::MakeTupleShape(
268       {MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}),
269        MakeShapeWithLayout(F32, {42, 123, 7}, {1, 2, 0}),
270        ShapeUtil::MakeTupleShape(
271            {MakeShapeWithLayout(F32, {}, {}),
272             MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 1, 2, 0})})});
273   EXPECT_FALSE(LayoutUtil::Equal(shape.tuple_shapes(0).layout(),
274                                  shape.tuple_shapes(1).layout()));
275   LayoutUtil::SetToDefaultLayout(&shape);
276   EXPECT_TRUE(LayoutUtil::Equal(shape.tuple_shapes(0).layout(),
277                                 shape.tuple_shapes(1).layout()));
278   EXPECT_TRUE(LayoutUtil::Equal(
279       LayoutUtil::GetDefaultLayoutForShape(shape.tuple_shapes(0)),
280       shape.tuple_shapes(1).layout()));
281 }
282 
TEST_F(LayoutUtilTest,DefaultLayoutGettersMajorToMinor)283 TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) {
284   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
285                                 LayoutUtil::GetDefaultLayoutForR2()));
286   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({2, 1, 0}),
287                                 LayoutUtil::GetDefaultLayoutForR3()));
288   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({3, 2, 1, 0}),
289                                 LayoutUtil::GetDefaultLayoutForR4()));
290   EXPECT_TRUE(
291       LayoutUtil::Equal(LayoutUtil::MakeLayout({4, 3, 2, 1, 0}),
292                         LayoutUtil::GetDefaultLayoutForShape(
293                             ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25}))));
294 }
295 
TEST_F(LayoutUtilTest,MakeDescending)296 TEST_F(LayoutUtilTest, MakeDescending) {
297   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(5),
298                                 LayoutUtil::MakeLayout({4, 3, 2, 1, 0})));
299   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(1),
300                                 LayoutUtil::MakeLayout({0})));
301   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(0),
302                                 LayoutUtil::MakeLayout({})));
303 }
304 
TEST_F(LayoutUtilTest,MakeAscending)305 TEST_F(LayoutUtilTest, MakeAscending) {
306   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(5),
307                                 LayoutUtil::MakeLayout({0, 1, 2, 3, 4})));
308   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(1),
309                                 LayoutUtil::MakeLayout({0})));
310   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(0),
311                                 LayoutUtil::MakeLayout({})));
312 }
313 
TEST_F(LayoutUtilTest,HumanStringWithTiling)314 TEST_F(LayoutUtilTest, HumanStringWithTiling) {
315   Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3, 4}, {0, 1, 2});
316   Tile* tile;
317 
318   // No tiling.
319   EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), "f32[2,3,4]{0,1,2}");
320 
321   // 2D tile.
322   tile = shape.mutable_layout()->add_tiles();
323   tile->add_dimensions(512);
324   tile->add_dimensions(1024);
325   EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
326             "f32[2,3,4]{0,1,2:T(512,1024)}");
327 
328   // 1D tile.
329   shape.mutable_layout()->clear_tiles();
330   tile = shape.mutable_layout()->add_tiles();
331   tile->add_dimensions(512);
332   EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
333             "f32[2,3,4]{0,1,2:T(512)}");
334 
335   // 2 tiles.
336   shape = ShapeUtil::MakeShapeWithLayout(BF16, {2, 3, 4}, {1, 2, 0});
337   tile = shape.mutable_layout()->add_tiles();
338   tile->add_dimensions(16);
339   tile->add_dimensions(256);
340   tile = shape.mutable_layout()->add_tiles();
341   tile->add_dimensions(2);
342   tile->add_dimensions(1);
343   EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
344             "bf16[2,3,4]{1,2,0:T(16,256)(2,1)}");
345 
346   // PRED with element size of 8 bits.
347   shape = ShapeUtil::MakeShapeWithLayout(PRED, {8, 8, 8}, {0, 2, 1});
348   tile = shape.mutable_layout()->add_tiles();
349   tile->add_dimensions(8);
350   tile->add_dimensions(128);
351   EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
352             "pred[8,8,8]{0,2,1:T(8,128)}");
353 
354   // PRED with element size of 32 bits.
355   shape.mutable_layout()->clear_tiles();
356   tile = shape.mutable_layout()->add_tiles();
357   tile->add_dimensions(8);
358   tile->add_dimensions(128);
359   shape.mutable_layout()->set_element_size_in_bits(32);
360   EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
361             "pred[8,8,8]{0,2,1:T(8,128)E(32)}");
362 
363   // No tile. PRED with element size of 32 bits.
364   shape.mutable_layout()->clear_tiles();
365   shape.mutable_layout()->set_element_size_in_bits(32);
366   EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
367             "pred[8,8,8]{0,2,1:E(32)}");
368 
369   // Tile with negative dimension size for combining dimensions.
370   shape = ShapeUtil::MakeShapeWithLayout(BF16, {2, 3, 1004}, {2, 1, 0});
371   tile = shape.mutable_layout()->add_tiles();
372   tile->add_dimensions(2);
373   tile->add_dimensions(Tile::kCombineDimension);
374   tile->add_dimensions(128);
375   EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
376             "bf16[2,3,1004]{2,1,0:T(2,*,128)}");
377 
378   // Tile with two negative dimensions.
379   shape = ShapeUtil::MakeShapeWithLayout(BF16, {8, 2, 3, 1004}, {3, 2, 1, 0});
380   tile = shape.mutable_layout()->add_tiles();
381   tile->add_dimensions(2);
382   tile->add_dimensions(Tile::kCombineDimension);
383   tile->add_dimensions(Tile::kCombineDimension);
384   tile->add_dimensions(128);
385   EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape),
386             "bf16[8,2,3,1004]{3,2,1,0:T(2,*,*,128)}");
387 }
388 
TEST_F(LayoutUtilTest,ValidateLayout_ValidArrayLayout)389 TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) {
390   Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1});
391   auto status =
392       LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false);
393   EXPECT_TRUE(status.ok());
394   status =
395       LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true);
396   EXPECT_TRUE(status.ok());
397 }
398 
TEST_F(LayoutUtilTest,ValidateLayout_InvalidArrayLayout)399 TEST_F(LayoutUtilTest, ValidateLayout_InvalidArrayLayout) {
400   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
401   *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2});
402   auto status =
403       LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false);
404   EXPECT_FALSE(status.ok());
405   EXPECT_THAT(status.error_message(),
406               ::testing::HasSubstr("layout minor_to_major field "
407                                    "contains 3 elements, but shape is rank 2"));
408   status =
409       LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true);
410   EXPECT_FALSE(status.ok());
411   EXPECT_THAT(status.error_message(),
412               ::testing::HasSubstr("layout minor_to_major field "
413                                    "contains 3 elements, but shape is rank 2"));
414 }
415 
TEST_F(LayoutUtilTest,ValidateLayout_InvalidDimLevelTypes)416 TEST_F(LayoutUtilTest, ValidateLayout_InvalidDimLevelTypes) {
417   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
418   *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
419   *shape.mutable_layout()->mutable_dim_level_types() = {DIM_DENSE, DIM_DENSE,
420                                                         DIM_DENSE};
421   auto status =
422       LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false);
423   EXPECT_FALSE(status.ok());
424   EXPECT_THAT(status.error_message(),
425               ::testing::HasSubstr("layout dim_level_types field "
426                                    "contains 3 elements, but shape is rank 2"));
427   status =
428       LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true);
429   EXPECT_FALSE(status.ok());
430   EXPECT_THAT(status.error_message(),
431               ::testing::HasSubstr("layout dim_level_types field "
432                                    "contains 3 elements, but shape is rank 2"));
433 }
434 
TEST_F(LayoutUtilTest,ValidateLayout_MissingArrayLayout)435 TEST_F(LayoutUtilTest, ValidateLayout_MissingArrayLayout) {
436   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
437   LayoutUtil::ClearLayout(&shape);
438   auto status =
439       LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false);
440   EXPECT_FALSE(status.ok());
441   EXPECT_THAT(status.error_message(),
442               ::testing::HasSubstr("shape f32[2,3] does not have a layout"));
443   status =
444       LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true);
445   EXPECT_TRUE(status.ok());
446 }
447 
TEST_F(LayoutUtilTest,ValidateLayout_TupleSubshapesWithMissingLayouts)448 TEST_F(LayoutUtilTest, ValidateLayout_TupleSubshapesWithMissingLayouts) {
449   Shape sub_1_1_1 = ShapeUtil::MakeShape(F32, {1, 2});
450   Shape sub_1_1 = ShapeUtil::MakeTupleShape({sub_1_1_1});
451   Shape sub_1_2 = ShapeUtil::MakeShape(F32, {1, 2});
452   LayoutUtil::ClearLayout(&sub_1_2);
453   Shape sub_1 = ShapeUtil::MakeTupleShape({sub_1_1, sub_1_2});
454   Shape sub_2_1 = ShapeUtil::MakeShape(F32, {9});
455   LayoutUtil::ClearLayout(&sub_2_1);
456   Shape sub_2 = ShapeUtil::MakeTupleShape({sub_2_1});
457   Shape shape = ShapeUtil::MakeTupleShape({sub_1, sub_2});
458 
459   auto status =
460       LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false);
461   EXPECT_FALSE(status.ok());
462   EXPECT_THAT(status.error_message(),
463               ::testing::HasSubstr("shape f32[1,2] does not have a layout"));
464   status =
465       LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true);
466   EXPECT_TRUE(status.ok());
467 
468   // Add invalid layout on one of sub-shapes.
469   *shape.mutable_tuple_shapes(1)->mutable_tuple_shapes(0)->mutable_layout() =
470       LayoutUtil::MakeLayout({0, 2, 3});
471 
472   status =
473       LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true);
474   EXPECT_FALSE(status.ok());
475   EXPECT_THAT(status.error_message(),
476               ::testing::HasSubstr("layout minor_to_major field "
477                                    "contains 3 elements, but shape is rank 1"));
478 }
479 
TEST_F(LayoutUtilTest,MoveDimToMajor)480 TEST_F(LayoutUtilTest, MoveDimToMajor) {
481   const Layout layout = LayoutUtil::MakeLayout({2, 1, 0});
482   Layout new_layout = LayoutUtil::MoveDimToMajor(layout, 0);
483   EXPECT_EQ(new_layout, layout);
484 
485   new_layout = LayoutUtil::MoveDimToMajor(layout, 1);
486   EXPECT_EQ(new_layout, LayoutUtil::MakeLayout({2, 0, 1}));
487 }
488 
489 }  // namespace
490 }  // namespace xla
491