xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/map_inliner_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/map_inliner.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 
32 namespace op = xla::testing::opcode_matchers;
33 
34 namespace xla {
35 namespace {
36 
37 using MapInlinerTest = HloTestBase;
38 
39 // Test that `map` with `max` is transformed to `max`
TEST_F(MapInlinerTest,MapMax)40 TEST_F(MapInlinerTest, MapMax) {
41   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
42 
43   auto max_builder = HloComputation::Builder(TestName());
44   auto param1 = max_builder.AddInstruction(
45       HloInstruction::CreateParameter(0, r0f32, "x"));
46   auto param2 = max_builder.AddInstruction(
47       HloInstruction::CreateParameter(1, r0f32, "y"));
48   max_builder.AddInstruction(HloInstruction::CreateBinary(
49       param1->shape(), HloOpcode::kMaximum, param1, param2));
50   auto max_f32 = max_builder.Build();
51 
52   auto builder = HloComputation::Builder("MapMaxFunction");
53   auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
54       LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
55   auto rhs = builder.AddInstruction(HloInstruction::CreateConstant(
56       LiteralUtil::CreateR1<float>({4, 3, 2, 1})));
57   builder.AddInstruction(
58       HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get()));
59 
60   auto computation = builder.Build();
61   auto hlo_module = CreateNewVerifiedModule();
62   hlo_module->AddEmbeddedComputation(std::move(max_f32));
63   hlo_module->AddEntryComputation(std::move(computation));
64 
65   MapInliner inliner;
66   EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
67   EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
68               op::Maximum(lhs, rhs));
69 
70   // Verify execution on CPU.
71   auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
72   auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
73   EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
74 }
75 
76 // Test that `constant` function is changed to `broadcast`.
TEST_F(MapInlinerTest,MapConstant)77 TEST_F(MapInlinerTest, MapConstant) {
78   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
79 
80   auto const2_builder = HloComputation::Builder(TestName());
81   auto param1 = const2_builder.AddInstruction(
82       HloInstruction::CreateParameter(0, r0f32, "x"));
83   (void)param1;
84   const2_builder.AddInstruction(
85       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
86   auto const2_f32 = const2_builder.Build();
87 
88   auto builder = HloComputation::Builder("MapConstFunction");
89   auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
90       LiteralUtil::CreateR2<float>({{1, 2, 3, 4}, {5, 6, 7, 8}})));
91   builder.AddInstruction(
92       HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get()));
93 
94   auto computation = builder.Build();
95   auto hlo_module = CreateNewVerifiedModule();
96   hlo_module->AddEmbeddedComputation(std::move(const2_f32));
97   hlo_module->AddEntryComputation(std::move(computation));
98   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
99   MapInliner inliner;
100   EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
101   root = hlo_module->entry_computation()->root_instruction();
102   EXPECT_THAT(root, op::Broadcast(op::Constant()));
103 
104   // Verify execution on CPU.
105   auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
106   auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
107   EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
108 }
109 
TEST_F(MapInlinerTest,MapSubtractOppositeOrder)110 TEST_F(MapInlinerTest, MapSubtractOppositeOrder) {
111   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
112 
113   // Note that the parameter ordinals are in the opposite order to their
114   // position as operands
115   auto max_builder = HloComputation::Builder(TestName());
116   auto param1 = max_builder.AddInstruction(
117           HloInstruction::CreateParameter(1, r0f32, "x"));
118   auto param2 = max_builder.AddInstruction(
119           HloInstruction::CreateParameter(0, r0f32, "y"));
120   max_builder.AddInstruction(HloInstruction::CreateBinary(
121           param1->shape(), HloOpcode::kSubtract, param1, param2));
122   auto max_f32 = max_builder.Build();
123 
124   auto builder = HloComputation::Builder("MapSubFunction");
125   auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
126       LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
127   auto rhs = builder.AddInstruction(HloInstruction::CreateConstant(
128       LiteralUtil::CreateR1<float>({4, 3, 2, 1})));
129   builder.AddInstruction(
130     HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get()));
131 
132   auto computation = builder.Build();
133   auto hlo_module = CreateNewVerifiedModule();
134   hlo_module->AddEmbeddedComputation(std::move(max_f32));
135   hlo_module->AddEntryComputation(std::move(computation));
136 
137   MapInliner inliner;
138   EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
139   EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
140           op::Subtract(rhs, lhs));
141 
142   // Verify execution on CPU.
143   auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
144   auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
145   EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
146 }
147 
TEST_F(MapInlinerTest,MapParameter)148 TEST_F(MapInlinerTest, MapParameter) {
149   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
150 
151   auto param_builder = HloComputation::Builder(TestName());
152   param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0"));
153   param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1"));
154   auto param_f32 = param_builder.Build();
155 
156   auto builder = HloComputation::Builder("MapParamFunction");
157   auto lhs = builder.AddInstruction(
158       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
159   auto rhs = builder.AddInstruction(
160       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4)));
161   builder.AddInstruction(
162       HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get()));
163 
164   auto computation = builder.Build();
165   auto hlo_module = CreateNewVerifiedModule();
166   hlo_module->AddEmbeddedComputation(std::move(param_f32));
167   hlo_module->AddEntryComputation(std::move(computation));
168 
169   MapInliner inliner;
170   EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
171   EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs);
172 
173   // Verify execution on CPU.
174   auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
175   auto expected = LiteralUtil::CreateR0<float>(4);
176   EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
177 }
178 
179 }  // namespace
180 }  // namespace xla
181