1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/map_inliner.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31
32 namespace op = xla::testing::opcode_matchers;
33
34 namespace xla {
35 namespace {
36
37 using MapInlinerTest = HloTestBase;
38
39 // Test that `map` with `max` is transformed to `max`
TEST_F(MapInlinerTest,MapMax)40 TEST_F(MapInlinerTest, MapMax) {
41 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
42
43 auto max_builder = HloComputation::Builder(TestName());
44 auto param1 = max_builder.AddInstruction(
45 HloInstruction::CreateParameter(0, r0f32, "x"));
46 auto param2 = max_builder.AddInstruction(
47 HloInstruction::CreateParameter(1, r0f32, "y"));
48 max_builder.AddInstruction(HloInstruction::CreateBinary(
49 param1->shape(), HloOpcode::kMaximum, param1, param2));
50 auto max_f32 = max_builder.Build();
51
52 auto builder = HloComputation::Builder("MapMaxFunction");
53 auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
54 LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
55 auto rhs = builder.AddInstruction(HloInstruction::CreateConstant(
56 LiteralUtil::CreateR1<float>({4, 3, 2, 1})));
57 builder.AddInstruction(
58 HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get()));
59
60 auto computation = builder.Build();
61 auto hlo_module = CreateNewVerifiedModule();
62 hlo_module->AddEmbeddedComputation(std::move(max_f32));
63 hlo_module->AddEntryComputation(std::move(computation));
64
65 MapInliner inliner;
66 EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
67 EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
68 op::Maximum(lhs, rhs));
69
70 // Verify execution on CPU.
71 auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
72 auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
73 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
74 }
75
76 // Test that `constant` function is changed to `broadcast`.
TEST_F(MapInlinerTest,MapConstant)77 TEST_F(MapInlinerTest, MapConstant) {
78 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
79
80 auto const2_builder = HloComputation::Builder(TestName());
81 auto param1 = const2_builder.AddInstruction(
82 HloInstruction::CreateParameter(0, r0f32, "x"));
83 (void)param1;
84 const2_builder.AddInstruction(
85 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
86 auto const2_f32 = const2_builder.Build();
87
88 auto builder = HloComputation::Builder("MapConstFunction");
89 auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
90 LiteralUtil::CreateR2<float>({{1, 2, 3, 4}, {5, 6, 7, 8}})));
91 builder.AddInstruction(
92 HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get()));
93
94 auto computation = builder.Build();
95 auto hlo_module = CreateNewVerifiedModule();
96 hlo_module->AddEmbeddedComputation(std::move(const2_f32));
97 hlo_module->AddEntryComputation(std::move(computation));
98 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
99 MapInliner inliner;
100 EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
101 root = hlo_module->entry_computation()->root_instruction();
102 EXPECT_THAT(root, op::Broadcast(op::Constant()));
103
104 // Verify execution on CPU.
105 auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
106 auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
107 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
108 }
109
TEST_F(MapInlinerTest,MapSubtractOppositeOrder)110 TEST_F(MapInlinerTest, MapSubtractOppositeOrder) {
111 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
112
113 // Note that the parameter ordinals are in the opposite order to their
114 // position as operands
115 auto max_builder = HloComputation::Builder(TestName());
116 auto param1 = max_builder.AddInstruction(
117 HloInstruction::CreateParameter(1, r0f32, "x"));
118 auto param2 = max_builder.AddInstruction(
119 HloInstruction::CreateParameter(0, r0f32, "y"));
120 max_builder.AddInstruction(HloInstruction::CreateBinary(
121 param1->shape(), HloOpcode::kSubtract, param1, param2));
122 auto max_f32 = max_builder.Build();
123
124 auto builder = HloComputation::Builder("MapSubFunction");
125 auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
126 LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
127 auto rhs = builder.AddInstruction(HloInstruction::CreateConstant(
128 LiteralUtil::CreateR1<float>({4, 3, 2, 1})));
129 builder.AddInstruction(
130 HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get()));
131
132 auto computation = builder.Build();
133 auto hlo_module = CreateNewVerifiedModule();
134 hlo_module->AddEmbeddedComputation(std::move(max_f32));
135 hlo_module->AddEntryComputation(std::move(computation));
136
137 MapInliner inliner;
138 EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
139 EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
140 op::Subtract(rhs, lhs));
141
142 // Verify execution on CPU.
143 auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
144 auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
145 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
146 }
147
TEST_F(MapInlinerTest,MapParameter)148 TEST_F(MapInlinerTest, MapParameter) {
149 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
150
151 auto param_builder = HloComputation::Builder(TestName());
152 param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0"));
153 param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1"));
154 auto param_f32 = param_builder.Build();
155
156 auto builder = HloComputation::Builder("MapParamFunction");
157 auto lhs = builder.AddInstruction(
158 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
159 auto rhs = builder.AddInstruction(
160 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4)));
161 builder.AddInstruction(
162 HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get()));
163
164 auto computation = builder.Build();
165 auto hlo_module = CreateNewVerifiedModule();
166 hlo_module->AddEmbeddedComputation(std::move(param_f32));
167 hlo_module->AddEntryComputation(std::move(computation));
168
169 MapInliner inliner;
170 EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
171 EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs);
172
173 // Verify execution on CPU.
174 auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
175 auto expected = LiteralUtil::CreateR0<float>(4);
176 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
177 }
178
179 } // namespace
180 } // namespace xla
181