1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
17
18 #include <cstdint>
19 #include <tuple>
20
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/MLIRContext.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
30 #include "tensorflow/core/util/device_name_utils.h"
31
32 namespace tensorflow {
33 namespace {
34
35 using Device = DeviceNameUtils::ParsedName;
36
DeviceNamesToParsedNames(llvm::ArrayRef<std::string> device_names,llvm::SmallVectorImpl<Device> * parsed_devices)37 bool DeviceNamesToParsedNames(llvm::ArrayRef<std::string> device_names,
38 llvm::SmallVectorImpl<Device>* parsed_devices) {
39 parsed_devices->reserve(device_names.size());
40 for (const auto& device_name : device_names) {
41 Device parsed_name;
42 if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name))
43 return false;
44
45 parsed_devices->push_back(parsed_name);
46 }
47 return true;
48 }
49
50 using DeviceNames = llvm::SmallVector<std::string, 8>;
51
52 struct ParameterizedDeviceSetTest
53 : ::testing::TestWithParam<std::tuple<DeviceNames, std::string>> {};
54
TEST_P(ParameterizedDeviceSetTest,BadDeviceSet)55 TEST_P(ParameterizedDeviceSetTest, BadDeviceSet) {
56 llvm::SmallVector<Device, 8> devices;
57 ASSERT_TRUE(DeviceNamesToParsedNames(std::get<0>(GetParam()), &devices));
58 std::string topology_attr;
59 std::vector<int64_t> device_assignment_attr;
60
61 auto status_or = GetTPUCompilationAndExecutionDevices(
62 devices, /*num_replicas=*/1, /*num_cores_per_replica=*/1, topology_attr,
63 device_assignment_attr);
64 ASSERT_FALSE(status_or.ok());
65 EXPECT_EQ(status_or.status().error_message(), std::get<1>(GetParam()));
66 }
67
68 INSTANTIATE_TEST_SUITE_P(
69 BadDeviceSet, ParameterizedDeviceSetTest,
70 ::testing::Values(
71 std::make_tuple<DeviceNames, std::string>(
72 {"/job:localhost/replica:0/task:0/device:CPU:0"},
73 "no TPU_SYSTEM devices found"),
74 std::make_tuple<DeviceNames, std::string>(
75 {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
76 "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0"},
77 "found TPU_SYSTEM devices with conflicting jobs 'localhost' and "
78 "'worker'"),
79 std::make_tuple<DeviceNames, std::string>(
80 {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
81 "/job:localhost/replica:1/task:0/device:TPU_SYSTEM:0"},
82 "found TPU_SYSTEM devices with conflicting replicas '0' and '1'"),
83 std::make_tuple<DeviceNames, std::string>(
84 {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
85 "/job:localhost/replica:0/task:0/device:TPU:0",
86 "/job:localhost/replica:0/task:0/device:TPU:1",
87 "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0",
88 "/job:localhost/replica:0/task:1/device:TPU:0"},
89 "expected the number of TPU devices per host to be 2, got 1")));
90
91 struct ParameterizedMetadataTest
92 : ::testing::TestWithParam<std::tuple<int, int, std::string,
93 std::vector<int64_t>, std::string>> {
94 };
95
TEST_P(ParameterizedMetadataTest,BadMetadata)96 TEST_P(ParameterizedMetadataTest, BadMetadata) {
97 llvm::SmallVector<Device, 8> devices;
98 ASSERT_TRUE(DeviceNamesToParsedNames(
99 {"/job:worker/replica:0/task:0/device:TPU_SYSTEM:0",
100 "/job:worker/replica:0/task:0/device:TPU:0",
101 "/job:worker/replica:0/task:1/device:TPU_SYSTEM:0",
102 "/job:worker/replica:0/task:1/device:TPU:0"},
103 &devices));
104 std::string compilation_device;
105 llvm::SmallVector<llvm::SmallVector<std::string, 8>, 8> execution_devices;
106 llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
107
108 auto status_or = GetTPUCompilationAndExecutionDevices(
109 devices, std::get<0>(GetParam()), std::get<1>(GetParam()),
110 std::get<2>(GetParam()), std::get<3>(GetParam()));
111 ASSERT_FALSE(status_or.ok());
112 EXPECT_EQ(status_or.status().error_message(), std::get<4>(GetParam()));
113 }
114
TopologyWithMeshShape(llvm::ArrayRef<int> mesh_shape)115 std::string TopologyWithMeshShape(llvm::ArrayRef<int> mesh_shape) {
116 tpu::TopologyProto topology_proto;
117 for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim);
118 return topology_proto.SerializeAsString();
119 }
120
TopologyWithMeshShapeAndTasks(llvm::ArrayRef<int> mesh_shape,int num_tasks,int num_tpu_devices_per_task)121 std::string TopologyWithMeshShapeAndTasks(llvm::ArrayRef<int> mesh_shape,
122 int num_tasks,
123 int num_tpu_devices_per_task) {
124 tpu::TopologyProto topology_proto;
125 for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim);
126 topology_proto.set_num_tasks(num_tasks);
127 topology_proto.set_num_tpu_devices_per_task(num_tpu_devices_per_task);
128 return topology_proto.SerializeAsString();
129 }
130
TopologyWithDeviceCoordinates(llvm::ArrayRef<int> device_coordinates)131 std::string TopologyWithDeviceCoordinates(
132 llvm::ArrayRef<int> device_coordinates) {
133 tpu::TopologyProto topology_proto;
134 topology_proto.add_mesh_shape(2);
135 topology_proto.add_mesh_shape(1);
136 topology_proto.add_mesh_shape(1);
137 topology_proto.add_mesh_shape(1);
138 topology_proto.set_num_tasks(2);
139 topology_proto.set_num_tpu_devices_per_task(1);
140 for (int device_coordinate : device_coordinates)
141 topology_proto.add_device_coordinates(device_coordinate);
142 return topology_proto.SerializeAsString();
143 }
144
145 INSTANTIATE_TEST_SUITE_P(
146 BadFullMeshMetadata, ParameterizedMetadataTest,
147 ::testing::Values(
148 std::make_tuple(
149 2, 1, "", std::vector<int64_t>{0},
150 "'device_assignment' must not be set when 'topology' is not set"),
151 std::make_tuple(8, 1, "", std::vector<int64_t>(),
152 "'num_replicas' must be equal to 1 or 2, got 8"),
153 std::make_tuple(2, 2, "", std::vector<int64_t>(),
154 "'num_cores_per_replica' must be equal to 1, got 2")));
155
156 INSTANTIATE_TEST_SUITE_P(
157 BadGeneralTopologyMetadata, ParameterizedMetadataTest,
158 ::testing::Values(
159 std::make_tuple(
160 2, 1, "BAD_TOPOLOGY", std::vector<int64_t>(),
161 "failed to parse 'topology' attribute to TopologyProto"),
162 std::make_tuple(4, 2, TopologyWithMeshShape({0}),
163 std::vector<int64_t>(),
164 "'topology' 'mesh_shape' must be rank 4, got rank 1"),
165 std::make_tuple(
166 2, 1, TopologyWithMeshShape({2, 0, 1, 2}), std::vector<int64_t>(),
167 "'topology' 'mesh_shape' dimension 1 must be positive, got 0"),
168 std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 1, 1),
169 std::vector<int64_t>(),
170 "number of tasks from available TPU devices must be "
171 "'num_tasks' in 'topology' (1), got 2"),
172 std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 2, 2),
173 std::vector<int64_t>(),
174 "number of TPU devices available per task must be "
175 "'num_tpu_devices_per_task' in 'topology' (2), got 1"),
176 std::make_tuple(
177 2, 1, TopologyWithDeviceCoordinates({}), std::vector<int64_t>(),
178 "length of 'device_coordinates' in 'topology' must be 'num_tasks' "
179 "* 'num_tpus_per_task' * 4 (2 * 1 * 4), got 0"),
180 std::make_tuple(
181 2, 1, TopologyWithDeviceCoordinates({-1, 0, 0, 0, 1, 0, 0, 0}),
182 std::vector<int64_t>(),
183 "device coordinate (-1, 0, 0, 0) in 'topology' is outside "
184 "of mesh shape (2, 1, 1, 1)"),
185 std::make_tuple(
186 2, 1, TopologyWithDeviceCoordinates({2, 0, 0, 0, 1, 0, 0, 0}),
187 std::vector<int64_t>(),
188 "device coordinate (2, 0, 0, 0) in 'topology' is outside "
189 "of mesh shape (2, 1, 1, 1)"),
190 std::make_tuple(
191 2, 1, TopologyWithDeviceCoordinates({0, -1, 0, 0, 1, 0, 0, 0}),
192 std::vector<int64_t>(),
193 "device coordinate (0, -1, 0, 0) in 'topology' is outside "
194 "of mesh shape (2, 1, 1, 1)"),
195 std::make_tuple(
196 2, 1, TopologyWithDeviceCoordinates({0, 1, 0, 0, 1, 0, 0, 0}),
197 std::vector<int64_t>(),
198 "device coordinate (0, 1, 0, 0) in 'topology' is outside "
199 "of mesh shape (2, 1, 1, 1)"),
200 std::make_tuple(
201 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, -1, 1, 0, 0, 0}),
202 std::vector<int64_t>(),
203 "device coordinate (0, 0, 0, -1) in 'topology' is outside "
204 "of mesh shape (2, 1, 1, 1)"),
205 std::make_tuple(
206 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 1, 0, 0, 0}),
207 std::vector<int64_t>(),
208 "device coordinate (0, 0, 0, 1) in 'topology' is outside "
209 "of mesh shape (2, 1, 1, 1)"),
210 std::make_tuple(
211 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 0, 0, 0, 0}),
212 std::vector<int64_t>(),
213 "'topology' has duplicate device coordinate (0, 0, 0, 0)")));
214
215 INSTANTIATE_TEST_SUITE_P(
216 BadGeneralDeviceAssignmentMetadata, ParameterizedMetadataTest,
217 ::testing::Values(
218 std::make_tuple(2, 1,
219 TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
220 std::vector<int64_t>(),
221 "length of 'device_assignment' must be 'num_replicas' "
222 "* 'num_cores_per_replica' * 4 (2 * 1 * 4), got 0"),
223 std::make_tuple(
224 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
225 std::vector<int64_t>{-1, 0, 0, 0, 0, 0, 0, 0},
226 "device coordinate (-1, 0, 0, 0) in 'device_assignment' "
227 "is outside of mesh shape (2, 1, 1, 1)"),
228 std::make_tuple(
229 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
230 std::vector<int64_t>{2, 0, 0, 0, 0, 0, 0, 0},
231 "device coordinate (2, 0, 0, 0) in 'device_assignment' is "
232 "outside of mesh shape (2, 1, 1, 1)"),
233 std::make_tuple(
234 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
235 std::vector<int64_t>{0, -1, 0, 0, 0, 0, 0, 0},
236 "device coordinate (0, -1, 0, 0) in 'device_assignment' "
237 "is outside of mesh shape (2, 1, 1, 1)"),
238 std::make_tuple(
239 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
240 std::vector<int64_t>{0, 1, 0, 0, 0, 0, 0, 0},
241 "device coordinate (0, 1, 0, 0) in 'device_assignment' is "
242 "outside of mesh shape (2, 1, 1, 1)"),
243 std::make_tuple(
244 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
245 std::vector<int64_t>{0, 0, 0, -1, 0, 0, 0, 0},
246 "device coordinate (0, 0, 0, -1) in 'device_assignment' "
247 "is outside of mesh shape (2, 1, 1, 1)"),
248 std::make_tuple(
249 2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
250 std::vector<int64_t>{0, 0, 0, 1, 0, 0, 0, 0},
251 "device coordinate (0, 0, 0, 1) in 'device_assignment' is "
252 "outside of mesh shape (2, 1, 1, 1)"),
253 std::make_tuple(2, 1,
254 TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
255 std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0},
256 "'device_assignment' has duplicate device coordinate "
257 "(0, 0, 0, 0)")));
258
MakeDeviceSet(int num_tasks,int num_devices_per_task)259 std::vector<std::string> MakeDeviceSet(int num_tasks,
260 int num_devices_per_task) {
261 std::vector<std::string> devices{
262 "/job:localhost/replica:0/task:0/device:CPU:0"};
263 devices.reserve(num_tasks * num_devices_per_task + num_tasks + 1);
264
265 for (int task = 0; task < num_tasks; ++task) {
266 devices.push_back(
267 llvm::formatv("/job:worker/replica:0/task:{0}/device:CPU:0", task)
268 .str());
269 devices.push_back(
270 llvm::formatv("/job:worker/replica:0/task:{0}/device:TPU_SYSTEM:0",
271 task)
272 .str());
273 for (int device = 0; device < num_devices_per_task; ++device)
274 devices.push_back(
275 llvm::formatv("/job:worker/replica:0/task:{0}/device:TPU:{1}", task,
276 device)
277 .str());
278 }
279
280 return devices;
281 }
282
TEST(TPURewriteDeviceUtilTest,BadGeneralDeviceAssignmentMetadataMissingDevice)283 TEST(TPURewriteDeviceUtilTest,
284 BadGeneralDeviceAssignmentMetadataMissingDevice) {
285 tpu::TopologyProto topology_proto;
286 {
287 topology_proto.add_mesh_shape(2);
288 topology_proto.add_mesh_shape(1);
289 topology_proto.add_mesh_shape(1);
290 topology_proto.add_mesh_shape(1);
291 topology_proto.set_num_tasks(1);
292 topology_proto.set_num_tpu_devices_per_task(1);
293 topology_proto.add_device_coordinates(0);
294 topology_proto.add_device_coordinates(0);
295 topology_proto.add_device_coordinates(0);
296 topology_proto.add_device_coordinates(0);
297 }
298
299 std::string topology_attr = topology_proto.SerializeAsString();
300 std::vector<int64_t> device_assignment_attr{1, 0, 0, 0};
301
302 llvm::SmallVector<Device, 8> devices;
303 std::vector<std::string> device_names =
304 MakeDeviceSet(/*num_tasks=*/1, /*num_devices_per_task=*/1);
305 ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
306
307 auto status_or = GetTPUCompilationAndExecutionDevices(
308 devices, /*num_replicas=*/1, /*num_cores_per_replica=*/1, topology_attr,
309 device_assignment_attr);
310
311 ASSERT_FALSE(status_or.ok());
312 EXPECT_EQ(status_or.status().error_message(),
313 "no TPU device found for 'device_assignment' device coordinate (1, "
314 "0, 0, 0)");
315 }
316
TEST(TPURewriteDeviceUtilTest,ValidFullMeshDeviceAssignment)317 TEST(TPURewriteDeviceUtilTest, ValidFullMeshDeviceAssignment) {
318 llvm::SmallVector<Device, 8> devices;
319 std::vector<std::string> device_names =
320 MakeDeviceSet(/*num_tasks=*/2, /*num_devices_per_task=*/4);
321 ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
322 std::string topology_attr;
323 std::vector<int64_t> device_assignment_attr;
324
325 auto status_or = GetTPUCompilationAndExecutionDevices(
326 devices, /*num_replicas=*/8, /*num_cores_per_replica=*/1, topology_attr,
327 device_assignment_attr);
328
329 TF_ASSERT_OK(status_or.status());
330
331 const auto& tpu_device_assignment = status_or.ValueOrDie();
332 EXPECT_EQ(tpu_device_assignment.compilation_device,
333 "/job:worker/replica:0/task:0/device:CPU:0");
334 const auto& tpu_devices = tpu_device_assignment.tpu_devices;
335 ASSERT_EQ(tpu_devices.size(), 8);
336 for (const auto& replica_tpu_devices : tpu_devices)
337 ASSERT_EQ(replica_tpu_devices.size(), 1);
338
339 EXPECT_EQ(tpu_devices[0][0].device,
340 "/job:worker/replica:0/task:0/device:TPU:0");
341 EXPECT_EQ(tpu_devices[0][0].host,
342 "/job:worker/replica:0/task:0/device:CPU:0");
343 EXPECT_EQ(tpu_devices[1][0].device,
344 "/job:worker/replica:0/task:0/device:TPU:1");
345 EXPECT_EQ(tpu_devices[1][0].host,
346 "/job:worker/replica:0/task:0/device:CPU:0");
347 EXPECT_EQ(tpu_devices[2][0].device,
348 "/job:worker/replica:0/task:0/device:TPU:2");
349 EXPECT_EQ(tpu_devices[2][0].host,
350 "/job:worker/replica:0/task:0/device:CPU:0");
351 EXPECT_EQ(tpu_devices[3][0].device,
352 "/job:worker/replica:0/task:0/device:TPU:3");
353 EXPECT_EQ(tpu_devices[3][0].host,
354 "/job:worker/replica:0/task:0/device:CPU:0");
355 EXPECT_EQ(tpu_devices[4][0].device,
356 "/job:worker/replica:0/task:1/device:TPU:0");
357 EXPECT_EQ(tpu_devices[4][0].host,
358 "/job:worker/replica:0/task:1/device:CPU:0");
359 EXPECT_EQ(tpu_devices[5][0].device,
360 "/job:worker/replica:0/task:1/device:TPU:1");
361 EXPECT_EQ(tpu_devices[5][0].host,
362 "/job:worker/replica:0/task:1/device:CPU:0");
363 EXPECT_EQ(tpu_devices[6][0].device,
364 "/job:worker/replica:0/task:1/device:TPU:2");
365 EXPECT_EQ(tpu_devices[6][0].host,
366 "/job:worker/replica:0/task:1/device:CPU:0");
367 EXPECT_EQ(tpu_devices[7][0].device,
368 "/job:worker/replica:0/task:1/device:TPU:3");
369 EXPECT_EQ(tpu_devices[7][0].host,
370 "/job:worker/replica:0/task:1/device:CPU:0");
371
372 EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.has_value());
373 }
374
TEST(TPURewriteDeviceUtilTest,ValidGeneralDeviceAssignmentMesh2x2x2)375 TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) {
376 tpu::TopologyProto topology_proto;
377 {
378 topology_proto.add_mesh_shape(2);
379 topology_proto.add_mesh_shape(2);
380 topology_proto.add_mesh_shape(1);
381 topology_proto.add_mesh_shape(2);
382 topology_proto.set_num_tasks(2);
383 topology_proto.set_num_tpu_devices_per_task(4);
384 topology_proto.add_device_coordinates(0);
385 topology_proto.add_device_coordinates(0);
386 topology_proto.add_device_coordinates(0);
387 topology_proto.add_device_coordinates(0);
388 topology_proto.add_device_coordinates(0);
389 topology_proto.add_device_coordinates(1);
390 topology_proto.add_device_coordinates(0);
391 topology_proto.add_device_coordinates(0);
392 topology_proto.add_device_coordinates(1);
393 topology_proto.add_device_coordinates(1);
394 topology_proto.add_device_coordinates(0);
395 topology_proto.add_device_coordinates(0);
396 topology_proto.add_device_coordinates(1);
397 topology_proto.add_device_coordinates(0);
398 topology_proto.add_device_coordinates(0);
399 topology_proto.add_device_coordinates(0);
400 topology_proto.add_device_coordinates(1);
401 topology_proto.add_device_coordinates(0);
402 topology_proto.add_device_coordinates(0);
403 topology_proto.add_device_coordinates(1);
404 topology_proto.add_device_coordinates(1);
405 topology_proto.add_device_coordinates(1);
406 topology_proto.add_device_coordinates(0);
407 topology_proto.add_device_coordinates(1);
408 topology_proto.add_device_coordinates(0);
409 topology_proto.add_device_coordinates(1);
410 topology_proto.add_device_coordinates(0);
411 topology_proto.add_device_coordinates(1);
412 topology_proto.add_device_coordinates(0);
413 topology_proto.add_device_coordinates(0);
414 topology_proto.add_device_coordinates(0);
415 topology_proto.add_device_coordinates(1);
416 }
417
418 std::string topology_attr = topology_proto.SerializeAsString();
419 std::vector<int64_t> device_assignment_attr{0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
420 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0,
421 0, 1, 1, 1, 0, 0, 1, 1, 0, 1};
422
423 llvm::SmallVector<Device, 8> devices;
424 std::vector<std::string> device_names =
425 MakeDeviceSet(/*num_tasks=*/2, /*num_devices_per_task=*/4);
426 ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
427
428 auto status_or = GetTPUCompilationAndExecutionDevices(
429 devices, /*num_replicas=*/4, /*num_cores_per_replica=*/2, topology_attr,
430 device_assignment_attr);
431
432 TF_ASSERT_OK(status_or.status());
433
434 const auto& tpu_device_assignment = status_or.ValueOrDie();
435 EXPECT_EQ(tpu_device_assignment.compilation_device,
436 "/job:worker/replica:0/task:0/device:CPU:0");
437 const auto& tpu_devices = tpu_device_assignment.tpu_devices;
438 ASSERT_EQ(tpu_devices.size(), 4);
439 for (const auto& replica_tpu_devices : tpu_devices)
440 ASSERT_EQ(replica_tpu_devices.size(), 2);
441
442 EXPECT_EQ(tpu_devices[0][0].device,
443 "/job:worker/replica:0/task:0/device:TPU:0");
444 EXPECT_EQ(tpu_devices[0][0].host,
445 "/job:worker/replica:0/task:0/device:CPU:0");
446 EXPECT_EQ(tpu_devices[0][1].device,
447 "/job:worker/replica:0/task:1/device:TPU:3");
448 EXPECT_EQ(tpu_devices[0][1].host,
449 "/job:worker/replica:0/task:1/device:CPU:0");
450 EXPECT_EQ(tpu_devices[1][0].device,
451 "/job:worker/replica:0/task:0/device:TPU:1");
452 EXPECT_EQ(tpu_devices[1][0].host,
453 "/job:worker/replica:0/task:0/device:CPU:0");
454 EXPECT_EQ(tpu_devices[1][1].device,
455 "/job:worker/replica:0/task:1/device:TPU:2");
456 EXPECT_EQ(tpu_devices[1][1].host,
457 "/job:worker/replica:0/task:1/device:CPU:0");
458 EXPECT_EQ(tpu_devices[2][0].device,
459 "/job:worker/replica:0/task:0/device:TPU:3");
460 EXPECT_EQ(tpu_devices[2][0].host,
461 "/job:worker/replica:0/task:0/device:CPU:0");
462 EXPECT_EQ(tpu_devices[2][1].device,
463 "/job:worker/replica:0/task:1/device:TPU:0");
464 EXPECT_EQ(tpu_devices[2][1].host,
465 "/job:worker/replica:0/task:1/device:CPU:0");
466 EXPECT_EQ(tpu_devices[3][0].device,
467 "/job:worker/replica:0/task:0/device:TPU:2");
468 EXPECT_EQ(tpu_devices[3][0].host,
469 "/job:worker/replica:0/task:0/device:CPU:0");
470 EXPECT_EQ(tpu_devices[3][1].device,
471 "/job:worker/replica:0/task:1/device:TPU:1");
472 EXPECT_EQ(tpu_devices[3][1].host,
473 "/job:worker/replica:0/task:1/device:CPU:0");
474
475 auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
476 ASSERT_TRUE(xla_device_assignment.has_value());
477 EXPECT_EQ(xla_device_assignment->replica_count(), 4);
478 EXPECT_EQ(xla_device_assignment->computation_count(), 2);
479 ASSERT_EQ(xla_device_assignment->computation_devices_size(), 2);
480 const auto& computation_device_0 =
481 xla_device_assignment->computation_devices(0);
482 ASSERT_EQ(computation_device_0.replica_device_ids_size(), 4);
483 const auto& computation_device_1 =
484 xla_device_assignment->computation_devices(1);
485 ASSERT_EQ(computation_device_1.replica_device_ids_size(), 4);
486
487 EXPECT_EQ(computation_device_0.replica_device_ids(0), 0);
488 EXPECT_EQ(computation_device_0.replica_device_ids(1), 4);
489 EXPECT_EQ(computation_device_0.replica_device_ids(2), 2);
490 EXPECT_EQ(computation_device_0.replica_device_ids(3), 6);
491 EXPECT_EQ(computation_device_1.replica_device_ids(0), 1);
492 EXPECT_EQ(computation_device_1.replica_device_ids(1), 5);
493 EXPECT_EQ(computation_device_1.replica_device_ids(2), 3);
494 EXPECT_EQ(computation_device_1.replica_device_ids(3), 7);
495 }
496
TEST(TPURewriteDeviceUtilTest,ValidGeneralDeviceAssignmentMesh1x2x1x3)497 TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
498 tpu::TopologyProto topology_proto;
499 {
500 topology_proto.add_mesh_shape(1);
501 topology_proto.add_mesh_shape(2);
502 topology_proto.add_mesh_shape(1);
503 topology_proto.add_mesh_shape(3);
504 topology_proto.set_num_tasks(3);
505 topology_proto.set_num_tpu_devices_per_task(2);
506 topology_proto.add_device_coordinates(0);
507 topology_proto.add_device_coordinates(0);
508 topology_proto.add_device_coordinates(0);
509 topology_proto.add_device_coordinates(0);
510 topology_proto.add_device_coordinates(0);
511 topology_proto.add_device_coordinates(1);
512 topology_proto.add_device_coordinates(0);
513 topology_proto.add_device_coordinates(0);
514 topology_proto.add_device_coordinates(0);
515 topology_proto.add_device_coordinates(1);
516 topology_proto.add_device_coordinates(0);
517 topology_proto.add_device_coordinates(1);
518 topology_proto.add_device_coordinates(0);
519 topology_proto.add_device_coordinates(0);
520 topology_proto.add_device_coordinates(0);
521 topology_proto.add_device_coordinates(1);
522 topology_proto.add_device_coordinates(0);
523 topology_proto.add_device_coordinates(0);
524 topology_proto.add_device_coordinates(0);
525 topology_proto.add_device_coordinates(2);
526 topology_proto.add_device_coordinates(0);
527 topology_proto.add_device_coordinates(1);
528 topology_proto.add_device_coordinates(0);
529 topology_proto.add_device_coordinates(2);
530 }
531
532 std::string topology_attr = topology_proto.SerializeAsString();
533 std::vector<int64_t> device_assignment_attr{
534 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 2, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0};
535
536 llvm::SmallVector<Device, 8> devices;
537 std::vector<std::string> device_names =
538 MakeDeviceSet(/*num_tasks=*/3, /*num_devices_per_task=*/2);
539 ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
540
541 auto status_or = GetTPUCompilationAndExecutionDevices(
542 devices, /*num_replicas=*/2, /*num_cores_per_replica=*/3, topology_attr,
543 device_assignment_attr);
544
545 TF_ASSERT_OK(status_or.status());
546
547 auto& tpu_device_assignment = status_or.ValueOrDie();
548 EXPECT_EQ(tpu_device_assignment.compilation_device,
549 "/job:worker/replica:0/task:0/device:CPU:0");
550
551 auto& tpu_devices = tpu_device_assignment.tpu_devices;
552 ASSERT_EQ(tpu_devices.size(), 2);
553 for (const auto& replica_tpu_devices : tpu_devices)
554 ASSERT_EQ(replica_tpu_devices.size(), 3);
555
556 EXPECT_EQ(tpu_devices[0][0].device,
557 "/job:worker/replica:0/task:1/device:TPU:1");
558 EXPECT_EQ(tpu_devices[0][0].host,
559 "/job:worker/replica:0/task:1/device:CPU:0");
560 EXPECT_EQ(tpu_devices[0][1].device,
561 "/job:worker/replica:0/task:1/device:TPU:0");
562 EXPECT_EQ(tpu_devices[0][1].host,
563 "/job:worker/replica:0/task:1/device:CPU:0");
564 EXPECT_EQ(tpu_devices[0][2].device,
565 "/job:worker/replica:0/task:2/device:TPU:0");
566 EXPECT_EQ(tpu_devices[0][2].host,
567 "/job:worker/replica:0/task:2/device:CPU:0");
568 EXPECT_EQ(tpu_devices[1][0].device,
569 "/job:worker/replica:0/task:2/device:TPU:1");
570 EXPECT_EQ(tpu_devices[1][0].host,
571 "/job:worker/replica:0/task:2/device:CPU:0");
572 EXPECT_EQ(tpu_devices[1][1].device,
573 "/job:worker/replica:0/task:0/device:TPU:0");
574 EXPECT_EQ(tpu_devices[1][1].host,
575 "/job:worker/replica:0/task:0/device:CPU:0");
576 EXPECT_EQ(tpu_devices[1][2].device,
577 "/job:worker/replica:0/task:0/device:TPU:1");
578 EXPECT_EQ(tpu_devices[1][2].host,
579 "/job:worker/replica:0/task:0/device:CPU:0");
580
581 auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
582 ASSERT_TRUE(xla_device_assignment.has_value());
583 EXPECT_EQ(xla_device_assignment->replica_count(), 2);
584 EXPECT_EQ(xla_device_assignment->computation_count(), 3);
585 ASSERT_EQ(xla_device_assignment->computation_devices_size(), 3);
586 const auto& computation_device_0 =
587 xla_device_assignment->computation_devices(0);
588 ASSERT_EQ(computation_device_0.replica_device_ids_size(), 2);
589 const auto& computation_device_1 =
590 xla_device_assignment->computation_devices(1);
591 ASSERT_EQ(computation_device_1.replica_device_ids_size(), 2);
592 const auto& computation_device_2 =
593 xla_device_assignment->computation_devices(2);
594 ASSERT_EQ(computation_device_2.replica_device_ids_size(), 2);
595
596 EXPECT_EQ(computation_device_0.replica_device_ids(0), 1);
597 EXPECT_EQ(computation_device_0.replica_device_ids(1), 5);
598 EXPECT_EQ(computation_device_1.replica_device_ids(0), 4);
599 EXPECT_EQ(computation_device_1.replica_device_ids(1), 0);
600 EXPECT_EQ(computation_device_2.replica_device_ids(0), 2);
601 EXPECT_EQ(computation_device_2.replica_device_ids(1), 3);
602 }
603
TEST(TPURewriteDeviceUtilTest,TestGetDeviceCoordinates)604 TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
605 mlir::MLIRContext context;
606 mlir::Builder builder(&context);
607 auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
608 auto status_or_device_coodinates =
609 GetDeviceCoordinates(device_assignment_attr);
610 ASSERT_TRUE(status_or_device_coodinates.ok());
611 auto device_coordinates = status_or_device_coodinates.value();
612 EXPECT_EQ(device_coordinates[0], 1);
613 EXPECT_EQ(device_coordinates[1], 2);
614 EXPECT_EQ(device_coordinates[2], 3);
615 }
616
TEST(TPURewriteDeviceUtilTest,TestInvalidAttrForDeviceAssignmentDisallowed)617 TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
618 mlir::MLIRContext context;
619 mlir::Builder builder(&context);
620 auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
621 auto status_or_device_coodinates =
622 GetDeviceCoordinates(device_assignment_attr);
623 ASSERT_TRUE(!status_or_device_coodinates.ok());
624 EXPECT_EQ(status_or_device_coodinates.status().error_message(),
625 "bad 'device_assignment' attribute at index 0, not an int");
626 }
627
TEST(TPURewriteDeviceUtilTest,TestHasModelParallelismFalse)628 TEST(TPURewriteDeviceUtilTest, TestHasModelParallelismFalse) {
629 mlir::MLIRContext context;
630 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
631 mlir::OwningOpRef<mlir::ModuleOp> module_ref =
632 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
633 mlir::OpBuilder builder(module_ref->getBodyRegion());
634
635 llvm::SmallVector<mlir::Type, 8> result_types;
636 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
637 mlir::UnknownLoc::get(&context), result_types);
638 cluster->setAttr(kNumCoresPerReplicaAttr,
639 builder.getIntegerAttr(builder.getIntegerType(64), 1));
640 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
641 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
642
643 EXPECT_FALSE(HasModelParallelism(cluster));
644 }
645
TEST(TPURewriteDeviceUtilTest,TestHasModelParallelismTrue)646 TEST(TPURewriteDeviceUtilTest, TestHasModelParallelismTrue) {
647 mlir::MLIRContext context;
648 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
649 mlir::OwningOpRef<mlir::ModuleOp> module_ref =
650 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
651 mlir::OpBuilder builder(module_ref->getBodyRegion());
652
653 llvm::SmallVector<mlir::Type, 8> result_types;
654 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
655 mlir::UnknownLoc::get(&context), result_types);
656 cluster->setAttr(kNumCoresPerReplicaAttr,
657 builder.getIntegerAttr(builder.getIntegerType(64), 5));
658 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
659 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
660
661 EXPECT_TRUE(HasModelParallelism(cluster));
662 }
663
TEST(TPURewriteDeviceUtilTest,TestHasModelParallelismFalseMissingCoresPerReplicaAttr)664 TEST(TPURewriteDeviceUtilTest,
665 TestHasModelParallelismFalseMissingCoresPerReplicaAttr) {
666 mlir::MLIRContext context;
667 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
668 mlir::OwningOpRef<mlir::ModuleOp> module_ref =
669 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
670 mlir::OpBuilder builder(module_ref->getBodyRegion());
671
672 llvm::SmallVector<mlir::Type, 8> result_types;
673 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
674 mlir::UnknownLoc::get(&context), result_types);
675 cluster->setAttr(kNumCoresPerReplicaAttr,
676 builder.getIntegerAttr(builder.getIntegerType(64), 1));
677 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
678 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
679
680 EXPECT_FALSE(HasModelParallelism(cluster));
681 }
682
TEST(TPURewriteDeviceUtilTest,TestGetHostFailNumCoresPerReplicaMissingAttributes)683 TEST(TPURewriteDeviceUtilTest,
684 TestGetHostFailNumCoresPerReplicaMissingAttributes) {
685 mlir::MLIRContext context;
686 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
687 mlir::OwningOpRef<mlir::ModuleOp> module_ref =
688 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
689 mlir::OpBuilder builder(module_ref->getBodyRegion());
690 llvm::SmallVector<mlir::Type, 8> result_types;
691 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
692 mlir::UnknownLoc::get(&context), result_types);
693 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
694
695 mlir::TF::RuntimeDevices devices;
696 std::string host_device;
697 EXPECT_TRUE(mlir::failed(
698 GetHostDeviceOutsideComputation(devices, cluster, &host_device)));
699 }
700
TEST(TPURewriteDeviceUtilTest,TestGetHostFailDeviceMissingAttributes)701 TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
702 mlir::MLIRContext context;
703 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
704 mlir::OwningOpRef<mlir::ModuleOp> module_ref =
705 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
706 mlir::OpBuilder builder(module_ref->getBodyRegion());
707 llvm::SmallVector<mlir::Type, 8> result_types;
708 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
709 mlir::UnknownLoc::get(&context), result_types);
710 cluster->setAttr(kNumCoresPerReplicaAttr,
711 builder.getIntegerAttr(builder.getIntegerType(64), 1));
712
713 mlir::TF::RuntimeDevices devices;
714 std::string host_device;
715 EXPECT_TRUE(mlir::failed(
716 GetHostDeviceOutsideComputation(devices, cluster, &host_device)));
717 }
718
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailMissingTopology)719 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) {
720 mlir::MLIRContext context;
721 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
722 mlir::OwningOpRef<mlir::ModuleOp> module_ref =
723 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
724 mlir::OpBuilder builder(module_ref->getBodyRegion());
725
726 llvm::SmallVector<mlir::Type, 8> result_types;
727 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
728 mlir::UnknownLoc::get(&context), result_types);
729 cluster->setAttr(kNumCoresPerReplicaAttr,
730 builder.getIntegerAttr(builder.getIntegerType(64), 1));
731 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
732
733 mlir::TF::RuntimeDevices runtime_devices;
734 std::string host_device;
735 EXPECT_TRUE(mlir::failed(
736 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
737 }
738
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailMissingDeviceAssignment)739 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) {
740 mlir::MLIRContext context;
741 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
742 mlir::OwningOpRef<mlir::ModuleOp> module_ref =
743 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
744 mlir::OpBuilder builder(module_ref->getBodyRegion());
745
746 llvm::SmallVector<mlir::Type, 8> result_types;
747 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
748 mlir::UnknownLoc::get(&context), result_types);
749 cluster->setAttr(kNumCoresPerReplicaAttr,
750 builder.getIntegerAttr(builder.getIntegerType(64), 1));
751 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
752
753 mlir::TF::RuntimeDevices runtime_devices;
754 std::string host_device;
755 EXPECT_TRUE(mlir::failed(
756 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
757 }
758
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailBadDeviceAssignment)759 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) {
760 mlir::MLIRContext context;
761 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
762 mlir::OwningOpRef<mlir::ModuleOp> module_ref =
763 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
764 mlir::OpBuilder builder(module_ref->getBodyRegion());
765
766 llvm::SmallVector<mlir::Type, 8> result_types;
767 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
768 mlir::UnknownLoc::get(&context), result_types);
769 cluster->setAttr(kNumCoresPerReplicaAttr,
770 builder.getIntegerAttr(builder.getIntegerType(64), 1));
771 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
772 cluster->setAttr(kDeviceAssignmentAttr,
773 builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
774 {"bad_device_assigment"})));
775
776 mlir::TF::RuntimeDevices runtime_devices;
777 std::string host_device;
778 EXPECT_TRUE(mlir::failed(
779 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
780 }
781
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailBadDeviceName)782 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) {
783 mlir::MLIRContext context;
784 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
785 mlir::OwningOpRef<mlir::ModuleOp> module_ref =
786 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
787 mlir::OpBuilder builder(module_ref->getBodyRegion());
788 (*module_ref)
789 ->setAttr("tf.devices",
790 builder.getStrArrayAttr(
791 llvm::ArrayRef<llvm::StringRef>({"bad_device_name"})));
792
793 llvm::SmallVector<mlir::Type, 8> result_types;
794 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
795 mlir::UnknownLoc::get(&context), result_types);
796 cluster->setAttr(kNumCoresPerReplicaAttr,
797 builder.getIntegerAttr(builder.getIntegerType(64), 1));
798 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
799 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
800
801 mlir::TF::RuntimeDevices runtime_devices;
802 (void)GetDevicesFromOp(*module_ref, &runtime_devices);
803 std::string host_device;
804 EXPECT_TRUE(mlir::failed(
805 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
806 }
807
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceTPUReplicate)808 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
809 mlir::MLIRContext context;
810 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
811 mlir::OwningOpRef<mlir::ModuleOp> module_ref =
812 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
813 mlir::OpBuilder builder(module_ref->getBodyRegion());
814
815 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<llvm::StringRef, 4>>
816 devices;
817 auto replicate = builder.create<mlir::tf_device::ReplicateOp>(
818 mlir::UnknownLoc::get(&context), /*num_replicas=*/2, devices,
819 llvm::ArrayRef<std::pair<mlir::ValueRange, mlir::Type>>{},
820 mlir::ValueRange{}, mlir::TypeRange{});
821 builder.setInsertionPoint(&replicate.body().front(),
822 replicate.body().front().begin());
823
824 llvm::SmallVector<mlir::Type, 8> result_types;
825 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
826 mlir::UnknownLoc::get(&context), result_types);
827
828 mlir::TF::RuntimeDevices runtime_devices;
829 std::string host_device;
830 EXPECT_TRUE(mlir::succeeded(
831 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
832 EXPECT_EQ(host_device, kTPUReplicatedHost);
833 }
834
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceNotReplicated)835 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) {
836 mlir::MLIRContext context;
837 context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
838 mlir::OwningOpRef<mlir::ModuleOp> module_ref =
839 mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
840 mlir::OpBuilder builder(module_ref->getBodyRegion());
841 (*module_ref)
842 ->setAttr("tf.devices",
843 builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
844 {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
845 "/job:localhost/replica:0/task:0/device:TPU:0",
846 "/job:worker/replica:0/task:0/device:CPU:0"})));
847
848 llvm::SmallVector<mlir::Type, 8> result_types;
849 auto cluster = builder.create<mlir::tf_device::ClusterOp>(
850 mlir::UnknownLoc::get(&context), result_types);
851 cluster->setAttr(kNumCoresPerReplicaAttr,
852 builder.getIntegerAttr(builder.getIntegerType(64), 1));
853 cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
854 cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
855
856 mlir::TF::RuntimeDevices runtime_devices;
857 (void)GetDevicesFromOp(*module_ref, &runtime_devices);
858 std::string host_device;
859 EXPECT_TRUE(mlir::succeeded(
860 GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
861 EXPECT_EQ(host_device, "/job:localhost/replica:0/task:0/device:CPU:0");
862 }
863
TEST(TPURewriteDeviceUtilTest,TestIsTPUDevice)864 TEST(TPURewriteDeviceUtilTest, TestIsTPUDevice) {
865 EXPECT_TRUE(IsTPUDevice("/job:localhost/replica:0/task:0/device:TPU:0"));
866 EXPECT_FALSE(IsTPUDevice("/job:localhost/replica:0/task:0/device:CPU:0"));
867 EXPECT_FALSE(IsTPUDevice("INVALID_DEVICE"));
868 }
869
870 } // anonymous namespace
871 } // namespace tensorflow
872