xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/layout_assignment_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17 
18 #include <initializer_list>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
33 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
34 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
35 #include "tensorflow/compiler/xla/shape_layout.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/test.h"
38 #include "tensorflow/compiler/xla/test_helpers.h"
39 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
40 #include "tensorflow/compiler/xla/tests/test_utils.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 
46 namespace xla {
47 namespace {
48 
49 namespace m = xla::match;
50 using ::testing::ElementsAre;
51 
52 class LayoutAssignmentTest : public HloTestBase {
53  protected:
AssignLayouts(HloModule * m,ComputationLayout * entry_computation_layout,ChannelLayoutConstraints * channel_constraints=nullptr)54   void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout,
55                      ChannelLayoutConstraints* channel_constraints = nullptr) {
56     LayoutAssignment layout_assignment(
57         entry_computation_layout,
58         /*channel_constraints=*/channel_constraints);
59     EXPECT_IS_OK(layout_assignment.Run(m).status());
60   }
61 
LayoutOf(HloModule * m,absl::string_view name)62   std::vector<int64_t> LayoutOf(HloModule* m, absl::string_view name) {
63     HloInstruction* instr = FindInstruction(m, name);
64     CHECK(instr != nullptr) << name;
65     auto minor_to_major = instr->shape().layout().minor_to_major();
66     return std::vector<int64_t>(minor_to_major.begin(), minor_to_major.end());
67   }
68 
ExpectLayoutIs(const Shape & shape,absl::Span<const int64_t> minor_to_major)69   void ExpectLayoutIs(const Shape& shape,
70                       absl::Span<const int64_t> minor_to_major) {
71     const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
72     EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
73         << "Expected layout " << expected << ", actual " << shape.layout();
74   }
75 
ExpectTupleLayoutIs(const Shape & shape,std::initializer_list<absl::Span<const int64_t>> minor_to_majors)76   void ExpectTupleLayoutIs(
77       const Shape& shape,
78       std::initializer_list<absl::Span<const int64_t>> minor_to_majors) {
79     int i = 0;
80     for (const absl::Span<const int64_t> minor_to_major : minor_to_majors) {
81       const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
82       const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout();
83       EXPECT_TRUE(LayoutUtil::Equal(actual, expected))
84           << "Expected tuple element " << i << " layout " << expected
85           << ", actual " << actual;
86       ++i;
87     }
88   }
89 };
90 
TEST_F(LayoutAssignmentTest,ComputationLayout)91 TEST_F(LayoutAssignmentTest, ComputationLayout) {
92   // Verify the layouts of the root and parameter instructions of a computation
93   // match the ComputationLayout for two different layouts.
94   std::vector<std::vector<int64_t>> minor_to_majors = {{0, 1}, {1, 0}};
95   for (auto& minor_to_major : minor_to_majors) {
96     auto builder = HloComputation::Builder(TestName());
97     Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
98     auto param0 = builder.AddInstruction(
99         HloInstruction::CreateParameter(0, ashape, "param0"));
100     auto param1 = builder.AddInstruction(
101         HloInstruction::CreateParameter(1, ashape, "param1"));
102     auto add = builder.AddInstruction(
103         HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
104     auto m = CreateNewVerifiedModule();
105     HloComputation* computation = m->AddEntryComputation(builder.Build());
106 
107     Layout layout = LayoutUtil::MakeLayout(minor_to_major);
108     Shape shape(ashape);
109     *shape.mutable_layout() = layout;
110     const ShapeLayout shape_layout(shape);
111 
112     ComputationLayout computation_layout(computation->ComputeProgramShape());
113     *computation_layout.mutable_parameter_layout(0) = shape_layout;
114     *computation_layout.mutable_parameter_layout(1) = shape_layout;
115     *computation_layout.mutable_result_layout() = shape_layout;
116     AssignLayouts(m.get(), &computation_layout);
117     EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
118     EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
119     EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
120   }
121 }
122 
TEST_F(LayoutAssignmentTest,ComputationLayoutMixedLayout)123 TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
124   // Verify the layouts of the root and parameter instructions of a computation
125   // match the ComputationLayout which has mixed layout.
126   auto builder = HloComputation::Builder(TestName());
127   Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
128   auto param0 = builder.AddInstruction(
129       HloInstruction::CreateParameter(0, ashape, "param0"));
130   auto param1 = builder.AddInstruction(
131       HloInstruction::CreateParameter(1, ashape, "param1"));
132   builder.AddInstruction(
133       HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
134   auto m = CreateNewVerifiedModule();
135   HloComputation* computation = m->AddEntryComputation(builder.Build());
136 
137   Layout col_major_layout = LayoutUtil::MakeLayout({1, 0});
138   Shape col_major_shape(ashape);
139   *col_major_shape.mutable_layout() = col_major_layout;
140   const ShapeLayout col_major(col_major_shape);
141 
142   Layout row_major_layout = LayoutUtil::MakeLayout({0, 1});
143   Shape row_major_shape(ashape);
144   *row_major_shape.mutable_layout() = row_major_layout;
145   const ShapeLayout row_major(row_major_shape);
146 
147   ComputationLayout computation_layout(computation->ComputeProgramShape());
148   *computation_layout.mutable_parameter_layout(0) = col_major;
149   *computation_layout.mutable_parameter_layout(1) = row_major;
150   *computation_layout.mutable_result_layout() = col_major;
151 
152   AssignLayouts(m.get(), &computation_layout);
153   EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
154   EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
155   EXPECT_TRUE(LayoutUtil::Equal(
156       col_major_layout, computation->root_instruction()->shape().layout()));
157 }
158 
TEST_F(LayoutAssignmentTest,FusionInstruction)159 TEST_F(LayoutAssignmentTest, FusionInstruction) {
160   // Verify that the layout of the fused parameters in a fusion instruction
161   // match that of the fusion operands. Other fused instructions should have no
162   // layout.
163   std::vector<std::vector<int64_t>> minor_to_majors = {{0, 1}, {1, 0}};
164   for (auto& minor_to_major : minor_to_majors) {
165     auto builder = HloComputation::Builder(TestName());
166     auto constant_literal1 = LiteralUtil::CreateR2WithLayout<float>(
167         {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
168     auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
169         {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
170     Shape ashape = constant_literal1.shape();
171 
172     auto constant1 = builder.AddInstruction(
173         HloInstruction::CreateConstant(std::move(constant_literal1)));
174     auto constant2 = builder.AddInstruction(
175         HloInstruction::CreateConstant(std::move(constant_literal2)));
176     auto add = builder.AddInstruction(HloInstruction::CreateBinary(
177         ashape, HloOpcode::kAdd, constant1, constant2));
178     auto negate1 = builder.AddInstruction(
179         HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, add));
180     auto negate2 = builder.AddInstruction(
181         HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1));
182 
183     auto m = CreateNewVerifiedModule();
184     HloComputation* computation = m->AddEntryComputation(builder.Build());
185 
186     auto fusion = computation->CreateFusionInstruction(
187         {negate2, negate1, add}, HloInstruction::FusionKind::kLoop);
188 
189     Layout layout = LayoutUtil::MakeLayout(minor_to_major);
190     Shape shape(ashape);
191     *shape.mutable_layout() = layout;
192     const ShapeLayout shape_layout(shape);
193 
194     ComputationLayout computation_layout(computation->ComputeProgramShape());
195     *computation_layout.mutable_result_layout() = shape_layout;
196 
197     AssignLayouts(m.get(), &computation_layout);
198 
199     EXPECT_TRUE(LayoutUtil::Equal(
200         layout, fusion->fused_parameter(0)->shape().layout()));
201     EXPECT_TRUE(LayoutUtil::Equal(
202         layout, fusion->fused_parameter(1)->shape().layout()));
203     EXPECT_TRUE(LayoutUtil::Equal(
204         layout, fusion->fused_expression_root()->shape().layout()));
205 
206     // Inner fused node should not have layout.
207     EXPECT_FALSE(LayoutUtil::HasLayout(
208         fusion->fused_expression_root()->operand(0)->shape()));
209   }
210 }
211 
TEST_F(LayoutAssignmentTest,TupleLayout)212 TEST_F(LayoutAssignmentTest, TupleLayout) {
213   // Verify the layouts of a tuple are assigned properly (the element layouts
214   // match their source).
215   auto builder = HloComputation::Builder(TestName());
216   auto constant0 = builder.AddInstruction(
217       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
218           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
219   auto constant1 = builder.AddInstruction(
220       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
221           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
222   auto tuple = builder.AddInstruction(
223       HloInstruction::CreateTuple({constant0, constant1}));
224 
225   // To avoid having to construct a tuple layout in the ComputationLayout below,
226   // make the result of the instruction be an array.
227   auto get_element0 = builder.AddInstruction(
228       HloInstruction::CreateGetTupleElement(constant0->shape(), tuple, 0));
229   auto negate = builder.AddInstruction(HloInstruction::CreateUnary(
230       constant0->shape(), HloOpcode::kNegate, get_element0));
231 
232   auto m = CreateNewVerifiedModule();
233   m->AddEntryComputation(builder.Build());
234 
235   ComputationLayout computation_layout(
236       m->entry_computation()->ComputeProgramShape());
237 
238   AssignLayouts(m.get(), &computation_layout);
239 
240   EXPECT_TRUE(
241       LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
242 
243   EXPECT_TRUE(LayoutUtil::HasLayout(tuple->shape()));
244   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
245       negate->shape(), computation_layout.result_layout().shape()));
246   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
247       ShapeUtil::GetTupleElementShape(tuple->shape(), 1), constant1->shape()));
248 }
249 
TEST_F(LayoutAssignmentTest,ConflictingLayoutTuple)250 TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
251   // Construct following computation which has conflicting layouts for two
252   // elements of a tuple which share the same source logicalb buffer:
253   //
254   // %constant = Constant(...)
255   // %inner_tuple = Tuple(%constant)
256   // %nested_tuple = Tuple(%inner_tuple, %inner_tuple)
257   //
258   // Result layout col-major for the first element and row-major for the
259   // second. This results in the conflict where the element of the inner_tuple
260   // needs to be both col and row major. This is resolved by deep-copying the
261   // tuple and assigning the layouts of the copied arrays as needed.
262   auto builder = HloComputation::Builder(TestName());
263   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
264       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
265   auto inner_tuple =
266       builder.AddInstruction(HloInstruction::CreateTuple({constant}));
267   auto nested_tuple = builder.AddInstruction(
268       HloInstruction::CreateTuple({inner_tuple, inner_tuple}));
269 
270   auto m = CreateNewVerifiedModule();
271   m->AddEntryComputation(builder.Build());
272 
273   ComputationLayout computation_layout(
274       m->entry_computation()->ComputeProgramShape());
275   Shape result_shape = nested_tuple->shape();
276   *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) =
277       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
278   *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{1, 0}) =
279       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
280   TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
281       result_shape));
282 
283   LayoutAssignment layout_assignment(&computation_layout);
284   AssignLayouts(m.get(), &computation_layout);
285 
286   // Layout assignment should have deep copied the result of the computation to
287   // address the layout conflict. This results in several Tuple() and
288   // GetTupleElement() instructions. Running algebraic simplification should
289   // clean up the code to something like:
290   //
291   //  %constant = Constant(...) layout={1,0}
292   //  %tuple.0 = Tuple(%constant) layout=({1,0})
293   //  %copy = Copy(%constant) layout={0,1}  # layout transposed
294   //  %tuple.1 = Tuple(%copy) layout=({0,1})
295   //  %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1}))
296   //
297   AlgebraicSimplifierOptions options(
298       [](const Shape&, const Shape&) { return false; });
299   options.set_is_layout_sensitive(true);
300   EXPECT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
301   HloInstruction* root = m->entry_computation()->root_instruction();
302   // Verify layout of the root and the root's operands.
303   EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape()));
304   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}),
305                                root->operand(0)->shape()));
306   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}),
307                                root->operand(1)->shape()));
308 
309   // Verify the structure of the HLO graph.
310   EXPECT_THAT(root,
311               GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)),
312                                   m::Tuple(m::Copy(m::Op().Is(constant))))));
313 }
314 
TEST_F(LayoutAssignmentTest,ElementwiseAndReshape)315 TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
316   // param -> log -> reshape -> tanh
317   auto builder = HloComputation::Builder(TestName());
318   Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1});
319   Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2});
320   auto param = builder.AddInstruction(
321       HloInstruction::CreateParameter(0, ashape, "param"));
322   auto log = builder.AddInstruction(
323       HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
324   auto reshape =
325       builder.AddInstruction(HloInstruction::CreateReshape(bshape, log));
326   auto tanh = builder.AddInstruction(
327       HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape));
328 
329   auto m = CreateNewVerifiedModule();
330   HloComputation* computation = m->AddEntryComputation(builder.Build(tanh));
331 
332   Shape ashape_with_layout(ashape);
333   Shape bshape_with_layout(bshape);
334   *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3});
335   *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
336 
337   ComputationLayout computation_layout(computation->ComputeProgramShape());
338   *computation_layout.mutable_parameter_layout(0) =
339       ShapeLayout(ashape_with_layout);
340   *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
341   AssignLayouts(m.get(), &computation_layout);
342 
343   auto log_minor_to_major = log->shape().layout().minor_to_major();
344   EXPECT_GT(PositionInContainer(log_minor_to_major, 1),
345             PositionInContainer(log_minor_to_major, 2));
346 
347   auto reshape_minor_to_major = reshape->shape().layout().minor_to_major();
348   EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0),
349             PositionInContainer(reshape_minor_to_major, 2));
350 }
351 
352 // Test whether LayoutAssignment assigns layouts to elementwise operations to
353 // keep linear indices valid across them, and to transpositions to make them
354 // bitcasts.
TEST_F(LayoutAssignmentTest,ElementwiseAndTranspose)355 TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
356   // param -> log -> transpose -> tanh
357   auto builder = HloComputation::Builder(TestName());
358   Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
359   Shape bshape = ShapeUtil::MakeShape(F32, {12, 42});
360   auto param = builder.AddInstruction(
361       HloInstruction::CreateParameter(0, ashape, "param"));
362   auto log = builder.AddInstruction(
363       HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
364   auto transpose = builder.AddInstruction(
365       HloInstruction::CreateTranspose(bshape, log, {1, 0}));
366   auto tanh = builder.AddInstruction(
367       HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose));
368   auto m = CreateNewVerifiedModule();
369   auto computation = m->AddEntryComputation(builder.Build(tanh));
370 
371   Shape ashape_with_layout(ashape);
372   Shape bshape_with_layout(bshape);
373   *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
374   *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
375 
376   ComputationLayout computation_layout(computation->ComputeProgramShape());
377   *computation_layout.mutable_parameter_layout(0) =
378       ShapeLayout(ashape_with_layout);
379   *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
380   AssignLayouts(m.get(), &computation_layout);
381 
382   EXPECT_TRUE(
383       LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
384   EXPECT_TRUE(LayoutUtil::Equal(bshape_with_layout.layout(),
385                                 transpose->shape().layout()));
386   EXPECT_TRUE(
387       LayoutUtil::Equal(bshape_with_layout.layout(), tanh->shape().layout()));
388 }
389 
390 // Test whether LayoutAssignment assigns layouts to transpositions to make them
391 // bitcasts.
TEST_F(LayoutAssignmentTest,BroadcastAndTranspose)392 TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
393   // param -> broadcast -> transpose
394   auto builder = HloComputation::Builder(TestName());
395   Shape ashape = ShapeUtil::MakeShape(F32, {3, 4});
396   Shape bshape = ShapeUtil::MakeShape(F32, {2, 3, 4});
397   Shape cshape = ShapeUtil::MakeShape(F32, {4, 3, 2});
398   auto param = builder.AddInstruction(
399       HloInstruction::CreateParameter(0, ashape, "param"));
400   auto broadcast = builder.AddInstruction(
401       HloInstruction::CreateBroadcast(bshape, param, {1, 2}));
402   auto transpose = builder.AddInstruction(
403       HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0}));
404   auto m = CreateNewVerifiedModule();
405   HloComputation* computation =
406       m->AddEntryComputation(builder.Build(transpose));
407 
408   Shape input_shape_with_layout(ashape);
409   Shape output_shape_with_layout(cshape);
410   *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
411   *output_shape_with_layout.mutable_layout() =
412       LayoutUtil::MakeLayout({2, 1, 0});
413 
414   ComputationLayout computation_layout(computation->ComputeProgramShape());
415   *computation_layout.mutable_parameter_layout(0) =
416       ShapeLayout(input_shape_with_layout);
417   *computation_layout.mutable_result_layout() =
418       ShapeLayout(output_shape_with_layout);
419   AssignLayouts(m.get(), &computation_layout);
420 
421   EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
422               ElementsAre(0, 1, 2));
423 }
424 
TEST_F(LayoutAssignmentTest,ReshapeOperandHasMultipleUsers)425 TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
426   // param[4] -> broadcast[3x4] ------> transpose[4x3]-------- -------> tuple
427   //                            \                                     /
428   //                             \-> tanh[3x4] -> broadcast2[2x3x4] -/
429   //
430   // The layout of `transpose` is set to {1,0} because it provides a buffer to
431   // the computation result which has a fixed layout.. Therefore, `broadcast`
432   // (the operand of transpose) is expected to have layout {0,1} so that the
433   // transpose is a bitcast. Furthermore, `tanh` is expected to have the same
434   // layout as `broadcast` (i.e. {0,1}) because `tanh` is elementwise.
435   Shape f32_4 = ShapeUtil::MakeShape(F32, {4});
436   Shape f32_34 = ShapeUtil::MakeShape(F32, {3, 4});
437   Shape f32_43 = ShapeUtil::MakeShape(F32, {4, 3});
438   Shape f32_234 = ShapeUtil::MakeShape(F32, {2, 3, 4});
439 
440   auto builder = HloComputation::Builder(TestName());
441   auto param = builder.AddInstruction(
442       HloInstruction::CreateParameter(0, f32_4, "param"));
443   auto broadcast = builder.AddInstruction(
444       HloInstruction::CreateBroadcast(f32_34, param, {1}));
445   auto transpose = builder.AddInstruction(
446       HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
447   auto tanh = builder.AddInstruction(
448       HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
449   auto broadcast2 = builder.AddInstruction(
450       HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
451   auto tuple = builder.AddInstruction(
452       HloInstruction::CreateTuple({transpose, broadcast2}));
453   auto m = CreateNewVerifiedModule();
454   HloComputation* computation = m->AddEntryComputation(builder.Build(tuple));
455 
456   ComputationLayout computation_layout(computation->ComputeProgramShape());
457   Shape param_shape_with_layout(f32_4);
458   Shape transpose_shape_with_layout(f32_43);
459   Shape broadcast2_shape_with_layout(f32_234);
460   *param_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0});
461   *transpose_shape_with_layout.mutable_layout() =
462       LayoutUtil::MakeLayout({1, 0});
463   *broadcast2_shape_with_layout.mutable_layout() =
464       LayoutUtil::MakeLayout({2, 1, 0});
465 
466   *computation_layout.mutable_parameter_layout(0) =
467       ShapeLayout(param_shape_with_layout);
468   *computation_layout.mutable_result_layout() =
469       ShapeLayout(ShapeUtil::MakeTupleShape(
470           {transpose_shape_with_layout, broadcast2_shape_with_layout}));
471   AssignLayouts(m.get(), &computation_layout);
472   EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
473   EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
474   EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1));
475 }
476 
477 class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
478  public:
OperandsMustBeTheSameLayoutAssignment(ComputationLayout * entry_computation_layout)479   explicit OperandsMustBeTheSameLayoutAssignment(
480       ComputationLayout* entry_computation_layout)
481       : LayoutAssignment(entry_computation_layout) {}
482 
483  protected:
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)484   Status PropagateBufferConstraint(
485       const BufferLayoutConstraint& buffer_constraint,
486       LayoutConstraints* constraints) override {
487     const LogicalBuffer& buffer = buffer_constraint.buffer();
488     const HloInstruction* instruction = buffer.instruction();
489 
490     // Force the operands' layout to the output layout.
491     for (int64_t operand_no = 0; operand_no < instruction->operand_count();
492          ++operand_no) {
493       const HloInstruction* operand = instruction->operand(operand_no);
494       if (instruction->shape().rank() != operand->shape().rank()) {
495         continue;
496       }
497       TF_RETURN_IF_ERROR(SetArrayOperandLayout(buffer_constraint.layout(),
498                                                instruction, operand_no,
499                                                /*mandatory=*/true));
500     }
501     return PropagateBufferConstraintToUses(buffer_constraint, constraints);
502   }
503 };
504 
TEST_F(LayoutAssignmentTest,MakeOperandsTheSame)505 TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
506   // param0 -> concatenate -> reshape
507   // param1   -^
508   auto builder = HloComputation::Builder(TestName());
509   Shape ashape = ShapeUtil::MakeShape(F32, {50, 1});
510   Shape bshape = ShapeUtil::MakeShape(F32, {50, 2});
511   Shape cshape = ShapeUtil::MakeShape(F32, {100});
512   auto param0 = builder.AddInstruction(
513       HloInstruction::CreateParameter(0, ashape, "param"));
514   auto param1 = builder.AddInstruction(
515       HloInstruction::CreateParameter(1, ashape, "param"));
516   auto concatenate = builder.AddInstruction(
517       HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1));
518   auto reshape = builder.AddInstruction(
519       HloInstruction::CreateReshape(cshape, concatenate));
520   auto m = CreateNewVerifiedModule();
521   HloComputation* computation = m->AddEntryComputation(builder.Build(reshape));
522 
523   Shape param0_shape_with_layout(ashape);
524   Shape param1_shape_with_layout(ashape);
525   *param0_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
526   *param1_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
527 
528   ComputationLayout computation_layout(computation->ComputeProgramShape());
529   *computation_layout.mutable_parameter_layout(0) =
530       ShapeLayout(param0_shape_with_layout);
531   *computation_layout.mutable_parameter_layout(1) =
532       ShapeLayout(param1_shape_with_layout);
533   OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
534   EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
535 
536   EXPECT_EQ(concatenate->operand(0)->shape().layout().minor_to_major(),
537             concatenate->operand(1)->shape().layout().minor_to_major());
538   EXPECT_EQ(concatenate->shape().layout().minor_to_major(),
539             concatenate->operand(1)->shape().layout().minor_to_major());
540 }
541 
542 // Test layout assignment of a transpose into a bitcast based on its operand.
TEST_F(LayoutAssignmentTest,TransposeToBitcastFromOperand)543 TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
544   auto builder = HloComputation::Builder(TestName());
545   Shape input_shape_with_layout =
546       ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1});
547   auto param = builder.AddInstruction(
548       HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
549   auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
550       ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1}));
551   auto m = CreateNewVerifiedModule();
552   HloComputation* computation =
553       m->AddEntryComputation(builder.Build(transpose));
554   ComputationLayout computation_layout(computation->ComputeProgramShape());
555   AssignLayouts(m.get(), &computation_layout);
556   EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
557                                             transpose->shape(), {2, 3, 0, 1}));
558 }
559 // Test layout assignment of a transpose into a bitcast based on its user.
TEST_F(LayoutAssignmentTest,TransposeToBitcastToUser)560 TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
561   auto builder = HloComputation::Builder(TestName());
562   Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
563   auto constant = builder.AddInstruction(
564       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
565   auto broadcast = builder.AddInstruction(
566       HloInstruction::CreateBroadcast(input_shape, constant, {}));
567   auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
568       ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1}));
569   auto m = CreateNewVerifiedModule();
570   HloComputation* computation =
571       m->AddEntryComputation(builder.Build(transpose));
572   ComputationLayout computation_layout(computation->ComputeProgramShape());
573   AssignLayouts(m.get(), &computation_layout);
574   EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
575                                             transpose->shape(), {2, 3, 0, 1}));
576 }
577 
578 // TransposeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,TransposeIsBitcastFail)579 TEST_F(LayoutAssignmentTest, TransposeIsBitcastFail) {
580   auto builder = HloComputation::Builder(TestName());
581   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
582   Shape input_shape_with_layout(input_shape);
583   *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
584   auto param = builder.AddInstruction(
585       HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
586   auto hlo = builder.AddInstruction(
587       HloInstruction::CreateTranspose(input_shape, param, {0, 2, 1}));
588   // Clear the default layout assigned to the instruction.
589   LayoutUtil::ClearLayout(hlo->mutable_shape());
590   EXPECT_DEATH(ShapeUtil::TransposeIsBitcast(hlo->operand(0)->shape(),
591                                              hlo->shape(), hlo->dimensions()),
592                "has_layout");
593 }
594 
595 // ReshapeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,ReshapeIsBitcastFail)596 TEST_F(LayoutAssignmentTest, ReshapeIsBitcastFail) {
597   auto builder = HloComputation::Builder(TestName());
598   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
599   Shape input_shape_with_layout(input_shape);
600   *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
601   auto param = builder.AddInstruction(
602       HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
603   auto hlo =
604       builder.AddInstruction(HloInstruction::CreateReshape(input_shape, param));
605   // Clear the default layout assigned to the instruction.
606   LayoutUtil::ClearLayout(hlo->mutable_shape());
607   EXPECT_DEATH(
608       ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()),
609       "has_layout");
610 }
611 
612 // Check that the computation below doesn't crash the compiler.
613 //
614 // Within a fusion computation, only the parameters and result get assigned a
615 // layout.  When we run the algebraic simplifier on this computation post layout
616 // assignment, it should not call TransposeIsBitcast on the `transpose` node
617 // inside the fusion computation as TransposeIsBitcast checks both input_shape
618 // and output_shape have layouts.
TEST_F(LayoutAssignmentTest,TransposeWithinFusionDoesNotCrash)619 TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
620   const char* module_str = R"(
621     HloModule test_module
622 
623     fused_computation {
624       param_1 = f32[2,2,2]{2,1,0} parameter(1)
625       transpose = f32[2,2,2]{2,1,0} transpose(param_1), dimensions={0,2,1}
626       reduce_1 = f32[] parameter(0)
627       broadcast_1 = f32[2,2,2]{2,1,0} broadcast(reduce_1), dimensions={}
628       ROOT divide_1 = f32[2,2,2]{2,1,0} divide(transpose, broadcast_1)
629     }
630 
631     ENTRY entry_computation {
632       fusion.1 = f32[2,2,2]{2,1,0} parameter(1)
633       reduce.1 = f32[] parameter(0)
634       fusion.2 = f32[2,2,2]{2,1,0} fusion(reduce.1, fusion.1), kind=kLoop, calls=fused_computation
635      ROOT tuple.1 = (f32[2,2,2]{2,1,0}) tuple(fusion.2)
636     }
637   )";
638 
639   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
640                           ParseAndReturnVerifiedModule(module_str));
641   std::unique_ptr<HloModule> compiled_module =
642       backend()
643           .compiler()
644           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
645                          /*device_allocator=*/nullptr)
646           .value();
647 
648   EXPECT_EQ(OkStatus(), backend()
649                             .compiler()
650                             ->RunBackend(std::move(compiled_module),
651                                          backend().default_stream_executor(),
652                                          /*device_allocator=*/nullptr)
653                             .status());
654 }
655 
656 // A GTE inside of a fusion node inherits the layout of its operand (which
657 // should, if we keep following operands, eventually be a parameter).
TEST_F(LayoutAssignmentTest,GTEInheritsLayoutFromOperand)658 TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
659   const char* module_str = R"(
660     HloModule test_module
661 
662     fused_computation {
663       fparam = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
664       gte0 = f32[2,2,2] get-tuple-element(fparam), index=0
665       gte1 = (f32[2,2,2], f32[2,2,2]) get-tuple-element(fparam), index=1
666       gte1a = f32[2,2,2] get-tuple-element(gte1), index=0
667       gte1b = f32[2,2,2] get-tuple-element(gte1), index=1
668       add = f32[2,2,2] add(gte1a, gte1b)
669       ROOT fresult = f32[2,2,2] add(gte0, add)
670     }
671 
672     ENTRY entry_computation {
673       param = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
674       ROOT fusion =
675         f32[2,2,2] fusion(param), kind=kLoop, calls=fused_computation
676     }
677   )";
678 
679   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
680                           ParseAndReturnVerifiedModule(module_str));
681   ComputationLayout computation_layout(
682       m->entry_computation()->ComputeProgramShape());
683   Shape param_shape = ShapeUtil::MakeTupleShape(
684       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
685        ShapeUtil::MakeTupleShape({
686            ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {1, 2, 0}),
687            ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {2, 0, 1}),
688        })});
689   TF_ASSERT_OK(
690       computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
691           param_shape));
692   computation_layout.mutable_result_layout()->ResetLayout(
693       LayoutUtil::MakeLayout({2, 1, 0}));
694   AssignLayouts(m.get(), &computation_layout);
695 
696   EXPECT_THAT(LayoutOf(m.get(), "gte0"), ElementsAre(0, 1, 2));
697   EXPECT_THAT(LayoutOf(m.get(), "gte1a"), ElementsAre(1, 2, 0));
698   EXPECT_THAT(LayoutOf(m.get(), "gte1b"), ElementsAre(2, 0, 1));
699   EXPECT_THAT(LayoutOf(m.get(), "fresult"), ElementsAre(2, 1, 0));
700   EXPECT_THAT(FindInstruction(m.get(), "gte1")
701                   ->shape()
702                   .tuple_shapes(0)
703                   .layout()
704                   .minor_to_major(),
705               ElementsAre(1, 2, 0));
706   EXPECT_THAT(FindInstruction(m.get(), "gte1")
707                   ->shape()
708                   .tuple_shapes(1)
709                   .layout()
710                   .minor_to_major(),
711               ElementsAre(2, 0, 1));
712 }
713 
TEST_F(LayoutAssignmentTest,ConditionalAsymmetricLayout)714 TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
715   auto builder = HloComputation::Builder(TestName());
716   auto m = CreateNewVerifiedModule();
717   Shape shape = ShapeUtil::MakeShape(F32, {128, 8});
718   Shape tshape = ShapeUtil::MakeTupleShape({shape, shape});
719   Shape result_tshape = ShapeUtil::MakeTupleShape({shape});
720 
721   auto param0 = builder.AddInstruction(
722       HloInstruction::CreateParameter(0, shape, "param0"));
723   auto param1 = builder.AddInstruction(
724       HloInstruction::CreateParameter(1, shape, "param1"));
725   auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
726       2, ShapeUtil::MakeShape(PRED, {}), "param2"));
727   auto tuple =
728       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
729 
730   auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch");
731   {
732     auto param = true_builder.AddInstruction(
733         HloInstruction::CreateParameter(0, tshape, "param"));
734     auto gte0 = true_builder.AddInstruction(
735         HloInstruction::CreateGetTupleElement(shape, param, 0));
736     auto gte1 = true_builder.AddInstruction(
737         HloInstruction::CreateGetTupleElement(shape, param, 1));
738     auto add = true_builder.AddInstruction(
739         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1));
740     true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
741   }
742   HloComputation* true_computation =
743       m->AddEmbeddedComputation(true_builder.Build());
744 
745   auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch");
746   {
747     Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1});
748     false_builder.AddInstruction(
749         HloInstruction::CreateParameter(0, tshape, "param"));
750     // Using infeed as layout assignment does not mess up with it.
751     auto token = false_builder.AddInstruction(HloInstruction::CreateToken());
752     auto infeed = false_builder.AddInstruction(
753         HloInstruction::CreateInfeed(xshape, token, ""));
754     auto infeed_data = false_builder.AddInstruction(
755         HloInstruction::CreateGetTupleElement(xshape, infeed, 0));
756     false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data}));
757   }
758   HloComputation* false_computation =
759       m->AddEmbeddedComputation(false_builder.Build());
760   builder.AddInstruction(HloInstruction::CreateConditional(
761       result_tshape, pred, tuple, true_computation, tuple, false_computation));
762 
763   HloComputation* computation = m->AddEntryComputation(builder.Build());
764   ComputationLayout computation_layout(computation->ComputeProgramShape());
765 
766   AssignLayouts(m.get(), &computation_layout);
767 
768   const HloInstruction* true_root = true_computation->root_instruction();
769   const HloInstruction* false_root = false_computation->root_instruction();
770   EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple);
771   EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple);
772 
773   const HloInstruction* true_result = true_root->operand(0);
774   const HloInstruction* false_result = false_root->operand(0);
775   EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(),
776                                 false_result->shape().layout()));
777   EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
778 }
779 
TEST_F(LayoutAssignmentTest,LayoutAssignmentToTupleSiblingOperand)780 TEST_F(LayoutAssignmentTest, LayoutAssignmentToTupleSiblingOperand) {
781   const char* const hlo_string = R"(
782   HloModule Module
783 
784   true_branch {
785     tparam = (f64[2,3], f64[2,3]) parameter(0)
786     ROOT tgte = f64[2,3] get-tuple-element(tparam), index=1
787   }
788 
789   false_branch {
790     ROOT Arg = f64[2,3] parameter(0)
791   }
792 
793   ENTRY entry {
794     p0 = (f64[2,3], f64[2,3])  parameter(0)
795     p1 = f64[2,3] parameter(1)
796     constant = pred[] constant(true)
797     ROOT conditional = f64[2,3] conditional(constant, p0, p1),
798       true_computation=true_branch, false_computation=false_branch
799   }
800   )";
801   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
802 
803   ComputationLayout computation_layout(
804       m->entry_computation()->ComputeProgramShape());
805   LayoutAssignment layout_assignment(&computation_layout);
806   Status error_status = layout_assignment.Run(m.get()).status();
807   EXPECT_TRUE(error_status.ok());
808 }
809 
TEST_F(LayoutAssignmentTest,InternalErrorOnBitcast)810 TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
811   auto builder = HloComputation::Builder(TestName());
812   auto constant0 = builder.AddInstruction(
813       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
814           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
815   builder.AddInstruction(
816       HloInstruction::CreateBitcast(constant0->shape(), constant0));
817   auto m = CreateNewVerifiedModule();
818   m->AddEntryComputation(builder.Build());
819 
820   ComputationLayout computation_layout(
821       m->entry_computation()->ComputeProgramShape());
822   LayoutAssignment layout_assignment(&computation_layout);
823   Status error_status = layout_assignment.Run(m.get()).status();
824   EXPECT_FALSE(error_status.ok());
825   EXPECT_THAT(
826       error_status.error_message(),
827       ::testing::HasSubstr(
828           "Unexpected bitcast operation seen during layout assignment"));
829 }
830 
TEST_F(LayoutAssignmentTest,ChannelLayoutMismatch)831 TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
832   // Pin non matching layouts to parameter and root.
833   const char* module_str = R"(
834     HloModule test_module
835 
836     ENTRY entry_computation {
837       param = (f32[2,2]) parameter(0)
838       gte = f32[2,2] get-tuple-element(param), index=0
839       token0 = token[] after-all()
840       recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1}
841       recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1,
842         sharding={maximal device=1}
843       ROOT root = f32[2,2] get-tuple-element(recv-done), index=0
844       send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1,
845         sharding={maximal device=0}
846       send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0}
847     }
848   )";
849 
850   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
851                           ParseAndReturnVerifiedModule(module_str));
852   ComputationLayout computation_layout(
853       m->entry_computation()->ComputeProgramShape());
854   Shape param_shape = ShapeUtil::MakeTupleShape(
855       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
856   TF_ASSERT_OK(
857       computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
858           param_shape));
859   computation_layout.mutable_result_layout()->ResetLayout(
860       LayoutUtil::MakeLayout({1, 0}));
861 
862   ChannelLayoutConstraints channel_constraints;
863   AssignLayouts(m.get(), &computation_layout, &channel_constraints);
864 
865   EXPECT_TRUE(ShapeUtil::Equal(FindInstruction(m.get(), "send")->shape(),
866                                FindInstruction(m.get(), "recv")->shape()));
867 }
868 
TEST_F(LayoutAssignmentTest,AllReduceLayoutMissmatch)869 TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
870   // Pin non matching layouts to parameter and root.
871   const char* module_str = R"(
872     HloModule test_module
873 
874     add {
875       lhs = f32[] parameter(0)
876       rhs = f32[] parameter(1)
877       ROOT add = f32[] add(lhs, rhs)
878     }
879 
880     ENTRY entry_computation {
881       param = (f32[2,2]) parameter(0)
882       gte = f32[2,2] get-tuple-element(param), index=0
883       ar.0 = f32[2,2] all-reduce(gte),
884         channel_id=1, replica_groups={{0}}, to_apply=add,
885         sharding={maximal device=0}
886       const = f32[2,2] constant({{0,1},{2,3}})
887       ROOT ar.1 = f32[2,2] all-reduce(const),
888         channel_id=1, replica_groups={{0}}, to_apply=add,
889         sharding={maximal device=1}
890     })";
891   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
892                           ParseAndReturnVerifiedModule(module_str));
893   ComputationLayout computation_layout(
894       m->entry_computation()->ComputeProgramShape());
895   Shape param_shape = ShapeUtil::MakeTupleShape(
896       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
897   TF_ASSERT_OK(
898       computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
899           param_shape));
900   computation_layout.mutable_result_layout()->ResetLayout(
901       LayoutUtil::MakeLayout({1, 0}));
902 
903   ChannelLayoutConstraints channel_constraints;
904   AssignLayouts(m.get(), &computation_layout, &channel_constraints);
905 
906   EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
907   EXPECT_THAT(LayoutOf(m.get(), "ar.0"), ElementsAre(0, 1));
908   EXPECT_THAT(LayoutOf(m.get(), "ar.1"), ElementsAre(0, 1));
909   const HloInstruction* root = m->entry_computation()->root_instruction();
910   EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
911 }
912 
TEST_F(LayoutAssignmentTest,CopySliceOperandToAvoidImplicitLayoutChange)913 TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
914   const char* module_str = R"(
915     HloModule CopySliceOperandToAvoidImplicitLayoutChange
916 
917     ENTRY CopySliceOperandToAvoidImplicitLayoutChange {
918       par0 = f32[3,4]{1,0} parameter(0)
919       par1 = f32[4,5]{0,1} parameter(1)
920       slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]}
921       ROOT add0 = f32[3,4]{1,0} add(par0,slice0)
922     }
923   )";
924 
925   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
926                           ParseAndReturnVerifiedModule(module_str));
927   auto compiled_module =
928       backend()
929           .compiler()
930           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
931                          /*device_allocator=*/nullptr)
932           .value();
933   HloInstruction* root =
934       compiled_module->entry_computation()->root_instruction();
935   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
936   EXPECT_THAT(
937       root,
938       GmockMatch(m::Add(
939           m::Parameter(),
940           m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy)))));
941 }
942 
TEST_F(LayoutAssignmentTest,CopyDSliceOperandToAvoidImplicitLayoutChange)943 TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
944   const char* module_str = R"(
945     HloModule CopyDSliceOperandToAvoidImplicitLayoutChange
946 
947     ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange {
948       par0 = f32[3,4]{1,0} parameter(0)
949       par1 = f32[4,5]{0,1} parameter(1)
950       par2 = s32[] parameter(2)
951       par3 = s32[] parameter(3)
952       dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4}
953       ROOT add0 = f32[3,4]{1,0} add(par0,dslice0)
954     }
955   )";
956 
957   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
958                           ParseAndReturnVerifiedModule(module_str));
959   auto compiled_module =
960       backend()
961           .compiler()
962           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
963                          /*device_allocator=*/nullptr)
964           .value();
965   HloInstruction* root =
966       compiled_module->entry_computation()->root_instruction();
967   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
968   EXPECT_THAT(root,
969               GmockMatch(m::Add(
970                   m::Parameter(),
971                   m::DynamicSlice(
972                       m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
973                       m::Parameter(2), m::Parameter(3)))));
974 }
975 
TEST_F(LayoutAssignmentTest,CopyConcatOperandToAvoidImplicitLayoutChange)976 TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
977   const char* module_str = R"(
978     HloModule CopyConcatOperandToAvoidImplicitLayoutChange
979 
980     ENTRY CopyConcatOperandToAvoidImplicitLayoutChange {
981       par0 = f32[3,8]{1,0} parameter(0)
982       par1 = f32[3,5]{0,1} parameter(1)
983       par2 = f32[3,3]{1,0} parameter(2)
984       concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2),
985         dimensions={1}
986       ROOT add0 = f32[3,8]{1,0} add(par0,concat0)
987     }
988   )";
989 
990   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
991                           ParseAndReturnVerifiedModule(module_str));
992   auto compiled_module =
993       backend()
994           .compiler()
995           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
996                          /*device_allocator=*/nullptr)
997           .value();
998   HloInstruction* root =
999       compiled_module->entry_computation()->root_instruction();
1000   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
1001   EXPECT_THAT(
1002       root,
1003       GmockMatch(m::Add(
1004           m::Parameter(),
1005           m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
1006                          m::Parameter(2)))));
1007 }
1008 
TEST_F(LayoutAssignmentTest,ConvolutionOperandWithImplicitLayoutChangeNotCopied)1009 TEST_F(LayoutAssignmentTest,
1010        ConvolutionOperandWithImplicitLayoutChangeNotCopied) {
1011   const char* module_str = R"(
1012     HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied
1013 
1014     ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied {
1015       par0 = f32[128,3,230,230]{2,3,1,0} parameter(0)
1016       par1 = f32[7,7,3,64]{3,2,0,1} parameter(1)
1017       ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1),
1018         window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01,
1019         feature_group_count=1
1020     }
1021   )";
1022 
1023   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1024                           ParseAndReturnVerifiedModule(module_str));
1025   auto compiled_module =
1026       backend()
1027           .compiler()
1028           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1029                          /*device_allocator=*/nullptr)
1030           .value();
1031   HloInstruction* root =
1032       compiled_module->entry_computation()->root_instruction();
1033   EXPECT_THAT(root,
1034               GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
1035 }
1036 
TEST_F(LayoutAssignmentTest,PropagatingLayoutFromResultToOperand)1037 TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
1038   const char* module_str = R"(
1039     HloModule PropagatingLayoutFromResultToOperand
1040 
1041     ENTRY PropagatingLayoutFromResultToOperand {
1042       par0 = f32[4,5]{1,0} parameter(0)
1043       ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]}
1044     }
1045   )";
1046 
1047   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1048                           ParseAndReturnVerifiedModule(module_str));
1049   auto compiled_module =
1050       backend()
1051           .compiler()
1052           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1053                          /*device_allocator=*/nullptr)
1054           .value();
1055   HloInstruction* root =
1056       compiled_module->entry_computation()->root_instruction();
1057   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
1058   EXPECT_THAT(root,
1059               GmockMatch(m::Slice(
1060                   m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy))));
1061 }
1062 
TEST_F(LayoutAssignmentTest,TupleCopyOnLayoutMismatch)1063 TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
1064   // The first infeed uses layout {0,1}, while the second uses layout {1,0}.
1065   // The mismatch forces a copy of the tuple.  The tuple contains a token, so
1066   // layout assignment will fail if it tries to copy the whole tuple.
1067   const char* module_str = R"(
1068     HloModule TupleCopyOnLayoutMismatch
1069 
1070     condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] {
1071       tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1072       counter.1 = s32[] get-tuple-element(tup.1), index=0
1073       five = s32[] constant(5)
1074       ROOT lt = pred[] compare(counter.1, five), direction=LT
1075     }
1076 
1077     body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {
1078       tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1079       counter.2 = s32[] get-tuple-element(tup.2), index=0
1080       tok.2 = token[] get-tuple-element(tup.2), index=1
1081 
1082       ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2)
1083       next_tok = token[] get-tuple-element(ifeed.2), index=1
1084       next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0
1085 
1086       one = s32[] constant(1)
1087       next_counter = s32[] add(counter.2, one)
1088       ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf)
1089     }
1090 
1091     ENTRY main () -> f32[512,1024]{0,1} {
1092       start_tok = token[] after-all()
1093 
1094       ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok)
1095       itok = token[] get-tuple-element(ifeed.3), index=1
1096       ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0
1097 
1098       zero = s32[] constant(0)
1099       itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf)
1100 
1101       loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2
1102       ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2
1103     }
1104   )";
1105 
1106   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1107                           ParseAndReturnVerifiedModule(module_str));
1108   ComputationLayout computation_layout(
1109       m->entry_computation()->ComputeProgramShape());
1110 
1111   // Sanity check to verify that there's a layout mismatch.
1112   EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
1113   EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
1114 
1115   AssignLayouts(m.get(), &computation_layout);
1116   SCOPED_TRACE(m->ToString());
1117 
1118   // Make sure that layout assignment did not magically eliminate the mismatch,
1119   // in which case the test didn't prove anything.
1120   Layout layout01 = LayoutUtil::MakeLayout({0, 1});
1121   const HloInstruction* loop = nullptr;
1122   ASSERT_THAT(m->entry_computation()->root_instruction(),
1123               GmockMatch(m::GetTupleElement(
1124                   m::Op(&loop)
1125                       .WithOpcode(HloOpcode::kWhile)
1126                       .WithOperand(0, m::Tuple(m::Op(), m::Op(),
1127                                                m::Copy(m::Op().WithShape(
1128                                                    m::Shape().WithLayoutEqualTo(
1129                                                        &layout01))))))));
1130 
1131   Layout layout10 = LayoutUtil::MakeLayout({1, 0});
1132   EXPECT_THAT(loop->while_body()->root_instruction(),
1133               GmockMatch(m::Tuple(
1134                   m::Op(), m::Op(),
1135                   m::Op().WithShape(m::Shape().WithLayoutEqualTo(&layout10)))));
1136 }
1137 
TEST_F(LayoutAssignmentTest,CustomCallNotLayoutConstrained)1138 TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) {
1139   const char* module_str = R"(
1140 HloModule CustomCallNotLayoutConstrained
1141 
1142 ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
1143   %p = f32[42,2,3] parameter(0)
1144   ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz"
1145 }
1146 )";
1147   // Try with a couple different layouts. In each case the custom calls operand
1148   // and result layout should match that of the computation.
1149   {
1150     TF_ASSERT_OK_AND_ASSIGN(
1151         std::unique_ptr<VerifiedHloModule> m,
1152         ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1153     ComputationLayout computation_layout = m->entry_computation_layout();
1154     *computation_layout.mutable_parameter_layout(0) =
1155         ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1}));
1156     *computation_layout.mutable_result_layout() = ShapeLayout(
1157         ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1}));
1158     AssignLayouts(m.get(), &computation_layout);
1159 
1160     HloInstruction* root = m->entry_computation()->root_instruction();
1161     ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1162     ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
1163     ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
1164   }
1165   {
1166     TF_ASSERT_OK_AND_ASSIGN(
1167         std::unique_ptr<VerifiedHloModule> m,
1168         ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1169     ComputationLayout computation_layout = m->entry_computation_layout();
1170     *computation_layout.mutable_parameter_layout(0) =
1171         ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2}));
1172     *computation_layout.mutable_result_layout() = ShapeLayout(
1173         ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1}));
1174     AssignLayouts(m.get(), &computation_layout);
1175 
1176     HloInstruction* root = m->entry_computation()->root_instruction();
1177     ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1178     ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
1179     ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
1180   }
1181 }
1182 
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrained)1183 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) {
1184   const char* module_str = R"(
1185 HloModule CustomCallLayoutConstrained
1186 
1187 ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1188   %p0 = f32[4,4] parameter(0)
1189   %p1 = f32[2,3] parameter(1)
1190   ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
1191 }
1192 )";
1193   TF_ASSERT_OK_AND_ASSIGN(
1194       std::unique_ptr<VerifiedHloModule> m,
1195       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1196   ComputationLayout computation_layout = m->entry_computation_layout();
1197   *computation_layout.mutable_parameter_layout(0) =
1198       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1199   *computation_layout.mutable_parameter_layout(1) =
1200       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1201   *computation_layout.mutable_result_layout() = ShapeLayout(
1202       ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1203   AssignLayouts(m.get(), &computation_layout);
1204 
1205   // The custom call should be partially encapsulated in kCopy instructions
1206   // because of the layout mismatches.
1207   ASSERT_THAT(m->entry_computation()->root_instruction(),
1208               GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter()))));
1209 
1210   const HloInstruction* custom_call =
1211       m->entry_computation()->root_instruction()->operand(0);
1212   ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1213   ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
1214   ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
1215 }
1216 
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedAliasedOutput)1217 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedAliasedOutput) {
1218   const char* module_str = R"(
1219 HloModule customcall.4
1220 
1221 ENTRY %customcall.4 (parameter.1: f32[8,128], parameter.2: f32[8,128]) -> f32[8,128] {
1222   %parameter.1 = f32[8,128]{1,0} parameter(0)
1223   %parameter.2 = f32[8,128]{1,0} parameter(1)
1224   ROOT %custom-call.3 = f32[8,128]{1,0} custom-call(f32[8,128]{1,0} %parameter.1, f32[8,128]{1,0} %parameter.2), custom_call_target="gpu_example_custom_call", operand_layout_constraints={f32[8,128]{1,0}, f32[8,128]{1,0}}, custom_call_has_side_effect=true, output_to_operand_aliasing={{}: (0, {})}
1225 })";
1226   TF_ASSERT_OK_AND_ASSIGN(
1227       std::unique_ptr<VerifiedHloModule> m,
1228       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1229   ComputationLayout computation_layout = m->entry_computation_layout();
1230   *computation_layout.mutable_parameter_layout(0) =
1231       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {8, 128}, {1, 0}));
1232   *computation_layout.mutable_parameter_layout(1) =
1233       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {8, 128}, {1, 0}));
1234   *computation_layout.mutable_result_layout() =
1235       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {8, 128}, {1, 0}));
1236   AssignLayouts(m.get(), &computation_layout);
1237 
1238   const HloInstruction* custom_call =
1239       m->entry_computation()->root_instruction();
1240   ExpectLayoutIs(custom_call->shape(), {1, 0});
1241   ExpectLayoutIs(custom_call->operand(0)->shape(), {1, 0});
1242   ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
1243 }
1244 
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedZeroOperands)1245 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) {
1246   const char* module_str = R"(
1247 HloModule CustomCallLayoutConstrainedZeroOperands
1248 
1249 ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
1250   ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
1251 }
1252 )";
1253   TF_ASSERT_OK_AND_ASSIGN(
1254       std::unique_ptr<VerifiedHloModule> m,
1255       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1256   ComputationLayout computation_layout = m->entry_computation_layout();
1257   *computation_layout.mutable_result_layout() = ShapeLayout(
1258       ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1259   AssignLayouts(m.get(), &computation_layout);
1260 
1261   ASSERT_THAT(m->entry_computation()->root_instruction(),
1262               GmockMatch(m::Copy(m::CustomCall())));
1263 
1264   const HloInstruction* custom_call =
1265       m->entry_computation()->root_instruction()->operand(0);
1266   ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1267 }
1268 
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleOperand)1269 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) {
1270   const char* module_str = R"(
1271 HloModule CustomCallLayoutConstrainedTupleOperand
1272 
1273 ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1274   %p0 = f32[4,4] parameter(0)
1275   %p1 = f32[2,3] parameter(1)
1276   %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1)
1277   ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})}
1278 }
1279 )";
1280   TF_ASSERT_OK_AND_ASSIGN(
1281       std::unique_ptr<VerifiedHloModule> m,
1282       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1283   ComputationLayout computation_layout = m->entry_computation_layout();
1284   *computation_layout.mutable_parameter_layout(0) =
1285       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1286   *computation_layout.mutable_parameter_layout(1) =
1287       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1288   *computation_layout.mutable_result_layout() = ShapeLayout(
1289       ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1290   AssignLayouts(m.get(), &computation_layout);
1291 
1292   HloInstruction* root = m->entry_computation()->root_instruction();
1293   ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
1294 
1295   ASSERT_THAT(m->entry_computation()->root_instruction(),
1296               GmockMatch(m::Copy(m::CustomCall(m::Tuple()))));
1297 
1298   const HloInstruction* custom_call =
1299       m->entry_computation()->root_instruction()->operand(0);
1300   ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1301   ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}});
1302 }
1303 
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleResult)1304 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) {
1305   const char* module_str = R"(
1306 HloModule CustomCallLayoutConstrainedTupleResult
1307 
1308 ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) {
1309   %p0 = f32[4,4] parameter(0)
1310   ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}}
1311 }
1312 )";
1313   // Try with a couple different layouts. In each case the custom calls operand
1314   // and result layout should match that of the computation.
1315   TF_ASSERT_OK_AND_ASSIGN(
1316       std::unique_ptr<VerifiedHloModule> m,
1317       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1318   ComputationLayout computation_layout = m->entry_computation_layout();
1319   *computation_layout.mutable_parameter_layout(0) =
1320       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1321   *computation_layout.mutable_result_layout() =
1322       ShapeLayout(ShapeUtil::MakeTupleShape(
1323           {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}),
1324            ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})}));
1325   AssignLayouts(m.get(), &computation_layout);
1326 
1327   ExpectTupleLayoutIs(m->result_shape(), {{1, 0}, {1, 0}});
1328 
1329   const HloInstruction* custom_call = FindInstruction(m.get(), "custom-call");
1330   ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
1331 }
1332 
AssignLayoutsToComputation(HloModule * m,ChannelLayoutConstraints * channel_constraints=nullptr)1333 Status AssignLayoutsToComputation(
1334     HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) {
1335   if (!m->entry_computation_layout().result_layout().LayoutIsSet()) {
1336     m->mutable_entry_computation_layout()
1337         ->mutable_result_layout()
1338         ->SetToDefaultLayout();
1339   }
1340   LayoutAssignment layout_assignment(m->mutable_entry_computation_layout(),
1341                                      channel_constraints);
1342   return layout_assignment.Run(m).status();
1343 }
1344 
TEST_F(LayoutAssignmentTest,OverwriteDiamondShapedConstraintsX)1345 TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) {
1346   // Check that we handle a diamond-shaped graph correctly.
1347   //      transpose
1348   //       /    \
1349   //     add    |
1350   //       \    /
1351   //        tuple
1352 
1353   auto b = HloComputation::Builder(TestName());
1354   Shape ashape = ShapeUtil::MakeShape(F32, {12, 8});
1355   Shape bshape = ShapeUtil::MakeShape(F32, {8, 12});
1356   auto param0 =
1357       b.AddInstruction(HloInstruction::CreateParameter(0, bshape, "input"));
1358   auto param1 =
1359       b.AddInstruction(HloInstruction::CreateParameter(1, ashape, "input"));
1360   auto transpose =
1361       b.AddInstruction(HloInstruction::CreateTranspose(ashape, param0, {1, 0}));
1362   auto add = b.AddInstruction(
1363       HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1));
1364   b.AddInstruction(HloInstruction::CreateTuple({add, transpose}));
1365   auto m = CreateNewVerifiedModule();
1366   m->AddEntryComputation(b.Build());
1367   Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0});
1368   Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1});
1369   *m->mutable_entry_computation_layout()->mutable_result_layout() =
1370       ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor}));
1371   const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0});
1372   ForceParameterLayout(m.get(), 0, r2_dim0major);
1373   ForceParameterLayout(m.get(), 1, r2_dim0major);
1374   TF_ASSERT_OK(AssignLayoutsToComputation(m.get()));
1375 
1376   EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0));
1377   EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(),
1378               ElementsAre(1, 0));
1379   EXPECT_THAT(add->operand(1)->shape().layout().minor_to_major(),
1380               ElementsAre(1, 0));
1381 
1382   EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(0, 1));
1383 }
1384 
1385 // Tests that the layout assignment supports layout-constrained all-reduce with
1386 // different operand layouts (b/146056839).
TEST_F(LayoutAssignmentTest,LayoutConstrainedAllReduce)1387 TEST_F(LayoutAssignmentTest, LayoutConstrainedAllReduce) {
1388   const char* module_str = R"(
1389 HloModule test_module
1390 
1391 add {
1392   lhs = f32[] parameter(0)
1393   rhs = f32[] parameter(1)
1394   ROOT add = f32[] add(lhs, rhs)
1395 }
1396 
1397 ENTRY entry_computation {
1398   param = (f32[8,4]{0,1}, f32[16,2]{0,1}) parameter(0)
1399   gte0 = f32[8,4] get-tuple-element(param), index=0
1400   gte1 = f32[16,2] get-tuple-element(param), index=1
1401   crs = (f32[8,4]{0,1}, f32[16,2]{1,0}) all-reduce(gte0, gte1),
1402     replica_groups={}, constrain_layout=true, to_apply=add
1403   gte2 = f32[8,4] get-tuple-element(crs), index=0
1404   gte3 = f32[16,2] get-tuple-element(crs), index=1
1405   ROOT result = (f32[8,4]{1,0}, f32[16,2]{1,0}) tuple(gte2, gte3)
1406 }
1407 )";
1408 
1409   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1410                           ParseAndReturnVerifiedModule(module_str));
1411   ComputationLayout computation_layout(
1412       m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1413 
1414   ChannelLayoutConstraints channel_constraints;
1415   AssignLayouts(m.get(), &computation_layout, &channel_constraints);
1416 
1417   const HloInstruction* crs = FindInstruction(m.get(), "crs");
1418   ExpectTupleLayoutIs(crs->shape(), {{0, 1}, {1, 0}});
1419   ExpectLayoutIs(crs->operand(0)->shape(), {0, 1});
1420   ExpectLayoutIs(crs->operand(1)->shape(), {1, 0});
1421 }
1422 
TEST_F(LayoutAssignmentTest,LayoutConstrainedAllToAll)1423 TEST_F(LayoutAssignmentTest, LayoutConstrainedAllToAll) {
1424   const char* module_str = R"(
1425 HloModule test_module
1426 
1427 add {
1428   lhs = f32[] parameter(0)
1429   rhs = f32[] parameter(1)
1430   ROOT add = f32[] add(lhs, rhs)
1431 }
1432 
1433 ENTRY entry_computation {
1434   param = (f32[16,4]{0,1}, f32[16,4]{1,0}) parameter(0)
1435   gte0 = f32[16,4] get-tuple-element(param), index=0
1436   gte1 = f32[16,4] get-tuple-element(param), index=1
1437   alltoall = (f32[16,4]{1,0}, f32[16,4]{1,0}) all-reduce(gte0, gte1),
1438     replica_groups={{0,1}}, constrain_layout=true, to_apply=add
1439   gte2 = f32[16,4] get-tuple-element(alltoall), index=0
1440   gte3 = f32[16,4] get-tuple-element(alltoall), index=1
1441   ROOT concat = f32[16,8]{0,1} concatenate(gte2, gte3), dimensions={1}
1442 }
1443 )";
1444 
1445   TF_ASSERT_OK_AND_ASSIGN(
1446       std::unique_ptr<HloModule> m,
1447       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1448   ComputationLayout computation_layout(
1449       m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1450 
1451   ChannelLayoutConstraints channel_constraints;
1452   AssignLayouts(m.get(), &computation_layout, &channel_constraints);
1453 
1454   const HloInstruction* alltoall = FindInstruction(m.get(), "alltoall");
1455   ExpectTupleLayoutIs(alltoall->shape(), {{1, 0}, {1, 0}});
1456   ExpectLayoutIs(alltoall->operand(0)->shape(), {1, 0});
1457   ExpectLayoutIs(alltoall->operand(1)->shape(), {1, 0});
1458 }
1459 
TEST_F(LayoutAssignmentTest,DynamicRoot)1460 TEST_F(LayoutAssignmentTest, DynamicRoot) {
1461   const char* module_str = R"(
1462 HloModule test_module
1463 
1464 ENTRY entry_computation {
1465   param = f32[1,<=16]{0,1} parameter(0)
1466   ROOT abs = f32[1,<=16]{0,1} abs(param)
1467 }
1468 )";
1469 
1470   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1471                           ParseAndReturnVerifiedModule(module_str));
1472   ComputationLayout computation_layout(
1473       m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1474   computation_layout.mutable_result_layout()->ClearDynamicShape();
1475 
1476   AssignLayouts(m.get(), &computation_layout);
1477 
1478   const HloInstruction* abs = FindInstruction(m.get(), "abs");
1479   ExpectLayoutIs(abs->operand(0)->shape(), {0, 1});
1480   ExpectLayoutIs(abs->shape(), {0, 1});
1481   EXPECT_TRUE(abs->shape().is_dynamic_dimension(1));
1482 }
1483 
1484 // Test the ability to avoid copying across computations by reversing
1485 // computation traversal order.
TEST_F(LayoutAssignmentTest,ReverseComputationOrderAvoidCopy)1486 TEST_F(LayoutAssignmentTest, ReverseComputationOrderAvoidCopy) {
1487   const char* module_str = R"(
1488 HloModule ComputationLayoutAvoidCopy
1489 
1490 call_1 {
1491   %arg_tuple.1 = (f32[93184,4]) parameter(0)
1492   %get-tuple-element.1 = f32[93184,4] get-tuple-element(%arg_tuple.1), index=0
1493   ROOT %reshape.8494 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0}%get-tuple-element.1)
1494 }
1495 
1496 on_true {
1497   %arg_tuple.1 = (f32[93184,4]) parameter(0)
1498   %get-tuple-element.1 = f32[93184,4] get-tuple-element(%arg_tuple.1), index=0
1499   ROOT %reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0}%get-tuple-element.1)
1500 }
1501 
1502 on_false {
1503   %arg_tuple.2 = (f32[93184,4]) parameter(0)
1504   %get-tuple-element.3 = f32[93184,4] get-tuple-element(%arg_tuple.2), index=0
1505   %reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0}%get-tuple-element.3)
1506   ROOT %add = f32[2,512,364] add(%reshape.9717, %reshape.9717)
1507 }
1508 
1509 ENTRY main {
1510   pred.1 = pred[] parameter(0)
1511   arg.2 = f32[93184,4]{1,0} parameter(1)
1512   arg_tuple.11 = (f32[93184,4]{1,0}) tuple(arg.2)
1513   call.1 = f32[2,512,364] call(arg_tuple.11), to_apply=call_1
1514   conditional = f32[2,512,364] conditional(pred.1, arg_tuple.11, arg_tuple.11),
1515                      true_computation=on_true, false_computation=on_false
1516   ROOT add = f32[2,512,364] add(call.1, conditional)
1517 }
1518   )";
1519 
1520   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1521                           ParseAndReturnVerifiedModule(module_str));
1522   ComputationLayout computation_layout(
1523       m->entry_computation()->ComputeProgramShape());
1524   *computation_layout.mutable_parameter_layout(0) =
1525       ShapeLayout(ShapeUtil::MakeShape(PRED, {}));
1526   *computation_layout.mutable_parameter_layout(1) =
1527       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {93184, 4}, {0, 1}));
1528   *computation_layout.mutable_result_layout() = ShapeLayout(
1529       ShapeUtil::MakeShapeWithLayout(F32, {2, 512, 364}, {0, 1, 2}));
1530   ChannelLayoutConstraints channel_constraints;
1531   LayoutAssignment layout_assignment(
1532       &computation_layout,
1533       /*channel_constraints=*/&channel_constraints,
1534       /* reverse_computation_order = */ true);
1535   EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
1536   const HloInstruction* call_1 = FindInstruction(m.get(), "reshape.8494");
1537   ExpectLayoutIs(call_1->shape(), {0, 1, 2});
1538   const HloInstruction* on_true = FindInstruction(m.get(), "reshape.8493");
1539   ExpectLayoutIs(on_true->shape(), {0, 1, 2});
1540   const HloInstruction* on_false = FindInstruction(m.get(), "reshape.9717");
1541   ExpectLayoutIs(on_false->shape(), {0, 1, 2});
1542 }
1543 
1544 // Test the ability to propagate operand constraints across multiple operations.
TEST_F(LayoutAssignmentTest,PropagateOperandLayout)1545 TEST_F(LayoutAssignmentTest, PropagateOperandLayout) {
1546   const char* module_str = R"(
1547 HloModule ComputationPropagateOperandLayout
1548 
1549 %scalar_add_computation.1 (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
1550   %scalar_lhs.1 = f32[]{:T(256)} parameter(0)
1551   %scalar_rhs.1 = f32[]{:T(256)} parameter(1)
1552   ROOT %add.25 = f32[]{:T(256)} add(f32[]{:T(256)} %scalar_lhs.1, f32[]{:T(256)} %scalar_rhs.1)
1553 }
1554 
1555 ENTRY main {
1556   %convolution-base-dilated = f32[64,243,243,10]{0,3,2,1:T(8,128)} parameter(0)
1557   %reduce.13 = f32[64,243,243]{2,1,0:T(8,128)} parameter(1)
1558   %divide.2 = f32[64,243,243,10]{2,3,1,0:T(8,128)} parameter(2)
1559   %reshape.32 = f32[384,10,1,1]{0,1,3,2:T(8,128)} parameter(3)
1560   %subtract = f32[64,243,243,384]{3,0,2,1:T(8,128)} parameter(4)
1561   %constant = f32[]{:T(256)} constant(3779136)
1562   %broadcast.71 = f32[64,243,243,384]{3,0,2,1:T(8,128)} parameter(5)
1563   %broadcast.46 = f32[64,243,243,10] broadcast(f32[64,243,243] %reduce.13), dimensions={0,1,2}
1564   %subtract.14 = f32[64,243,243,10] subtract(f32[64,243,243,10] %convolution-base-dilated, f32[64,243,243,10] %broadcast.46)
1565   %multiply.22 = f32[64,243,243,10] multiply(f32[64,243,243,10] %subtract.14, f32[64,243,243,10]{2,3,1,0:T(8,128)} %divide.2)
1566   %convolution.9 = f32[64,243,243,384] convolution(f32[64,243,243,10] %multiply.22, f32[384,10,1,1] %reshape.32), window={size=1x1}, dim_labels=01bf_oi01->01bf
1567   %reduce.14 = f32[384] reduce(f32[64,243,243,384] %convolution.9, f32[]{:T(256)} %constant), dimensions={0,1,2}, to_apply=%scalar_add_computation.1
1568   %multiply.24 = f32[64,243,243,384] multiply(f32[64,243,243,384] %convolution.9, f32[64,243,243,384] %subtract)
1569   %reduce.15 = f32[384] reduce(f32[64,243,243,384] %multiply.24, f32[] %constant), dimensions={0,1,2}, to_apply=%scalar_add_computation.1
1570   %broadcast.47 = f32[64,243,243,384] broadcast(f32[] %constant), dimensions={}
1571   %multiply.23 = f32[64,243,243,384] multiply(f32[64,243,243,384] %convolution.9, f32[64,243,243,384] %broadcast.47)
1572   %broadcast.48 = f32[64,243,243,384]{3,2,1,0:T(8,128)} broadcast(f32[384]{0:T(512)} %reduce.14), dimensions={3}
1573   %subtract.15 = f32[64,243,243,384] subtract(f32[64,243,243,384] %multiply.23, f32[64,243,243,384] %broadcast.48)
1574   %broadcast.50 = f32[64,243,243,384] broadcast(f32[384] %reduce.15), dimensions={3}
1575   %multiply.25 = f32[64,243,243,384] multiply(f32[64,243,243,384] %broadcast.50, f32[64,243,243,384] %subtract)
1576   %divide.7 = f32[64,243,243,384]{3,0,2,1:T(8,128)} divide(f32[64,243,243,384] %multiply.25, f32[64,243,243,384] %broadcast.71)
1577   ROOT %subtract.17 = f32[64,243,243,384]{3,0,2,1:T(8,128)} subtract(f32[64,243,243,384] %subtract.15, f32[64,243,243,384] %divide.7)
1578 }
1579   )";
1580 
1581   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1582                           ParseAndReturnVerifiedModule(module_str));
1583   ComputationLayout computation_layout(
1584       m->entry_computation()->ComputeProgramShape());
1585   *computation_layout.mutable_parameter_layout(0) = ShapeLayout(
1586       ShapeUtil::MakeShapeWithLayout(F32, {64, 243, 243, 10}, {0, 3, 2, 1}));
1587   *computation_layout.mutable_parameter_layout(1) = ShapeLayout(
1588       ShapeUtil::MakeShapeWithLayout(F32, {64, 243, 243}, {2, 1, 0}));
1589   *computation_layout.mutable_parameter_layout(2) = ShapeLayout(
1590       ShapeUtil::MakeShapeWithLayout(F32, {64, 243, 243, 10}, {2, 3, 1, 0}));
1591   *computation_layout.mutable_parameter_layout(3) = ShapeLayout(
1592       ShapeUtil::MakeShapeWithLayout(F32, {384, 10, 1, 1}, {0, 1, 3, 2}));
1593   *computation_layout.mutable_parameter_layout(4) = ShapeLayout(
1594       ShapeUtil::MakeShapeWithLayout(F32, {64, 243, 243, 384}, {3, 0, 2, 1}));
1595   *computation_layout.mutable_result_layout() = ShapeLayout(
1596       ShapeUtil::MakeShapeWithLayout(F32, {64, 243, 243, 384}, {3, 0, 2, 1}));
1597   ChannelLayoutConstraints channel_constraints;
1598   LayoutAssignment layout_assignment(
1599       &computation_layout,
1600       /*channel_constraints=*/&channel_constraints);
1601   EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
1602   const HloInstruction* subtract_15 = FindInstruction(m.get(), "subtract.15");
1603   ExpectLayoutIs(subtract_15->shape(), {3, 0, 2, 1});
1604   const HloInstruction* broadcast_46 = FindInstruction(m.get(), "broadcast.46");
1605   ExpectLayoutIs(broadcast_46->shape(), {2, 3, 1, 0});
1606   const HloInstruction* subtract_14 = FindInstruction(m.get(), "subtract.14");
1607   ExpectLayoutIs(subtract_14->shape(), {2, 3, 1, 0});
1608 }
1609 
1610 }  // namespace
1611 }  // namespace xla
1612