1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17
18 #include <initializer_list>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
33 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
34 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
35 #include "tensorflow/compiler/xla/shape_layout.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/test.h"
38 #include "tensorflow/compiler/xla/test_helpers.h"
39 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
40 #include "tensorflow/compiler/xla/tests/test_utils.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45
46 namespace xla {
47 namespace {
48
49 namespace m = xla::match;
50 using ::testing::ElementsAre;
51
52 class LayoutAssignmentTest : public HloTestBase {
53 protected:
AssignLayouts(HloModule * m,ComputationLayout * entry_computation_layout,ChannelLayoutConstraints * channel_constraints=nullptr)54 void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout,
55 ChannelLayoutConstraints* channel_constraints = nullptr) {
56 LayoutAssignment layout_assignment(
57 entry_computation_layout,
58 /*channel_constraints=*/channel_constraints);
59 EXPECT_IS_OK(layout_assignment.Run(m).status());
60 }
61
LayoutOf(HloModule * m,absl::string_view name)62 std::vector<int64_t> LayoutOf(HloModule* m, absl::string_view name) {
63 HloInstruction* instr = FindInstruction(m, name);
64 CHECK(instr != nullptr) << name;
65 auto minor_to_major = instr->shape().layout().minor_to_major();
66 return std::vector<int64_t>(minor_to_major.begin(), minor_to_major.end());
67 }
68
ExpectLayoutIs(const Shape & shape,absl::Span<const int64_t> minor_to_major)69 void ExpectLayoutIs(const Shape& shape,
70 absl::Span<const int64_t> minor_to_major) {
71 const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
72 EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
73 << "Expected layout " << expected << ", actual " << shape.layout();
74 }
75
ExpectTupleLayoutIs(const Shape & shape,std::initializer_list<absl::Span<const int64_t>> minor_to_majors)76 void ExpectTupleLayoutIs(
77 const Shape& shape,
78 std::initializer_list<absl::Span<const int64_t>> minor_to_majors) {
79 int i = 0;
80 for (const absl::Span<const int64_t> minor_to_major : minor_to_majors) {
81 const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
82 const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout();
83 EXPECT_TRUE(LayoutUtil::Equal(actual, expected))
84 << "Expected tuple element " << i << " layout " << expected
85 << ", actual " << actual;
86 ++i;
87 }
88 }
89 };
90
TEST_F(LayoutAssignmentTest,ComputationLayout)91 TEST_F(LayoutAssignmentTest, ComputationLayout) {
92 // Verify the layouts of the root and parameter instructions of a computation
93 // match the ComputationLayout for two different layouts.
94 std::vector<std::vector<int64_t>> minor_to_majors = {{0, 1}, {1, 0}};
95 for (auto& minor_to_major : minor_to_majors) {
96 auto builder = HloComputation::Builder(TestName());
97 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
98 auto param0 = builder.AddInstruction(
99 HloInstruction::CreateParameter(0, ashape, "param0"));
100 auto param1 = builder.AddInstruction(
101 HloInstruction::CreateParameter(1, ashape, "param1"));
102 auto add = builder.AddInstruction(
103 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
104 auto m = CreateNewVerifiedModule();
105 HloComputation* computation = m->AddEntryComputation(builder.Build());
106
107 Layout layout = LayoutUtil::MakeLayout(minor_to_major);
108 Shape shape(ashape);
109 *shape.mutable_layout() = layout;
110 const ShapeLayout shape_layout(shape);
111
112 ComputationLayout computation_layout(computation->ComputeProgramShape());
113 *computation_layout.mutable_parameter_layout(0) = shape_layout;
114 *computation_layout.mutable_parameter_layout(1) = shape_layout;
115 *computation_layout.mutable_result_layout() = shape_layout;
116 AssignLayouts(m.get(), &computation_layout);
117 EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
118 EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
119 EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
120 }
121 }
122
TEST_F(LayoutAssignmentTest,ComputationLayoutMixedLayout)123 TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
124 // Verify the layouts of the root and parameter instructions of a computation
125 // match the ComputationLayout which has mixed layout.
126 auto builder = HloComputation::Builder(TestName());
127 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
128 auto param0 = builder.AddInstruction(
129 HloInstruction::CreateParameter(0, ashape, "param0"));
130 auto param1 = builder.AddInstruction(
131 HloInstruction::CreateParameter(1, ashape, "param1"));
132 builder.AddInstruction(
133 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
134 auto m = CreateNewVerifiedModule();
135 HloComputation* computation = m->AddEntryComputation(builder.Build());
136
137 Layout col_major_layout = LayoutUtil::MakeLayout({1, 0});
138 Shape col_major_shape(ashape);
139 *col_major_shape.mutable_layout() = col_major_layout;
140 const ShapeLayout col_major(col_major_shape);
141
142 Layout row_major_layout = LayoutUtil::MakeLayout({0, 1});
143 Shape row_major_shape(ashape);
144 *row_major_shape.mutable_layout() = row_major_layout;
145 const ShapeLayout row_major(row_major_shape);
146
147 ComputationLayout computation_layout(computation->ComputeProgramShape());
148 *computation_layout.mutable_parameter_layout(0) = col_major;
149 *computation_layout.mutable_parameter_layout(1) = row_major;
150 *computation_layout.mutable_result_layout() = col_major;
151
152 AssignLayouts(m.get(), &computation_layout);
153 EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
154 EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
155 EXPECT_TRUE(LayoutUtil::Equal(
156 col_major_layout, computation->root_instruction()->shape().layout()));
157 }
158
TEST_F(LayoutAssignmentTest,FusionInstruction)159 TEST_F(LayoutAssignmentTest, FusionInstruction) {
160 // Verify that the layout of the fused parameters in a fusion instruction
161 // match that of the fusion operands. Other fused instructions should have no
162 // layout.
163 std::vector<std::vector<int64_t>> minor_to_majors = {{0, 1}, {1, 0}};
164 for (auto& minor_to_major : minor_to_majors) {
165 auto builder = HloComputation::Builder(TestName());
166 auto constant_literal1 = LiteralUtil::CreateR2WithLayout<float>(
167 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
168 auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
169 {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
170 Shape ashape = constant_literal1.shape();
171
172 auto constant1 = builder.AddInstruction(
173 HloInstruction::CreateConstant(std::move(constant_literal1)));
174 auto constant2 = builder.AddInstruction(
175 HloInstruction::CreateConstant(std::move(constant_literal2)));
176 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
177 ashape, HloOpcode::kAdd, constant1, constant2));
178 auto negate1 = builder.AddInstruction(
179 HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, add));
180 auto negate2 = builder.AddInstruction(
181 HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1));
182
183 auto m = CreateNewVerifiedModule();
184 HloComputation* computation = m->AddEntryComputation(builder.Build());
185
186 auto fusion = computation->CreateFusionInstruction(
187 {negate2, negate1, add}, HloInstruction::FusionKind::kLoop);
188
189 Layout layout = LayoutUtil::MakeLayout(minor_to_major);
190 Shape shape(ashape);
191 *shape.mutable_layout() = layout;
192 const ShapeLayout shape_layout(shape);
193
194 ComputationLayout computation_layout(computation->ComputeProgramShape());
195 *computation_layout.mutable_result_layout() = shape_layout;
196
197 AssignLayouts(m.get(), &computation_layout);
198
199 EXPECT_TRUE(LayoutUtil::Equal(
200 layout, fusion->fused_parameter(0)->shape().layout()));
201 EXPECT_TRUE(LayoutUtil::Equal(
202 layout, fusion->fused_parameter(1)->shape().layout()));
203 EXPECT_TRUE(LayoutUtil::Equal(
204 layout, fusion->fused_expression_root()->shape().layout()));
205
206 // Inner fused node should not have layout.
207 EXPECT_FALSE(LayoutUtil::HasLayout(
208 fusion->fused_expression_root()->operand(0)->shape()));
209 }
210 }
211
TEST_F(LayoutAssignmentTest,TupleLayout)212 TEST_F(LayoutAssignmentTest, TupleLayout) {
213 // Verify the layouts of a tuple are assigned properly (the element layouts
214 // match their source).
215 auto builder = HloComputation::Builder(TestName());
216 auto constant0 = builder.AddInstruction(
217 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
218 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
219 auto constant1 = builder.AddInstruction(
220 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
221 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
222 auto tuple = builder.AddInstruction(
223 HloInstruction::CreateTuple({constant0, constant1}));
224
225 // To avoid having to construct a tuple layout in the ComputationLayout below,
226 // make the result of the instruction be an array.
227 auto get_element0 = builder.AddInstruction(
228 HloInstruction::CreateGetTupleElement(constant0->shape(), tuple, 0));
229 auto negate = builder.AddInstruction(HloInstruction::CreateUnary(
230 constant0->shape(), HloOpcode::kNegate, get_element0));
231
232 auto m = CreateNewVerifiedModule();
233 m->AddEntryComputation(builder.Build());
234
235 ComputationLayout computation_layout(
236 m->entry_computation()->ComputeProgramShape());
237
238 AssignLayouts(m.get(), &computation_layout);
239
240 EXPECT_TRUE(
241 LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
242
243 EXPECT_TRUE(LayoutUtil::HasLayout(tuple->shape()));
244 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
245 negate->shape(), computation_layout.result_layout().shape()));
246 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
247 ShapeUtil::GetTupleElementShape(tuple->shape(), 1), constant1->shape()));
248 }
249
TEST_F(LayoutAssignmentTest,ConflictingLayoutTuple)250 TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
251 // Construct following computation which has conflicting layouts for two
252 // elements of a tuple which share the same source logicalb buffer:
253 //
254 // %constant = Constant(...)
255 // %inner_tuple = Tuple(%constant)
256 // %nested_tuple = Tuple(%inner_tuple, %inner_tuple)
257 //
258 // Result layout col-major for the first element and row-major for the
259 // second. This results in the conflict where the element of the inner_tuple
260 // needs to be both col and row major. This is resolved by deep-copying the
261 // tuple and assigning the layouts of the copied arrays as needed.
262 auto builder = HloComputation::Builder(TestName());
263 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
264 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
265 auto inner_tuple =
266 builder.AddInstruction(HloInstruction::CreateTuple({constant}));
267 auto nested_tuple = builder.AddInstruction(
268 HloInstruction::CreateTuple({inner_tuple, inner_tuple}));
269
270 auto m = CreateNewVerifiedModule();
271 m->AddEntryComputation(builder.Build());
272
273 ComputationLayout computation_layout(
274 m->entry_computation()->ComputeProgramShape());
275 Shape result_shape = nested_tuple->shape();
276 *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) =
277 ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
278 *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{1, 0}) =
279 ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
280 TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
281 result_shape));
282
283 LayoutAssignment layout_assignment(&computation_layout);
284 AssignLayouts(m.get(), &computation_layout);
285
286 // Layout assignment should have deep copied the result of the computation to
287 // address the layout conflict. This results in several Tuple() and
288 // GetTupleElement() instructions. Running algebraic simplification should
289 // clean up the code to something like:
290 //
291 // %constant = Constant(...) layout={1,0}
292 // %tuple.0 = Tuple(%constant) layout=({1,0})
293 // %copy = Copy(%constant) layout={0,1} # layout transposed
294 // %tuple.1 = Tuple(%copy) layout=({0,1})
295 // %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1}))
296 //
297 AlgebraicSimplifierOptions options(
298 [](const Shape&, const Shape&) { return false; });
299 options.set_is_layout_sensitive(true);
300 EXPECT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
301 HloInstruction* root = m->entry_computation()->root_instruction();
302 // Verify layout of the root and the root's operands.
303 EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape()));
304 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}),
305 root->operand(0)->shape()));
306 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}),
307 root->operand(1)->shape()));
308
309 // Verify the structure of the HLO graph.
310 EXPECT_THAT(root,
311 GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)),
312 m::Tuple(m::Copy(m::Op().Is(constant))))));
313 }
314
TEST_F(LayoutAssignmentTest,ElementwiseAndReshape)315 TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
316 // param -> log -> reshape -> tanh
317 auto builder = HloComputation::Builder(TestName());
318 Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1});
319 Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2});
320 auto param = builder.AddInstruction(
321 HloInstruction::CreateParameter(0, ashape, "param"));
322 auto log = builder.AddInstruction(
323 HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
324 auto reshape =
325 builder.AddInstruction(HloInstruction::CreateReshape(bshape, log));
326 auto tanh = builder.AddInstruction(
327 HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape));
328
329 auto m = CreateNewVerifiedModule();
330 HloComputation* computation = m->AddEntryComputation(builder.Build(tanh));
331
332 Shape ashape_with_layout(ashape);
333 Shape bshape_with_layout(bshape);
334 *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3});
335 *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
336
337 ComputationLayout computation_layout(computation->ComputeProgramShape());
338 *computation_layout.mutable_parameter_layout(0) =
339 ShapeLayout(ashape_with_layout);
340 *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
341 AssignLayouts(m.get(), &computation_layout);
342
343 auto log_minor_to_major = log->shape().layout().minor_to_major();
344 EXPECT_GT(PositionInContainer(log_minor_to_major, 1),
345 PositionInContainer(log_minor_to_major, 2));
346
347 auto reshape_minor_to_major = reshape->shape().layout().minor_to_major();
348 EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0),
349 PositionInContainer(reshape_minor_to_major, 2));
350 }
351
352 // Test whether LayoutAssignment assigns layouts to elementwise operations to
353 // keep linear indices valid across them, and to transpositions to make them
354 // bitcasts.
TEST_F(LayoutAssignmentTest,ElementwiseAndTranspose)355 TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
356 // param -> log -> transpose -> tanh
357 auto builder = HloComputation::Builder(TestName());
358 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
359 Shape bshape = ShapeUtil::MakeShape(F32, {12, 42});
360 auto param = builder.AddInstruction(
361 HloInstruction::CreateParameter(0, ashape, "param"));
362 auto log = builder.AddInstruction(
363 HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
364 auto transpose = builder.AddInstruction(
365 HloInstruction::CreateTranspose(bshape, log, {1, 0}));
366 auto tanh = builder.AddInstruction(
367 HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose));
368 auto m = CreateNewVerifiedModule();
369 auto computation = m->AddEntryComputation(builder.Build(tanh));
370
371 Shape ashape_with_layout(ashape);
372 Shape bshape_with_layout(bshape);
373 *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
374 *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
375
376 ComputationLayout computation_layout(computation->ComputeProgramShape());
377 *computation_layout.mutable_parameter_layout(0) =
378 ShapeLayout(ashape_with_layout);
379 *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
380 AssignLayouts(m.get(), &computation_layout);
381
382 EXPECT_TRUE(
383 LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
384 EXPECT_TRUE(LayoutUtil::Equal(bshape_with_layout.layout(),
385 transpose->shape().layout()));
386 EXPECT_TRUE(
387 LayoutUtil::Equal(bshape_with_layout.layout(), tanh->shape().layout()));
388 }
389
390 // Test whether LayoutAssignment assigns layouts to transpositions to make them
391 // bitcasts.
TEST_F(LayoutAssignmentTest,BroadcastAndTranspose)392 TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
393 // param -> broadcast -> transpose
394 auto builder = HloComputation::Builder(TestName());
395 Shape ashape = ShapeUtil::MakeShape(F32, {3, 4});
396 Shape bshape = ShapeUtil::MakeShape(F32, {2, 3, 4});
397 Shape cshape = ShapeUtil::MakeShape(F32, {4, 3, 2});
398 auto param = builder.AddInstruction(
399 HloInstruction::CreateParameter(0, ashape, "param"));
400 auto broadcast = builder.AddInstruction(
401 HloInstruction::CreateBroadcast(bshape, param, {1, 2}));
402 auto transpose = builder.AddInstruction(
403 HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0}));
404 auto m = CreateNewVerifiedModule();
405 HloComputation* computation =
406 m->AddEntryComputation(builder.Build(transpose));
407
408 Shape input_shape_with_layout(ashape);
409 Shape output_shape_with_layout(cshape);
410 *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
411 *output_shape_with_layout.mutable_layout() =
412 LayoutUtil::MakeLayout({2, 1, 0});
413
414 ComputationLayout computation_layout(computation->ComputeProgramShape());
415 *computation_layout.mutable_parameter_layout(0) =
416 ShapeLayout(input_shape_with_layout);
417 *computation_layout.mutable_result_layout() =
418 ShapeLayout(output_shape_with_layout);
419 AssignLayouts(m.get(), &computation_layout);
420
421 EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
422 ElementsAre(0, 1, 2));
423 }
424
TEST_F(LayoutAssignmentTest,ReshapeOperandHasMultipleUsers)425 TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
426 // param[4] -> broadcast[3x4] ------> transpose[4x3]-------- -------> tuple
427 // \ /
428 // \-> tanh[3x4] -> broadcast2[2x3x4] -/
429 //
430 // The layout of `transpose` is set to {1,0} because it provides a buffer to
431 // the computation result which has a fixed layout.. Therefore, `broadcast`
432 // (the operand of transpose) is expected to have layout {0,1} so that the
433 // transpose is a bitcast. Furthermore, `tanh` is expected to have the same
434 // layout as `broadcast` (i.e. {0,1}) because `tanh` is elementwise.
435 Shape f32_4 = ShapeUtil::MakeShape(F32, {4});
436 Shape f32_34 = ShapeUtil::MakeShape(F32, {3, 4});
437 Shape f32_43 = ShapeUtil::MakeShape(F32, {4, 3});
438 Shape f32_234 = ShapeUtil::MakeShape(F32, {2, 3, 4});
439
440 auto builder = HloComputation::Builder(TestName());
441 auto param = builder.AddInstruction(
442 HloInstruction::CreateParameter(0, f32_4, "param"));
443 auto broadcast = builder.AddInstruction(
444 HloInstruction::CreateBroadcast(f32_34, param, {1}));
445 auto transpose = builder.AddInstruction(
446 HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
447 auto tanh = builder.AddInstruction(
448 HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
449 auto broadcast2 = builder.AddInstruction(
450 HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
451 auto tuple = builder.AddInstruction(
452 HloInstruction::CreateTuple({transpose, broadcast2}));
453 auto m = CreateNewVerifiedModule();
454 HloComputation* computation = m->AddEntryComputation(builder.Build(tuple));
455
456 ComputationLayout computation_layout(computation->ComputeProgramShape());
457 Shape param_shape_with_layout(f32_4);
458 Shape transpose_shape_with_layout(f32_43);
459 Shape broadcast2_shape_with_layout(f32_234);
460 *param_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0});
461 *transpose_shape_with_layout.mutable_layout() =
462 LayoutUtil::MakeLayout({1, 0});
463 *broadcast2_shape_with_layout.mutable_layout() =
464 LayoutUtil::MakeLayout({2, 1, 0});
465
466 *computation_layout.mutable_parameter_layout(0) =
467 ShapeLayout(param_shape_with_layout);
468 *computation_layout.mutable_result_layout() =
469 ShapeLayout(ShapeUtil::MakeTupleShape(
470 {transpose_shape_with_layout, broadcast2_shape_with_layout}));
471 AssignLayouts(m.get(), &computation_layout);
472 EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
473 EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
474 EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1));
475 }
476
477 class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
478 public:
OperandsMustBeTheSameLayoutAssignment(ComputationLayout * entry_computation_layout)479 explicit OperandsMustBeTheSameLayoutAssignment(
480 ComputationLayout* entry_computation_layout)
481 : LayoutAssignment(entry_computation_layout) {}
482
483 protected:
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)484 Status PropagateBufferConstraint(
485 const BufferLayoutConstraint& buffer_constraint,
486 LayoutConstraints* constraints) override {
487 const LogicalBuffer& buffer = buffer_constraint.buffer();
488 const HloInstruction* instruction = buffer.instruction();
489
490 // Force the operands' layout to the output layout.
491 for (int64_t operand_no = 0; operand_no < instruction->operand_count();
492 ++operand_no) {
493 const HloInstruction* operand = instruction->operand(operand_no);
494 if (instruction->shape().rank() != operand->shape().rank()) {
495 continue;
496 }
497 TF_RETURN_IF_ERROR(SetArrayOperandLayout(buffer_constraint.layout(),
498 instruction, operand_no,
499 /*mandatory=*/true));
500 }
501 return PropagateBufferConstraintToUses(buffer_constraint, constraints);
502 }
503 };
504
TEST_F(LayoutAssignmentTest,MakeOperandsTheSame)505 TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
506 // param0 -> concatenate -> reshape
507 // param1 -^
508 auto builder = HloComputation::Builder(TestName());
509 Shape ashape = ShapeUtil::MakeShape(F32, {50, 1});
510 Shape bshape = ShapeUtil::MakeShape(F32, {50, 2});
511 Shape cshape = ShapeUtil::MakeShape(F32, {100});
512 auto param0 = builder.AddInstruction(
513 HloInstruction::CreateParameter(0, ashape, "param"));
514 auto param1 = builder.AddInstruction(
515 HloInstruction::CreateParameter(1, ashape, "param"));
516 auto concatenate = builder.AddInstruction(
517 HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1));
518 auto reshape = builder.AddInstruction(
519 HloInstruction::CreateReshape(cshape, concatenate));
520 auto m = CreateNewVerifiedModule();
521 HloComputation* computation = m->AddEntryComputation(builder.Build(reshape));
522
523 Shape param0_shape_with_layout(ashape);
524 Shape param1_shape_with_layout(ashape);
525 *param0_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
526 *param1_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
527
528 ComputationLayout computation_layout(computation->ComputeProgramShape());
529 *computation_layout.mutable_parameter_layout(0) =
530 ShapeLayout(param0_shape_with_layout);
531 *computation_layout.mutable_parameter_layout(1) =
532 ShapeLayout(param1_shape_with_layout);
533 OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
534 EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
535
536 EXPECT_EQ(concatenate->operand(0)->shape().layout().minor_to_major(),
537 concatenate->operand(1)->shape().layout().minor_to_major());
538 EXPECT_EQ(concatenate->shape().layout().minor_to_major(),
539 concatenate->operand(1)->shape().layout().minor_to_major());
540 }
541
542 // Test layout assignment of a transpose into a bitcast based on its operand.
TEST_F(LayoutAssignmentTest,TransposeToBitcastFromOperand)543 TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
544 auto builder = HloComputation::Builder(TestName());
545 Shape input_shape_with_layout =
546 ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1});
547 auto param = builder.AddInstruction(
548 HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
549 auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
550 ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1}));
551 auto m = CreateNewVerifiedModule();
552 HloComputation* computation =
553 m->AddEntryComputation(builder.Build(transpose));
554 ComputationLayout computation_layout(computation->ComputeProgramShape());
555 AssignLayouts(m.get(), &computation_layout);
556 EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
557 transpose->shape(), {2, 3, 0, 1}));
558 }
559 // Test layout assignment of a transpose into a bitcast based on its user.
TEST_F(LayoutAssignmentTest,TransposeToBitcastToUser)560 TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
561 auto builder = HloComputation::Builder(TestName());
562 Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
563 auto constant = builder.AddInstruction(
564 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
565 auto broadcast = builder.AddInstruction(
566 HloInstruction::CreateBroadcast(input_shape, constant, {}));
567 auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
568 ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1}));
569 auto m = CreateNewVerifiedModule();
570 HloComputation* computation =
571 m->AddEntryComputation(builder.Build(transpose));
572 ComputationLayout computation_layout(computation->ComputeProgramShape());
573 AssignLayouts(m.get(), &computation_layout);
574 EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
575 transpose->shape(), {2, 3, 0, 1}));
576 }
577
578 // TransposeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,TransposeIsBitcastFail)579 TEST_F(LayoutAssignmentTest, TransposeIsBitcastFail) {
580 auto builder = HloComputation::Builder(TestName());
581 Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
582 Shape input_shape_with_layout(input_shape);
583 *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
584 auto param = builder.AddInstruction(
585 HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
586 auto hlo = builder.AddInstruction(
587 HloInstruction::CreateTranspose(input_shape, param, {0, 2, 1}));
588 // Clear the default layout assigned to the instruction.
589 LayoutUtil::ClearLayout(hlo->mutable_shape());
590 EXPECT_DEATH(ShapeUtil::TransposeIsBitcast(hlo->operand(0)->shape(),
591 hlo->shape(), hlo->dimensions()),
592 "has_layout");
593 }
594
595 // ReshapeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,ReshapeIsBitcastFail)596 TEST_F(LayoutAssignmentTest, ReshapeIsBitcastFail) {
597 auto builder = HloComputation::Builder(TestName());
598 Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
599 Shape input_shape_with_layout(input_shape);
600 *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
601 auto param = builder.AddInstruction(
602 HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
603 auto hlo =
604 builder.AddInstruction(HloInstruction::CreateReshape(input_shape, param));
605 // Clear the default layout assigned to the instruction.
606 LayoutUtil::ClearLayout(hlo->mutable_shape());
607 EXPECT_DEATH(
608 ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()),
609 "has_layout");
610 }
611
612 // Check that the computation below doesn't crash the compiler.
613 //
614 // Within a fusion computation, only the parameters and result get assigned a
615 // layout. When we run the algebraic simplifier on this computation post layout
616 // assignment, it should not call TransposeIsBitcast on the `transpose` node
617 // inside the fusion computation as TransposeIsBitcast checks both input_shape
618 // and output_shape have layouts.
TEST_F(LayoutAssignmentTest,TransposeWithinFusionDoesNotCrash)619 TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
620 const char* module_str = R"(
621 HloModule test_module
622
623 fused_computation {
624 param_1 = f32[2,2,2]{2,1,0} parameter(1)
625 transpose = f32[2,2,2]{2,1,0} transpose(param_1), dimensions={0,2,1}
626 reduce_1 = f32[] parameter(0)
627 broadcast_1 = f32[2,2,2]{2,1,0} broadcast(reduce_1), dimensions={}
628 ROOT divide_1 = f32[2,2,2]{2,1,0} divide(transpose, broadcast_1)
629 }
630
631 ENTRY entry_computation {
632 fusion.1 = f32[2,2,2]{2,1,0} parameter(1)
633 reduce.1 = f32[] parameter(0)
634 fusion.2 = f32[2,2,2]{2,1,0} fusion(reduce.1, fusion.1), kind=kLoop, calls=fused_computation
635 ROOT tuple.1 = (f32[2,2,2]{2,1,0}) tuple(fusion.2)
636 }
637 )";
638
639 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
640 ParseAndReturnVerifiedModule(module_str));
641 std::unique_ptr<HloModule> compiled_module =
642 backend()
643 .compiler()
644 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
645 /*device_allocator=*/nullptr)
646 .value();
647
648 EXPECT_EQ(OkStatus(), backend()
649 .compiler()
650 ->RunBackend(std::move(compiled_module),
651 backend().default_stream_executor(),
652 /*device_allocator=*/nullptr)
653 .status());
654 }
655
656 // A GTE inside of a fusion node inherits the layout of its operand (which
657 // should, if we keep following operands, eventually be a parameter).
TEST_F(LayoutAssignmentTest,GTEInheritsLayoutFromOperand)658 TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
659 const char* module_str = R"(
660 HloModule test_module
661
662 fused_computation {
663 fparam = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
664 gte0 = f32[2,2,2] get-tuple-element(fparam), index=0
665 gte1 = (f32[2,2,2], f32[2,2,2]) get-tuple-element(fparam), index=1
666 gte1a = f32[2,2,2] get-tuple-element(gte1), index=0
667 gte1b = f32[2,2,2] get-tuple-element(gte1), index=1
668 add = f32[2,2,2] add(gte1a, gte1b)
669 ROOT fresult = f32[2,2,2] add(gte0, add)
670 }
671
672 ENTRY entry_computation {
673 param = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
674 ROOT fusion =
675 f32[2,2,2] fusion(param), kind=kLoop, calls=fused_computation
676 }
677 )";
678
679 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
680 ParseAndReturnVerifiedModule(module_str));
681 ComputationLayout computation_layout(
682 m->entry_computation()->ComputeProgramShape());
683 Shape param_shape = ShapeUtil::MakeTupleShape(
684 {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
685 ShapeUtil::MakeTupleShape({
686 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {1, 2, 0}),
687 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {2, 0, 1}),
688 })});
689 TF_ASSERT_OK(
690 computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
691 param_shape));
692 computation_layout.mutable_result_layout()->ResetLayout(
693 LayoutUtil::MakeLayout({2, 1, 0}));
694 AssignLayouts(m.get(), &computation_layout);
695
696 EXPECT_THAT(LayoutOf(m.get(), "gte0"), ElementsAre(0, 1, 2));
697 EXPECT_THAT(LayoutOf(m.get(), "gte1a"), ElementsAre(1, 2, 0));
698 EXPECT_THAT(LayoutOf(m.get(), "gte1b"), ElementsAre(2, 0, 1));
699 EXPECT_THAT(LayoutOf(m.get(), "fresult"), ElementsAre(2, 1, 0));
700 EXPECT_THAT(FindInstruction(m.get(), "gte1")
701 ->shape()
702 .tuple_shapes(0)
703 .layout()
704 .minor_to_major(),
705 ElementsAre(1, 2, 0));
706 EXPECT_THAT(FindInstruction(m.get(), "gte1")
707 ->shape()
708 .tuple_shapes(1)
709 .layout()
710 .minor_to_major(),
711 ElementsAre(2, 0, 1));
712 }
713
TEST_F(LayoutAssignmentTest,ConditionalAsymmetricLayout)714 TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
715 auto builder = HloComputation::Builder(TestName());
716 auto m = CreateNewVerifiedModule();
717 Shape shape = ShapeUtil::MakeShape(F32, {128, 8});
718 Shape tshape = ShapeUtil::MakeTupleShape({shape, shape});
719 Shape result_tshape = ShapeUtil::MakeTupleShape({shape});
720
721 auto param0 = builder.AddInstruction(
722 HloInstruction::CreateParameter(0, shape, "param0"));
723 auto param1 = builder.AddInstruction(
724 HloInstruction::CreateParameter(1, shape, "param1"));
725 auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
726 2, ShapeUtil::MakeShape(PRED, {}), "param2"));
727 auto tuple =
728 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
729
730 auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch");
731 {
732 auto param = true_builder.AddInstruction(
733 HloInstruction::CreateParameter(0, tshape, "param"));
734 auto gte0 = true_builder.AddInstruction(
735 HloInstruction::CreateGetTupleElement(shape, param, 0));
736 auto gte1 = true_builder.AddInstruction(
737 HloInstruction::CreateGetTupleElement(shape, param, 1));
738 auto add = true_builder.AddInstruction(
739 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1));
740 true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
741 }
742 HloComputation* true_computation =
743 m->AddEmbeddedComputation(true_builder.Build());
744
745 auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch");
746 {
747 Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1});
748 false_builder.AddInstruction(
749 HloInstruction::CreateParameter(0, tshape, "param"));
750 // Using infeed as layout assignment does not mess up with it.
751 auto token = false_builder.AddInstruction(HloInstruction::CreateToken());
752 auto infeed = false_builder.AddInstruction(
753 HloInstruction::CreateInfeed(xshape, token, ""));
754 auto infeed_data = false_builder.AddInstruction(
755 HloInstruction::CreateGetTupleElement(xshape, infeed, 0));
756 false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data}));
757 }
758 HloComputation* false_computation =
759 m->AddEmbeddedComputation(false_builder.Build());
760 builder.AddInstruction(HloInstruction::CreateConditional(
761 result_tshape, pred, tuple, true_computation, tuple, false_computation));
762
763 HloComputation* computation = m->AddEntryComputation(builder.Build());
764 ComputationLayout computation_layout(computation->ComputeProgramShape());
765
766 AssignLayouts(m.get(), &computation_layout);
767
768 const HloInstruction* true_root = true_computation->root_instruction();
769 const HloInstruction* false_root = false_computation->root_instruction();
770 EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple);
771 EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple);
772
773 const HloInstruction* true_result = true_root->operand(0);
774 const HloInstruction* false_result = false_root->operand(0);
775 EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(),
776 false_result->shape().layout()));
777 EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
778 }
779
TEST_F(LayoutAssignmentTest,LayoutAssignmentToTupleSiblingOperand)780 TEST_F(LayoutAssignmentTest, LayoutAssignmentToTupleSiblingOperand) {
781 const char* const hlo_string = R"(
782 HloModule Module
783
784 true_branch {
785 tparam = (f64[2,3], f64[2,3]) parameter(0)
786 ROOT tgte = f64[2,3] get-tuple-element(tparam), index=1
787 }
788
789 false_branch {
790 ROOT Arg = f64[2,3] parameter(0)
791 }
792
793 ENTRY entry {
794 p0 = (f64[2,3], f64[2,3]) parameter(0)
795 p1 = f64[2,3] parameter(1)
796 constant = pred[] constant(true)
797 ROOT conditional = f64[2,3] conditional(constant, p0, p1),
798 true_computation=true_branch, false_computation=false_branch
799 }
800 )";
801 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
802
803 ComputationLayout computation_layout(
804 m->entry_computation()->ComputeProgramShape());
805 LayoutAssignment layout_assignment(&computation_layout);
806 Status error_status = layout_assignment.Run(m.get()).status();
807 EXPECT_TRUE(error_status.ok());
808 }
809
TEST_F(LayoutAssignmentTest,InternalErrorOnBitcast)810 TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
811 auto builder = HloComputation::Builder(TestName());
812 auto constant0 = builder.AddInstruction(
813 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
814 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
815 builder.AddInstruction(
816 HloInstruction::CreateBitcast(constant0->shape(), constant0));
817 auto m = CreateNewVerifiedModule();
818 m->AddEntryComputation(builder.Build());
819
820 ComputationLayout computation_layout(
821 m->entry_computation()->ComputeProgramShape());
822 LayoutAssignment layout_assignment(&computation_layout);
823 Status error_status = layout_assignment.Run(m.get()).status();
824 EXPECT_FALSE(error_status.ok());
825 EXPECT_THAT(
826 error_status.error_message(),
827 ::testing::HasSubstr(
828 "Unexpected bitcast operation seen during layout assignment"));
829 }
830
TEST_F(LayoutAssignmentTest,ChannelLayoutMismatch)831 TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
832 // Pin non matching layouts to parameter and root.
833 const char* module_str = R"(
834 HloModule test_module
835
836 ENTRY entry_computation {
837 param = (f32[2,2]) parameter(0)
838 gte = f32[2,2] get-tuple-element(param), index=0
839 token0 = token[] after-all()
840 recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1}
841 recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1,
842 sharding={maximal device=1}
843 ROOT root = f32[2,2] get-tuple-element(recv-done), index=0
844 send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1,
845 sharding={maximal device=0}
846 send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0}
847 }
848 )";
849
850 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
851 ParseAndReturnVerifiedModule(module_str));
852 ComputationLayout computation_layout(
853 m->entry_computation()->ComputeProgramShape());
854 Shape param_shape = ShapeUtil::MakeTupleShape(
855 {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
856 TF_ASSERT_OK(
857 computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
858 param_shape));
859 computation_layout.mutable_result_layout()->ResetLayout(
860 LayoutUtil::MakeLayout({1, 0}));
861
862 ChannelLayoutConstraints channel_constraints;
863 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
864
865 EXPECT_TRUE(ShapeUtil::Equal(FindInstruction(m.get(), "send")->shape(),
866 FindInstruction(m.get(), "recv")->shape()));
867 }
868
TEST_F(LayoutAssignmentTest,AllReduceLayoutMissmatch)869 TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
870 // Pin non matching layouts to parameter and root.
871 const char* module_str = R"(
872 HloModule test_module
873
874 add {
875 lhs = f32[] parameter(0)
876 rhs = f32[] parameter(1)
877 ROOT add = f32[] add(lhs, rhs)
878 }
879
880 ENTRY entry_computation {
881 param = (f32[2,2]) parameter(0)
882 gte = f32[2,2] get-tuple-element(param), index=0
883 ar.0 = f32[2,2] all-reduce(gte),
884 channel_id=1, replica_groups={{0}}, to_apply=add,
885 sharding={maximal device=0}
886 const = f32[2,2] constant({{0,1},{2,3}})
887 ROOT ar.1 = f32[2,2] all-reduce(const),
888 channel_id=1, replica_groups={{0}}, to_apply=add,
889 sharding={maximal device=1}
890 })";
891 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
892 ParseAndReturnVerifiedModule(module_str));
893 ComputationLayout computation_layout(
894 m->entry_computation()->ComputeProgramShape());
895 Shape param_shape = ShapeUtil::MakeTupleShape(
896 {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
897 TF_ASSERT_OK(
898 computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
899 param_shape));
900 computation_layout.mutable_result_layout()->ResetLayout(
901 LayoutUtil::MakeLayout({1, 0}));
902
903 ChannelLayoutConstraints channel_constraints;
904 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
905
906 EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
907 EXPECT_THAT(LayoutOf(m.get(), "ar.0"), ElementsAre(0, 1));
908 EXPECT_THAT(LayoutOf(m.get(), "ar.1"), ElementsAre(0, 1));
909 const HloInstruction* root = m->entry_computation()->root_instruction();
910 EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
911 }
912
TEST_F(LayoutAssignmentTest,CopySliceOperandToAvoidImplicitLayoutChange)913 TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
914 const char* module_str = R"(
915 HloModule CopySliceOperandToAvoidImplicitLayoutChange
916
917 ENTRY CopySliceOperandToAvoidImplicitLayoutChange {
918 par0 = f32[3,4]{1,0} parameter(0)
919 par1 = f32[4,5]{0,1} parameter(1)
920 slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]}
921 ROOT add0 = f32[3,4]{1,0} add(par0,slice0)
922 }
923 )";
924
925 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
926 ParseAndReturnVerifiedModule(module_str));
927 auto compiled_module =
928 backend()
929 .compiler()
930 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
931 /*device_allocator=*/nullptr)
932 .value();
933 HloInstruction* root =
934 compiled_module->entry_computation()->root_instruction();
935 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
936 EXPECT_THAT(
937 root,
938 GmockMatch(m::Add(
939 m::Parameter(),
940 m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy)))));
941 }
942
TEST_F(LayoutAssignmentTest,CopyDSliceOperandToAvoidImplicitLayoutChange)943 TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
944 const char* module_str = R"(
945 HloModule CopyDSliceOperandToAvoidImplicitLayoutChange
946
947 ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange {
948 par0 = f32[3,4]{1,0} parameter(0)
949 par1 = f32[4,5]{0,1} parameter(1)
950 par2 = s32[] parameter(2)
951 par3 = s32[] parameter(3)
952 dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4}
953 ROOT add0 = f32[3,4]{1,0} add(par0,dslice0)
954 }
955 )";
956
957 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
958 ParseAndReturnVerifiedModule(module_str));
959 auto compiled_module =
960 backend()
961 .compiler()
962 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
963 /*device_allocator=*/nullptr)
964 .value();
965 HloInstruction* root =
966 compiled_module->entry_computation()->root_instruction();
967 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
968 EXPECT_THAT(root,
969 GmockMatch(m::Add(
970 m::Parameter(),
971 m::DynamicSlice(
972 m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
973 m::Parameter(2), m::Parameter(3)))));
974 }
975
TEST_F(LayoutAssignmentTest,CopyConcatOperandToAvoidImplicitLayoutChange)976 TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
977 const char* module_str = R"(
978 HloModule CopyConcatOperandToAvoidImplicitLayoutChange
979
980 ENTRY CopyConcatOperandToAvoidImplicitLayoutChange {
981 par0 = f32[3,8]{1,0} parameter(0)
982 par1 = f32[3,5]{0,1} parameter(1)
983 par2 = f32[3,3]{1,0} parameter(2)
984 concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2),
985 dimensions={1}
986 ROOT add0 = f32[3,8]{1,0} add(par0,concat0)
987 }
988 )";
989
990 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
991 ParseAndReturnVerifiedModule(module_str));
992 auto compiled_module =
993 backend()
994 .compiler()
995 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
996 /*device_allocator=*/nullptr)
997 .value();
998 HloInstruction* root =
999 compiled_module->entry_computation()->root_instruction();
1000 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
1001 EXPECT_THAT(
1002 root,
1003 GmockMatch(m::Add(
1004 m::Parameter(),
1005 m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
1006 m::Parameter(2)))));
1007 }
1008
TEST_F(LayoutAssignmentTest,ConvolutionOperandWithImplicitLayoutChangeNotCopied)1009 TEST_F(LayoutAssignmentTest,
1010 ConvolutionOperandWithImplicitLayoutChangeNotCopied) {
1011 const char* module_str = R"(
1012 HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied
1013
1014 ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied {
1015 par0 = f32[128,3,230,230]{2,3,1,0} parameter(0)
1016 par1 = f32[7,7,3,64]{3,2,0,1} parameter(1)
1017 ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1),
1018 window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01,
1019 feature_group_count=1
1020 }
1021 )";
1022
1023 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1024 ParseAndReturnVerifiedModule(module_str));
1025 auto compiled_module =
1026 backend()
1027 .compiler()
1028 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1029 /*device_allocator=*/nullptr)
1030 .value();
1031 HloInstruction* root =
1032 compiled_module->entry_computation()->root_instruction();
1033 EXPECT_THAT(root,
1034 GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
1035 }
1036
TEST_F(LayoutAssignmentTest,PropagatingLayoutFromResultToOperand)1037 TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
1038 const char* module_str = R"(
1039 HloModule PropagatingLayoutFromResultToOperand
1040
1041 ENTRY PropagatingLayoutFromResultToOperand {
1042 par0 = f32[4,5]{1,0} parameter(0)
1043 ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]}
1044 }
1045 )";
1046
1047 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1048 ParseAndReturnVerifiedModule(module_str));
1049 auto compiled_module =
1050 backend()
1051 .compiler()
1052 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1053 /*device_allocator=*/nullptr)
1054 .value();
1055 HloInstruction* root =
1056 compiled_module->entry_computation()->root_instruction();
1057 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
1058 EXPECT_THAT(root,
1059 GmockMatch(m::Slice(
1060 m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy))));
1061 }
1062
TEST_F(LayoutAssignmentTest,TupleCopyOnLayoutMismatch)1063 TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
1064 // The first infeed uses layout {0,1}, while the second uses layout {1,0}.
1065 // The mismatch forces a copy of the tuple. The tuple contains a token, so
1066 // layout assignment will fail if it tries to copy the whole tuple.
1067 const char* module_str = R"(
1068 HloModule TupleCopyOnLayoutMismatch
1069
1070 condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] {
1071 tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1072 counter.1 = s32[] get-tuple-element(tup.1), index=0
1073 five = s32[] constant(5)
1074 ROOT lt = pred[] compare(counter.1, five), direction=LT
1075 }
1076
1077 body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {
1078 tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1079 counter.2 = s32[] get-tuple-element(tup.2), index=0
1080 tok.2 = token[] get-tuple-element(tup.2), index=1
1081
1082 ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2)
1083 next_tok = token[] get-tuple-element(ifeed.2), index=1
1084 next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0
1085
1086 one = s32[] constant(1)
1087 next_counter = s32[] add(counter.2, one)
1088 ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf)
1089 }
1090
1091 ENTRY main () -> f32[512,1024]{0,1} {
1092 start_tok = token[] after-all()
1093
1094 ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok)
1095 itok = token[] get-tuple-element(ifeed.3), index=1
1096 ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0
1097
1098 zero = s32[] constant(0)
1099 itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf)
1100
1101 loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2
1102 ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2
1103 }
1104 )";
1105
1106 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1107 ParseAndReturnVerifiedModule(module_str));
1108 ComputationLayout computation_layout(
1109 m->entry_computation()->ComputeProgramShape());
1110
1111 // Sanity check to verify that there's a layout mismatch.
1112 EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
1113 EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
1114
1115 AssignLayouts(m.get(), &computation_layout);
1116 SCOPED_TRACE(m->ToString());
1117
1118 // Make sure that layout assignment did not magically eliminate the mismatch,
1119 // in which case the test didn't prove anything.
1120 Layout layout01 = LayoutUtil::MakeLayout({0, 1});
1121 const HloInstruction* loop = nullptr;
1122 ASSERT_THAT(m->entry_computation()->root_instruction(),
1123 GmockMatch(m::GetTupleElement(
1124 m::Op(&loop)
1125 .WithOpcode(HloOpcode::kWhile)
1126 .WithOperand(0, m::Tuple(m::Op(), m::Op(),
1127 m::Copy(m::Op().WithShape(
1128 m::Shape().WithLayoutEqualTo(
1129 &layout01))))))));
1130
1131 Layout layout10 = LayoutUtil::MakeLayout({1, 0});
1132 EXPECT_THAT(loop->while_body()->root_instruction(),
1133 GmockMatch(m::Tuple(
1134 m::Op(), m::Op(),
1135 m::Op().WithShape(m::Shape().WithLayoutEqualTo(&layout10)))));
1136 }
1137
TEST_F(LayoutAssignmentTest,CustomCallNotLayoutConstrained)1138 TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) {
1139 const char* module_str = R"(
1140 HloModule CustomCallNotLayoutConstrained
1141
1142 ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
1143 %p = f32[42,2,3] parameter(0)
1144 ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz"
1145 }
1146 )";
1147 // Try with a couple different layouts. In each case the custom calls operand
1148 // and result layout should match that of the computation.
1149 {
1150 TF_ASSERT_OK_AND_ASSIGN(
1151 std::unique_ptr<VerifiedHloModule> m,
1152 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1153 ComputationLayout computation_layout = m->entry_computation_layout();
1154 *computation_layout.mutable_parameter_layout(0) =
1155 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1}));
1156 *computation_layout.mutable_result_layout() = ShapeLayout(
1157 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1}));
1158 AssignLayouts(m.get(), &computation_layout);
1159
1160 HloInstruction* root = m->entry_computation()->root_instruction();
1161 ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1162 ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
1163 ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
1164 }
1165 {
1166 TF_ASSERT_OK_AND_ASSIGN(
1167 std::unique_ptr<VerifiedHloModule> m,
1168 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1169 ComputationLayout computation_layout = m->entry_computation_layout();
1170 *computation_layout.mutable_parameter_layout(0) =
1171 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2}));
1172 *computation_layout.mutable_result_layout() = ShapeLayout(
1173 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1}));
1174 AssignLayouts(m.get(), &computation_layout);
1175
1176 HloInstruction* root = m->entry_computation()->root_instruction();
1177 ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1178 ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
1179 ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
1180 }
1181 }
1182
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrained)1183 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) {
1184 const char* module_str = R"(
1185 HloModule CustomCallLayoutConstrained
1186
1187 ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1188 %p0 = f32[4,4] parameter(0)
1189 %p1 = f32[2,3] parameter(1)
1190 ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
1191 }
1192 )";
1193 TF_ASSERT_OK_AND_ASSIGN(
1194 std::unique_ptr<VerifiedHloModule> m,
1195 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1196 ComputationLayout computation_layout = m->entry_computation_layout();
1197 *computation_layout.mutable_parameter_layout(0) =
1198 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1199 *computation_layout.mutable_parameter_layout(1) =
1200 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1201 *computation_layout.mutable_result_layout() = ShapeLayout(
1202 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1203 AssignLayouts(m.get(), &computation_layout);
1204
1205 // The custom call should be partially encapsulated in kCopy instructions
1206 // because of the layout mismatches.
1207 ASSERT_THAT(m->entry_computation()->root_instruction(),
1208 GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter()))));
1209
1210 const HloInstruction* custom_call =
1211 m->entry_computation()->root_instruction()->operand(0);
1212 ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1213 ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
1214 ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
1215 }
1216
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedAliasedOutput)1217 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedAliasedOutput) {
1218 const char* module_str = R"(
1219 HloModule customcall.4
1220
1221 ENTRY %customcall.4 (parameter.1: f32[8,128], parameter.2: f32[8,128]) -> f32[8,128] {
1222 %parameter.1 = f32[8,128]{1,0} parameter(0)
1223 %parameter.2 = f32[8,128]{1,0} parameter(1)
1224 ROOT %custom-call.3 = f32[8,128]{1,0} custom-call(f32[8,128]{1,0} %parameter.1, f32[8,128]{1,0} %parameter.2), custom_call_target="gpu_example_custom_call", operand_layout_constraints={f32[8,128]{1,0}, f32[8,128]{1,0}}, custom_call_has_side_effect=true, output_to_operand_aliasing={{}: (0, {})}
1225 })";
1226 TF_ASSERT_OK_AND_ASSIGN(
1227 std::unique_ptr<VerifiedHloModule> m,
1228 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1229 ComputationLayout computation_layout = m->entry_computation_layout();
1230 *computation_layout.mutable_parameter_layout(0) =
1231 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {8, 128}, {1, 0}));
1232 *computation_layout.mutable_parameter_layout(1) =
1233 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {8, 128}, {1, 0}));
1234 *computation_layout.mutable_result_layout() =
1235 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {8, 128}, {1, 0}));
1236 AssignLayouts(m.get(), &computation_layout);
1237
1238 const HloInstruction* custom_call =
1239 m->entry_computation()->root_instruction();
1240 ExpectLayoutIs(custom_call->shape(), {1, 0});
1241 ExpectLayoutIs(custom_call->operand(0)->shape(), {1, 0});
1242 ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
1243 }
1244
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedZeroOperands)1245 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) {
1246 const char* module_str = R"(
1247 HloModule CustomCallLayoutConstrainedZeroOperands
1248
1249 ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
1250 ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
1251 }
1252 )";
1253 TF_ASSERT_OK_AND_ASSIGN(
1254 std::unique_ptr<VerifiedHloModule> m,
1255 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1256 ComputationLayout computation_layout = m->entry_computation_layout();
1257 *computation_layout.mutable_result_layout() = ShapeLayout(
1258 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1259 AssignLayouts(m.get(), &computation_layout);
1260
1261 ASSERT_THAT(m->entry_computation()->root_instruction(),
1262 GmockMatch(m::Copy(m::CustomCall())));
1263
1264 const HloInstruction* custom_call =
1265 m->entry_computation()->root_instruction()->operand(0);
1266 ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1267 }
1268
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleOperand)1269 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) {
1270 const char* module_str = R"(
1271 HloModule CustomCallLayoutConstrainedTupleOperand
1272
1273 ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1274 %p0 = f32[4,4] parameter(0)
1275 %p1 = f32[2,3] parameter(1)
1276 %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1)
1277 ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})}
1278 }
1279 )";
1280 TF_ASSERT_OK_AND_ASSIGN(
1281 std::unique_ptr<VerifiedHloModule> m,
1282 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1283 ComputationLayout computation_layout = m->entry_computation_layout();
1284 *computation_layout.mutable_parameter_layout(0) =
1285 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1286 *computation_layout.mutable_parameter_layout(1) =
1287 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1288 *computation_layout.mutable_result_layout() = ShapeLayout(
1289 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1290 AssignLayouts(m.get(), &computation_layout);
1291
1292 HloInstruction* root = m->entry_computation()->root_instruction();
1293 ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
1294
1295 ASSERT_THAT(m->entry_computation()->root_instruction(),
1296 GmockMatch(m::Copy(m::CustomCall(m::Tuple()))));
1297
1298 const HloInstruction* custom_call =
1299 m->entry_computation()->root_instruction()->operand(0);
1300 ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1301 ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}});
1302 }
1303
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleResult)1304 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) {
1305 const char* module_str = R"(
1306 HloModule CustomCallLayoutConstrainedTupleResult
1307
1308 ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) {
1309 %p0 = f32[4,4] parameter(0)
1310 ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}}
1311 }
1312 )";
1313 // Try with a couple different layouts. In each case the custom calls operand
1314 // and result layout should match that of the computation.
1315 TF_ASSERT_OK_AND_ASSIGN(
1316 std::unique_ptr<VerifiedHloModule> m,
1317 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1318 ComputationLayout computation_layout = m->entry_computation_layout();
1319 *computation_layout.mutable_parameter_layout(0) =
1320 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1321 *computation_layout.mutable_result_layout() =
1322 ShapeLayout(ShapeUtil::MakeTupleShape(
1323 {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}),
1324 ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})}));
1325 AssignLayouts(m.get(), &computation_layout);
1326
1327 ExpectTupleLayoutIs(m->result_shape(), {{1, 0}, {1, 0}});
1328
1329 const HloInstruction* custom_call = FindInstruction(m.get(), "custom-call");
1330 ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
1331 }
1332
AssignLayoutsToComputation(HloModule * m,ChannelLayoutConstraints * channel_constraints=nullptr)1333 Status AssignLayoutsToComputation(
1334 HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) {
1335 if (!m->entry_computation_layout().result_layout().LayoutIsSet()) {
1336 m->mutable_entry_computation_layout()
1337 ->mutable_result_layout()
1338 ->SetToDefaultLayout();
1339 }
1340 LayoutAssignment layout_assignment(m->mutable_entry_computation_layout(),
1341 channel_constraints);
1342 return layout_assignment.Run(m).status();
1343 }
1344
TEST_F(LayoutAssignmentTest,OverwriteDiamondShapedConstraintsX)1345 TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) {
1346 // Check that we handle a diamond-shaped graph correctly.
1347 // transpose
1348 // / \
1349 // add |
1350 // \ /
1351 // tuple
1352
1353 auto b = HloComputation::Builder(TestName());
1354 Shape ashape = ShapeUtil::MakeShape(F32, {12, 8});
1355 Shape bshape = ShapeUtil::MakeShape(F32, {8, 12});
1356 auto param0 =
1357 b.AddInstruction(HloInstruction::CreateParameter(0, bshape, "input"));
1358 auto param1 =
1359 b.AddInstruction(HloInstruction::CreateParameter(1, ashape, "input"));
1360 auto transpose =
1361 b.AddInstruction(HloInstruction::CreateTranspose(ashape, param0, {1, 0}));
1362 auto add = b.AddInstruction(
1363 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1));
1364 b.AddInstruction(HloInstruction::CreateTuple({add, transpose}));
1365 auto m = CreateNewVerifiedModule();
1366 m->AddEntryComputation(b.Build());
1367 Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0});
1368 Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1});
1369 *m->mutable_entry_computation_layout()->mutable_result_layout() =
1370 ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor}));
1371 const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0});
1372 ForceParameterLayout(m.get(), 0, r2_dim0major);
1373 ForceParameterLayout(m.get(), 1, r2_dim0major);
1374 TF_ASSERT_OK(AssignLayoutsToComputation(m.get()));
1375
1376 EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0));
1377 EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(),
1378 ElementsAre(1, 0));
1379 EXPECT_THAT(add->operand(1)->shape().layout().minor_to_major(),
1380 ElementsAre(1, 0));
1381
1382 EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(0, 1));
1383 }
1384
1385 // Tests that the layout assignment supports layout-constrained all-reduce with
1386 // different operand layouts (b/146056839).
TEST_F(LayoutAssignmentTest,LayoutConstrainedAllReduce)1387 TEST_F(LayoutAssignmentTest, LayoutConstrainedAllReduce) {
1388 const char* module_str = R"(
1389 HloModule test_module
1390
1391 add {
1392 lhs = f32[] parameter(0)
1393 rhs = f32[] parameter(1)
1394 ROOT add = f32[] add(lhs, rhs)
1395 }
1396
1397 ENTRY entry_computation {
1398 param = (f32[8,4]{0,1}, f32[16,2]{0,1}) parameter(0)
1399 gte0 = f32[8,4] get-tuple-element(param), index=0
1400 gte1 = f32[16,2] get-tuple-element(param), index=1
1401 crs = (f32[8,4]{0,1}, f32[16,2]{1,0}) all-reduce(gte0, gte1),
1402 replica_groups={}, constrain_layout=true, to_apply=add
1403 gte2 = f32[8,4] get-tuple-element(crs), index=0
1404 gte3 = f32[16,2] get-tuple-element(crs), index=1
1405 ROOT result = (f32[8,4]{1,0}, f32[16,2]{1,0}) tuple(gte2, gte3)
1406 }
1407 )";
1408
1409 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1410 ParseAndReturnVerifiedModule(module_str));
1411 ComputationLayout computation_layout(
1412 m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1413
1414 ChannelLayoutConstraints channel_constraints;
1415 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
1416
1417 const HloInstruction* crs = FindInstruction(m.get(), "crs");
1418 ExpectTupleLayoutIs(crs->shape(), {{0, 1}, {1, 0}});
1419 ExpectLayoutIs(crs->operand(0)->shape(), {0, 1});
1420 ExpectLayoutIs(crs->operand(1)->shape(), {1, 0});
1421 }
1422
TEST_F(LayoutAssignmentTest,LayoutConstrainedAllToAll)1423 TEST_F(LayoutAssignmentTest, LayoutConstrainedAllToAll) {
1424 const char* module_str = R"(
1425 HloModule test_module
1426
1427 add {
1428 lhs = f32[] parameter(0)
1429 rhs = f32[] parameter(1)
1430 ROOT add = f32[] add(lhs, rhs)
1431 }
1432
1433 ENTRY entry_computation {
1434 param = (f32[16,4]{0,1}, f32[16,4]{1,0}) parameter(0)
1435 gte0 = f32[16,4] get-tuple-element(param), index=0
1436 gte1 = f32[16,4] get-tuple-element(param), index=1
1437 alltoall = (f32[16,4]{1,0}, f32[16,4]{1,0}) all-reduce(gte0, gte1),
1438 replica_groups={{0,1}}, constrain_layout=true, to_apply=add
1439 gte2 = f32[16,4] get-tuple-element(alltoall), index=0
1440 gte3 = f32[16,4] get-tuple-element(alltoall), index=1
1441 ROOT concat = f32[16,8]{0,1} concatenate(gte2, gte3), dimensions={1}
1442 }
1443 )";
1444
1445 TF_ASSERT_OK_AND_ASSIGN(
1446 std::unique_ptr<HloModule> m,
1447 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1448 ComputationLayout computation_layout(
1449 m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1450
1451 ChannelLayoutConstraints channel_constraints;
1452 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
1453
1454 const HloInstruction* alltoall = FindInstruction(m.get(), "alltoall");
1455 ExpectTupleLayoutIs(alltoall->shape(), {{1, 0}, {1, 0}});
1456 ExpectLayoutIs(alltoall->operand(0)->shape(), {1, 0});
1457 ExpectLayoutIs(alltoall->operand(1)->shape(), {1, 0});
1458 }
1459
TEST_F(LayoutAssignmentTest,DynamicRoot)1460 TEST_F(LayoutAssignmentTest, DynamicRoot) {
1461 const char* module_str = R"(
1462 HloModule test_module
1463
1464 ENTRY entry_computation {
1465 param = f32[1,<=16]{0,1} parameter(0)
1466 ROOT abs = f32[1,<=16]{0,1} abs(param)
1467 }
1468 )";
1469
1470 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1471 ParseAndReturnVerifiedModule(module_str));
1472 ComputationLayout computation_layout(
1473 m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1474 computation_layout.mutable_result_layout()->ClearDynamicShape();
1475
1476 AssignLayouts(m.get(), &computation_layout);
1477
1478 const HloInstruction* abs = FindInstruction(m.get(), "abs");
1479 ExpectLayoutIs(abs->operand(0)->shape(), {0, 1});
1480 ExpectLayoutIs(abs->shape(), {0, 1});
1481 EXPECT_TRUE(abs->shape().is_dynamic_dimension(1));
1482 }
1483
1484 // Test the ability to avoid copying across computations by reversing
1485 // computation traversal order.
TEST_F(LayoutAssignmentTest,ReverseComputationOrderAvoidCopy)1486 TEST_F(LayoutAssignmentTest, ReverseComputationOrderAvoidCopy) {
1487 const char* module_str = R"(
1488 HloModule ComputationLayoutAvoidCopy
1489
1490 call_1 {
1491 %arg_tuple.1 = (f32[93184,4]) parameter(0)
1492 %get-tuple-element.1 = f32[93184,4] get-tuple-element(%arg_tuple.1), index=0
1493 ROOT %reshape.8494 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0}%get-tuple-element.1)
1494 }
1495
1496 on_true {
1497 %arg_tuple.1 = (f32[93184,4]) parameter(0)
1498 %get-tuple-element.1 = f32[93184,4] get-tuple-element(%arg_tuple.1), index=0
1499 ROOT %reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0}%get-tuple-element.1)
1500 }
1501
1502 on_false {
1503 %arg_tuple.2 = (f32[93184,4]) parameter(0)
1504 %get-tuple-element.3 = f32[93184,4] get-tuple-element(%arg_tuple.2), index=0
1505 %reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0}%get-tuple-element.3)
1506 ROOT %add = f32[2,512,364] add(%reshape.9717, %reshape.9717)
1507 }
1508
1509 ENTRY main {
1510 pred.1 = pred[] parameter(0)
1511 arg.2 = f32[93184,4]{1,0} parameter(1)
1512 arg_tuple.11 = (f32[93184,4]{1,0}) tuple(arg.2)
1513 call.1 = f32[2,512,364] call(arg_tuple.11), to_apply=call_1
1514 conditional = f32[2,512,364] conditional(pred.1, arg_tuple.11, arg_tuple.11),
1515 true_computation=on_true, false_computation=on_false
1516 ROOT add = f32[2,512,364] add(call.1, conditional)
1517 }
1518 )";
1519
1520 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1521 ParseAndReturnVerifiedModule(module_str));
1522 ComputationLayout computation_layout(
1523 m->entry_computation()->ComputeProgramShape());
1524 *computation_layout.mutable_parameter_layout(0) =
1525 ShapeLayout(ShapeUtil::MakeShape(PRED, {}));
1526 *computation_layout.mutable_parameter_layout(1) =
1527 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {93184, 4}, {0, 1}));
1528 *computation_layout.mutable_result_layout() = ShapeLayout(
1529 ShapeUtil::MakeShapeWithLayout(F32, {2, 512, 364}, {0, 1, 2}));
1530 ChannelLayoutConstraints channel_constraints;
1531 LayoutAssignment layout_assignment(
1532 &computation_layout,
1533 /*channel_constraints=*/&channel_constraints,
1534 /* reverse_computation_order = */ true);
1535 EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
1536 const HloInstruction* call_1 = FindInstruction(m.get(), "reshape.8494");
1537 ExpectLayoutIs(call_1->shape(), {0, 1, 2});
1538 const HloInstruction* on_true = FindInstruction(m.get(), "reshape.8493");
1539 ExpectLayoutIs(on_true->shape(), {0, 1, 2});
1540 const HloInstruction* on_false = FindInstruction(m.get(), "reshape.9717");
1541 ExpectLayoutIs(on_false->shape(), {0, 1, 2});
1542 }
1543
1544 // Test the ability to propagate operand constraints across multiple operations.
TEST_F(LayoutAssignmentTest,PropagateOperandLayout)1545 TEST_F(LayoutAssignmentTest, PropagateOperandLayout) {
1546 const char* module_str = R"(
1547 HloModule ComputationPropagateOperandLayout
1548
1549 %scalar_add_computation.1 (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
1550 %scalar_lhs.1 = f32[]{:T(256)} parameter(0)
1551 %scalar_rhs.1 = f32[]{:T(256)} parameter(1)
1552 ROOT %add.25 = f32[]{:T(256)} add(f32[]{:T(256)} %scalar_lhs.1, f32[]{:T(256)} %scalar_rhs.1)
1553 }
1554
1555 ENTRY main {
1556 %convolution-base-dilated = f32[64,243,243,10]{0,3,2,1:T(8,128)} parameter(0)
1557 %reduce.13 = f32[64,243,243]{2,1,0:T(8,128)} parameter(1)
1558 %divide.2 = f32[64,243,243,10]{2,3,1,0:T(8,128)} parameter(2)
1559 %reshape.32 = f32[384,10,1,1]{0,1,3,2:T(8,128)} parameter(3)
1560 %subtract = f32[64,243,243,384]{3,0,2,1:T(8,128)} parameter(4)
1561 %constant = f32[]{:T(256)} constant(3779136)
1562 %broadcast.71 = f32[64,243,243,384]{3,0,2,1:T(8,128)} parameter(5)
1563 %broadcast.46 = f32[64,243,243,10] broadcast(f32[64,243,243] %reduce.13), dimensions={0,1,2}
1564 %subtract.14 = f32[64,243,243,10] subtract(f32[64,243,243,10] %convolution-base-dilated, f32[64,243,243,10] %broadcast.46)
1565 %multiply.22 = f32[64,243,243,10] multiply(f32[64,243,243,10] %subtract.14, f32[64,243,243,10]{2,3,1,0:T(8,128)} %divide.2)
1566 %convolution.9 = f32[64,243,243,384] convolution(f32[64,243,243,10] %multiply.22, f32[384,10,1,1] %reshape.32), window={size=1x1}, dim_labels=01bf_oi01->01bf
1567 %reduce.14 = f32[384] reduce(f32[64,243,243,384] %convolution.9, f32[]{:T(256)} %constant), dimensions={0,1,2}, to_apply=%scalar_add_computation.1
1568 %multiply.24 = f32[64,243,243,384] multiply(f32[64,243,243,384] %convolution.9, f32[64,243,243,384] %subtract)
1569 %reduce.15 = f32[384] reduce(f32[64,243,243,384] %multiply.24, f32[] %constant), dimensions={0,1,2}, to_apply=%scalar_add_computation.1
1570 %broadcast.47 = f32[64,243,243,384] broadcast(f32[] %constant), dimensions={}
1571 %multiply.23 = f32[64,243,243,384] multiply(f32[64,243,243,384] %convolution.9, f32[64,243,243,384] %broadcast.47)
1572 %broadcast.48 = f32[64,243,243,384]{3,2,1,0:T(8,128)} broadcast(f32[384]{0:T(512)} %reduce.14), dimensions={3}
1573 %subtract.15 = f32[64,243,243,384] subtract(f32[64,243,243,384] %multiply.23, f32[64,243,243,384] %broadcast.48)
1574 %broadcast.50 = f32[64,243,243,384] broadcast(f32[384] %reduce.15), dimensions={3}
1575 %multiply.25 = f32[64,243,243,384] multiply(f32[64,243,243,384] %broadcast.50, f32[64,243,243,384] %subtract)
1576 %divide.7 = f32[64,243,243,384]{3,0,2,1:T(8,128)} divide(f32[64,243,243,384] %multiply.25, f32[64,243,243,384] %broadcast.71)
1577 ROOT %subtract.17 = f32[64,243,243,384]{3,0,2,1:T(8,128)} subtract(f32[64,243,243,384] %subtract.15, f32[64,243,243,384] %divide.7)
1578 }
1579 )";
1580
1581 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1582 ParseAndReturnVerifiedModule(module_str));
1583 ComputationLayout computation_layout(
1584 m->entry_computation()->ComputeProgramShape());
1585 *computation_layout.mutable_parameter_layout(0) = ShapeLayout(
1586 ShapeUtil::MakeShapeWithLayout(F32, {64, 243, 243, 10}, {0, 3, 2, 1}));
1587 *computation_layout.mutable_parameter_layout(1) = ShapeLayout(
1588 ShapeUtil::MakeShapeWithLayout(F32, {64, 243, 243}, {2, 1, 0}));
1589 *computation_layout.mutable_parameter_layout(2) = ShapeLayout(
1590 ShapeUtil::MakeShapeWithLayout(F32, {64, 243, 243, 10}, {2, 3, 1, 0}));
1591 *computation_layout.mutable_parameter_layout(3) = ShapeLayout(
1592 ShapeUtil::MakeShapeWithLayout(F32, {384, 10, 1, 1}, {0, 1, 3, 2}));
1593 *computation_layout.mutable_parameter_layout(4) = ShapeLayout(
1594 ShapeUtil::MakeShapeWithLayout(F32, {64, 243, 243, 384}, {3, 0, 2, 1}));
1595 *computation_layout.mutable_result_layout() = ShapeLayout(
1596 ShapeUtil::MakeShapeWithLayout(F32, {64, 243, 243, 384}, {3, 0, 2, 1}));
1597 ChannelLayoutConstraints channel_constraints;
1598 LayoutAssignment layout_assignment(
1599 &computation_layout,
1600 /*channel_constraints=*/&channel_constraints);
1601 EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
1602 const HloInstruction* subtract_15 = FindInstruction(m.get(), "subtract.15");
1603 ExpectLayoutIs(subtract_15->shape(), {3, 0, 2, 1});
1604 const HloInstruction* broadcast_46 = FindInstruction(m.get(), "broadcast.46");
1605 ExpectLayoutIs(broadcast_46->shape(), {2, 3, 1, 0});
1606 const HloInstruction* subtract_14 = FindInstruction(m.get(), "subtract.14");
1607 ExpectLayoutIs(subtract_14->shape(), {2, 3, 1, 0});
1608 }
1609
1610 } // namespace
1611 } // namespace xla
1612