1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
17 
18 #include <cstdint>
19 #include <tuple>
20 
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
30 #include "tensorflow/core/util/device_name_utils.h"
31 
32 namespace tensorflow {
33 namespace {
34 
35 using Device = DeviceNameUtils::ParsedName;
36 
DeviceNamesToParsedNames(llvm::ArrayRef<std::string> device_names,llvm::SmallVectorImpl<Device> * parsed_devices)37 bool DeviceNamesToParsedNames(llvm::ArrayRef<std::string> device_names,
38                               llvm::SmallVectorImpl<Device>* parsed_devices) {
39   parsed_devices->reserve(device_names.size());
40   for (const auto& device_name : device_names) {
41     Device parsed_name;
42     if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name))
43       return false;
44 
45     parsed_devices->push_back(parsed_name);
46   }
47   return true;
48 }
49 
50 using DeviceNames = llvm::SmallVector<std::string, 8>;
51 
52 struct ParameterizedDeviceSetTest
53     : ::testing::TestWithParam<std::tuple<DeviceNames, std::string>> {};
54 
TEST_P(ParameterizedDeviceSetTest,BadDeviceSet)55 TEST_P(ParameterizedDeviceSetTest, BadDeviceSet) {
56   llvm::SmallVector<Device, 8> devices;
57   ASSERT_TRUE(DeviceNamesToParsedNames(std::get<0>(GetParam()), &devices));
58   std::string topology_attr;
59   std::vector<int64_t> device_assignment_attr;
60 
61   auto status_or = GetTPUCompilationAndExecutionDevices(
62       devices, /*num_replicas=*/1, /*num_cores_per_replica=*/1, topology_attr,
63       device_assignment_attr);
64   ASSERT_FALSE(status_or.ok());
65   EXPECT_EQ(status_or.status().error_message(), std::get<1>(GetParam()));
66 }
67 
68 INSTANTIATE_TEST_SUITE_P(
69     BadDeviceSet, ParameterizedDeviceSetTest,
70     ::testing::Values(
71         std::make_tuple<DeviceNames, std::string>(
72             {"/job:localhost/replica:0/task:0/device:CPU:0"},
73             "no TPU_SYSTEM devices found"),
74         std::make_tuple<DeviceNames, std::string>(
75             {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
76              "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0"},
77             "found TPU_SYSTEM devices with conflicting jobs 'localhost' and "
78             "'worker'"),
79         std::make_tuple<DeviceNames, std::string>(
80             {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
81              "/job:localhost/replica:1/task:0/device:TPU_SYSTEM:0"},
82             "found TPU_SYSTEM devices with conflicting replicas '0' and '1'"),
83         std::make_tuple<DeviceNames, std::string>(
84             {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
85              "/job:localhost/replica:0/task:0/device:TPU:0",
86              "/job:localhost/replica:0/task:0/device:TPU:1",
87              "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0",
88              "/job:localhost/replica:0/task:1/device:TPU:0"},
89             "expected the number of TPU devices per host to be 2, got 1")));
90 
91 struct ParameterizedMetadataTest
92     : ::testing::TestWithParam<std::tuple<int, int, std::string,
93                                           std::vector<int64_t>, std::string>> {
94 };
95 
TEST_P(ParameterizedMetadataTest,BadMetadata)96 TEST_P(ParameterizedMetadataTest, BadMetadata) {
97   llvm::SmallVector<Device, 8> devices;
98   ASSERT_TRUE(DeviceNamesToParsedNames(
99       {"/job:worker/replica:0/task:0/device:TPU_SYSTEM:0",
100        "/job:worker/replica:0/task:0/device:TPU:0",
101        "/job:worker/replica:0/task:1/device:TPU_SYSTEM:0",
102        "/job:worker/replica:0/task:1/device:TPU:0"},
103       &devices));
104   std::string compilation_device;
105   llvm::SmallVector<llvm::SmallVector<std::string, 8>, 8> execution_devices;
106   llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
107 
108   auto status_or = GetTPUCompilationAndExecutionDevices(
109       devices, std::get<0>(GetParam()), std::get<1>(GetParam()),
110       std::get<2>(GetParam()), std::get<3>(GetParam()));
111   ASSERT_FALSE(status_or.ok());
112   EXPECT_EQ(status_or.status().error_message(), std::get<4>(GetParam()));
113 }
114 
TopologyWithMeshShape(llvm::ArrayRef<int> mesh_shape)115 std::string TopologyWithMeshShape(llvm::ArrayRef<int> mesh_shape) {
116   tpu::TopologyProto topology_proto;
117   for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim);
118   return topology_proto.SerializeAsString();
119 }
120 
TopologyWithMeshShapeAndTasks(llvm::ArrayRef<int> mesh_shape,int num_tasks,int num_tpu_devices_per_task)121 std::string TopologyWithMeshShapeAndTasks(llvm::ArrayRef<int> mesh_shape,
122                                           int num_tasks,
123                                           int num_tpu_devices_per_task) {
124   tpu::TopologyProto topology_proto;
125   for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim);
126   topology_proto.set_num_tasks(num_tasks);
127   topology_proto.set_num_tpu_devices_per_task(num_tpu_devices_per_task);
128   return topology_proto.SerializeAsString();
129 }
130 
TopologyWithDeviceCoordinates(llvm::ArrayRef<int> device_coordinates)131 std::string TopologyWithDeviceCoordinates(
132     llvm::ArrayRef<int> device_coordinates) {
133   tpu::TopologyProto topology_proto;
134   topology_proto.add_mesh_shape(2);
135   topology_proto.add_mesh_shape(1);
136   topology_proto.add_mesh_shape(1);
137   topology_proto.add_mesh_shape(1);
138   topology_proto.set_num_tasks(2);
139   topology_proto.set_num_tpu_devices_per_task(1);
140   for (int device_coordinate : device_coordinates)
141     topology_proto.add_device_coordinates(device_coordinate);
142   return topology_proto.SerializeAsString();
143 }
144 
145 INSTANTIATE_TEST_SUITE_P(
146     BadFullMeshMetadata, ParameterizedMetadataTest,
147     ::testing::Values(
148         std::make_tuple(
149             2, 1, "", std::vector<int64_t>{0},
150             "'device_assignment' must not be set when 'topology' is not set"),
151         std::make_tuple(8, 1, "", std::vector<int64_t>(),
152                         "'num_replicas' must be equal to 1 or 2, got 8"),
153         std::make_tuple(2, 2, "", std::vector<int64_t>(),
154                         "'num_cores_per_replica' must be equal to 1, got 2")));
155 
156 INSTANTIATE_TEST_SUITE_P(
157     BadGeneralTopologyMetadata, ParameterizedMetadataTest,
158     ::testing::Values(
159         std::make_tuple(
160             2, 1, "BAD_TOPOLOGY", std::vector<int64_t>(),
161             "failed to parse 'topology' attribute to TopologyProto"),
162         std::make_tuple(4, 2, TopologyWithMeshShape({0}),
163                         std::vector<int64_t>(),
164                         "'topology' 'mesh_shape' must be rank 4, got rank 1"),
165         std::make_tuple(
166             2, 1, TopologyWithMeshShape({2, 0, 1, 2}), std::vector<int64_t>(),
167             "'topology' 'mesh_shape' dimension 1 must be positive, got 0"),
168         std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 1, 1),
169                         std::vector<int64_t>(),
170                         "number of tasks from available TPU devices must be "
171                         "'num_tasks' in 'topology' (1), got 2"),
172         std::make_tuple(2, 1, TopologyWithMeshShapeAndTasks({1, 1, 1, 1}, 2, 2),
173                         std::vector<int64_t>(),
174                         "number of TPU devices available per task must be "
175                         "'num_tpu_devices_per_task' in 'topology' (2), got 1"),
176         std::make_tuple(
177             2, 1, TopologyWithDeviceCoordinates({}), std::vector<int64_t>(),
178             "length of 'device_coordinates' in 'topology' must be 'num_tasks' "
179             "* 'num_tpus_per_task' * 4 (2 * 1 * 4), got 0"),
180         std::make_tuple(
181             2, 1, TopologyWithDeviceCoordinates({-1, 0, 0, 0, 1, 0, 0, 0}),
182             std::vector<int64_t>(),
183             "device coordinate (-1, 0, 0, 0) in 'topology' is outside "
184             "of mesh shape (2, 1, 1, 1)"),
185         std::make_tuple(
186             2, 1, TopologyWithDeviceCoordinates({2, 0, 0, 0, 1, 0, 0, 0}),
187             std::vector<int64_t>(),
188             "device coordinate (2, 0, 0, 0) in 'topology' is outside "
189             "of mesh shape (2, 1, 1, 1)"),
190         std::make_tuple(
191             2, 1, TopologyWithDeviceCoordinates({0, -1, 0, 0, 1, 0, 0, 0}),
192             std::vector<int64_t>(),
193             "device coordinate (0, -1, 0, 0) in 'topology' is outside "
194             "of mesh shape (2, 1, 1, 1)"),
195         std::make_tuple(
196             2, 1, TopologyWithDeviceCoordinates({0, 1, 0, 0, 1, 0, 0, 0}),
197             std::vector<int64_t>(),
198             "device coordinate (0, 1, 0, 0) in 'topology' is outside "
199             "of mesh shape (2, 1, 1, 1)"),
200         std::make_tuple(
201             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, -1, 1, 0, 0, 0}),
202             std::vector<int64_t>(),
203             "device coordinate (0, 0, 0, -1) in 'topology' is outside "
204             "of mesh shape (2, 1, 1, 1)"),
205         std::make_tuple(
206             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 1, 1, 0, 0, 0}),
207             std::vector<int64_t>(),
208             "device coordinate (0, 0, 0, 1) in 'topology' is outside "
209             "of mesh shape (2, 1, 1, 1)"),
210         std::make_tuple(
211             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 0, 0, 0, 0}),
212             std::vector<int64_t>(),
213             "'topology' has duplicate device coordinate (0, 0, 0, 0)")));
214 
215 INSTANTIATE_TEST_SUITE_P(
216     BadGeneralDeviceAssignmentMetadata, ParameterizedMetadataTest,
217     ::testing::Values(
218         std::make_tuple(2, 1,
219                         TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
220                         std::vector<int64_t>(),
221                         "length of 'device_assignment' must be 'num_replicas' "
222                         "* 'num_cores_per_replica' * 4 (2 * 1 * 4), got 0"),
223         std::make_tuple(
224             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
225             std::vector<int64_t>{-1, 0, 0, 0, 0, 0, 0, 0},
226             "device coordinate (-1, 0, 0, 0) in 'device_assignment' "
227             "is outside of mesh shape (2, 1, 1, 1)"),
228         std::make_tuple(
229             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
230             std::vector<int64_t>{2, 0, 0, 0, 0, 0, 0, 0},
231             "device coordinate (2, 0, 0, 0) in 'device_assignment' is "
232             "outside of mesh shape (2, 1, 1, 1)"),
233         std::make_tuple(
234             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
235             std::vector<int64_t>{0, -1, 0, 0, 0, 0, 0, 0},
236             "device coordinate (0, -1, 0, 0) in 'device_assignment' "
237             "is outside of mesh shape (2, 1, 1, 1)"),
238         std::make_tuple(
239             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
240             std::vector<int64_t>{0, 1, 0, 0, 0, 0, 0, 0},
241             "device coordinate (0, 1, 0, 0) in 'device_assignment' is "
242             "outside of mesh shape (2, 1, 1, 1)"),
243         std::make_tuple(
244             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
245             std::vector<int64_t>{0, 0, 0, -1, 0, 0, 0, 0},
246             "device coordinate (0, 0, 0, -1) in 'device_assignment' "
247             "is outside of mesh shape (2, 1, 1, 1)"),
248         std::make_tuple(
249             2, 1, TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
250             std::vector<int64_t>{0, 0, 0, 1, 0, 0, 0, 0},
251             "device coordinate (0, 0, 0, 1) in 'device_assignment' is "
252             "outside of mesh shape (2, 1, 1, 1)"),
253         std::make_tuple(2, 1,
254                         TopologyWithDeviceCoordinates({0, 0, 0, 0, 1, 0, 0, 0}),
255                         std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0},
256                         "'device_assignment' has duplicate device coordinate "
257                         "(0, 0, 0, 0)")));
258 
MakeDeviceSet(int num_tasks,int num_devices_per_task)259 std::vector<std::string> MakeDeviceSet(int num_tasks,
260                                        int num_devices_per_task) {
261   std::vector<std::string> devices{
262       "/job:localhost/replica:0/task:0/device:CPU:0"};
263   devices.reserve(num_tasks * num_devices_per_task + num_tasks + 1);
264 
265   for (int task = 0; task < num_tasks; ++task) {
266     devices.push_back(
267         llvm::formatv("/job:worker/replica:0/task:{0}/device:CPU:0", task)
268             .str());
269     devices.push_back(
270         llvm::formatv("/job:worker/replica:0/task:{0}/device:TPU_SYSTEM:0",
271                       task)
272             .str());
273     for (int device = 0; device < num_devices_per_task; ++device)
274       devices.push_back(
275           llvm::formatv("/job:worker/replica:0/task:{0}/device:TPU:{1}", task,
276                         device)
277               .str());
278   }
279 
280   return devices;
281 }
282 
TEST(TPURewriteDeviceUtilTest,BadGeneralDeviceAssignmentMetadataMissingDevice)283 TEST(TPURewriteDeviceUtilTest,
284      BadGeneralDeviceAssignmentMetadataMissingDevice) {
285   tpu::TopologyProto topology_proto;
286   {
287     topology_proto.add_mesh_shape(2);
288     topology_proto.add_mesh_shape(1);
289     topology_proto.add_mesh_shape(1);
290     topology_proto.add_mesh_shape(1);
291     topology_proto.set_num_tasks(1);
292     topology_proto.set_num_tpu_devices_per_task(1);
293     topology_proto.add_device_coordinates(0);
294     topology_proto.add_device_coordinates(0);
295     topology_proto.add_device_coordinates(0);
296     topology_proto.add_device_coordinates(0);
297   }
298 
299   std::string topology_attr = topology_proto.SerializeAsString();
300   std::vector<int64_t> device_assignment_attr{1, 0, 0, 0};
301 
302   llvm::SmallVector<Device, 8> devices;
303   std::vector<std::string> device_names =
304       MakeDeviceSet(/*num_tasks=*/1, /*num_devices_per_task=*/1);
305   ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
306 
307   auto status_or = GetTPUCompilationAndExecutionDevices(
308       devices, /*num_replicas=*/1, /*num_cores_per_replica=*/1, topology_attr,
309       device_assignment_attr);
310 
311   ASSERT_FALSE(status_or.ok());
312   EXPECT_EQ(status_or.status().error_message(),
313             "no TPU device found for 'device_assignment' device coordinate (1, "
314             "0, 0, 0)");
315 }
316 
TEST(TPURewriteDeviceUtilTest,ValidFullMeshDeviceAssignment)317 TEST(TPURewriteDeviceUtilTest, ValidFullMeshDeviceAssignment) {
318   llvm::SmallVector<Device, 8> devices;
319   std::vector<std::string> device_names =
320       MakeDeviceSet(/*num_tasks=*/2, /*num_devices_per_task=*/4);
321   ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
322   std::string topology_attr;
323   std::vector<int64_t> device_assignment_attr;
324 
325   auto status_or = GetTPUCompilationAndExecutionDevices(
326       devices, /*num_replicas=*/8, /*num_cores_per_replica=*/1, topology_attr,
327       device_assignment_attr);
328 
329   TF_ASSERT_OK(status_or.status());
330 
331   const auto& tpu_device_assignment = status_or.ValueOrDie();
332   EXPECT_EQ(tpu_device_assignment.compilation_device,
333             "/job:worker/replica:0/task:0/device:CPU:0");
334   const auto& tpu_devices = tpu_device_assignment.tpu_devices;
335   ASSERT_EQ(tpu_devices.size(), 8);
336   for (const auto& replica_tpu_devices : tpu_devices)
337     ASSERT_EQ(replica_tpu_devices.size(), 1);
338 
339   EXPECT_EQ(tpu_devices[0][0].device,
340             "/job:worker/replica:0/task:0/device:TPU:0");
341   EXPECT_EQ(tpu_devices[0][0].host,
342             "/job:worker/replica:0/task:0/device:CPU:0");
343   EXPECT_EQ(tpu_devices[1][0].device,
344             "/job:worker/replica:0/task:0/device:TPU:1");
345   EXPECT_EQ(tpu_devices[1][0].host,
346             "/job:worker/replica:0/task:0/device:CPU:0");
347   EXPECT_EQ(tpu_devices[2][0].device,
348             "/job:worker/replica:0/task:0/device:TPU:2");
349   EXPECT_EQ(tpu_devices[2][0].host,
350             "/job:worker/replica:0/task:0/device:CPU:0");
351   EXPECT_EQ(tpu_devices[3][0].device,
352             "/job:worker/replica:0/task:0/device:TPU:3");
353   EXPECT_EQ(tpu_devices[3][0].host,
354             "/job:worker/replica:0/task:0/device:CPU:0");
355   EXPECT_EQ(tpu_devices[4][0].device,
356             "/job:worker/replica:0/task:1/device:TPU:0");
357   EXPECT_EQ(tpu_devices[4][0].host,
358             "/job:worker/replica:0/task:1/device:CPU:0");
359   EXPECT_EQ(tpu_devices[5][0].device,
360             "/job:worker/replica:0/task:1/device:TPU:1");
361   EXPECT_EQ(tpu_devices[5][0].host,
362             "/job:worker/replica:0/task:1/device:CPU:0");
363   EXPECT_EQ(tpu_devices[6][0].device,
364             "/job:worker/replica:0/task:1/device:TPU:2");
365   EXPECT_EQ(tpu_devices[6][0].host,
366             "/job:worker/replica:0/task:1/device:CPU:0");
367   EXPECT_EQ(tpu_devices[7][0].device,
368             "/job:worker/replica:0/task:1/device:TPU:3");
369   EXPECT_EQ(tpu_devices[7][0].host,
370             "/job:worker/replica:0/task:1/device:CPU:0");
371 
372   EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.has_value());
373 }
374 
TEST(TPURewriteDeviceUtilTest,ValidGeneralDeviceAssignmentMesh2x2x2)375 TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) {
376   tpu::TopologyProto topology_proto;
377   {
378     topology_proto.add_mesh_shape(2);
379     topology_proto.add_mesh_shape(2);
380     topology_proto.add_mesh_shape(1);
381     topology_proto.add_mesh_shape(2);
382     topology_proto.set_num_tasks(2);
383     topology_proto.set_num_tpu_devices_per_task(4);
384     topology_proto.add_device_coordinates(0);
385     topology_proto.add_device_coordinates(0);
386     topology_proto.add_device_coordinates(0);
387     topology_proto.add_device_coordinates(0);
388     topology_proto.add_device_coordinates(0);
389     topology_proto.add_device_coordinates(1);
390     topology_proto.add_device_coordinates(0);
391     topology_proto.add_device_coordinates(0);
392     topology_proto.add_device_coordinates(1);
393     topology_proto.add_device_coordinates(1);
394     topology_proto.add_device_coordinates(0);
395     topology_proto.add_device_coordinates(0);
396     topology_proto.add_device_coordinates(1);
397     topology_proto.add_device_coordinates(0);
398     topology_proto.add_device_coordinates(0);
399     topology_proto.add_device_coordinates(0);
400     topology_proto.add_device_coordinates(1);
401     topology_proto.add_device_coordinates(0);
402     topology_proto.add_device_coordinates(0);
403     topology_proto.add_device_coordinates(1);
404     topology_proto.add_device_coordinates(1);
405     topology_proto.add_device_coordinates(1);
406     topology_proto.add_device_coordinates(0);
407     topology_proto.add_device_coordinates(1);
408     topology_proto.add_device_coordinates(0);
409     topology_proto.add_device_coordinates(1);
410     topology_proto.add_device_coordinates(0);
411     topology_proto.add_device_coordinates(1);
412     topology_proto.add_device_coordinates(0);
413     topology_proto.add_device_coordinates(0);
414     topology_proto.add_device_coordinates(0);
415     topology_proto.add_device_coordinates(1);
416   }
417 
418   std::string topology_attr = topology_proto.SerializeAsString();
419   std::vector<int64_t> device_assignment_attr{0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
420                                               0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0,
421                                               0, 1, 1, 1, 0, 0, 1, 1, 0, 1};
422 
423   llvm::SmallVector<Device, 8> devices;
424   std::vector<std::string> device_names =
425       MakeDeviceSet(/*num_tasks=*/2, /*num_devices_per_task=*/4);
426   ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
427 
428   auto status_or = GetTPUCompilationAndExecutionDevices(
429       devices, /*num_replicas=*/4, /*num_cores_per_replica=*/2, topology_attr,
430       device_assignment_attr);
431 
432   TF_ASSERT_OK(status_or.status());
433 
434   const auto& tpu_device_assignment = status_or.ValueOrDie();
435   EXPECT_EQ(tpu_device_assignment.compilation_device,
436             "/job:worker/replica:0/task:0/device:CPU:0");
437   const auto& tpu_devices = tpu_device_assignment.tpu_devices;
438   ASSERT_EQ(tpu_devices.size(), 4);
439   for (const auto& replica_tpu_devices : tpu_devices)
440     ASSERT_EQ(replica_tpu_devices.size(), 2);
441 
442   EXPECT_EQ(tpu_devices[0][0].device,
443             "/job:worker/replica:0/task:0/device:TPU:0");
444   EXPECT_EQ(tpu_devices[0][0].host,
445             "/job:worker/replica:0/task:0/device:CPU:0");
446   EXPECT_EQ(tpu_devices[0][1].device,
447             "/job:worker/replica:0/task:1/device:TPU:3");
448   EXPECT_EQ(tpu_devices[0][1].host,
449             "/job:worker/replica:0/task:1/device:CPU:0");
450   EXPECT_EQ(tpu_devices[1][0].device,
451             "/job:worker/replica:0/task:0/device:TPU:1");
452   EXPECT_EQ(tpu_devices[1][0].host,
453             "/job:worker/replica:0/task:0/device:CPU:0");
454   EXPECT_EQ(tpu_devices[1][1].device,
455             "/job:worker/replica:0/task:1/device:TPU:2");
456   EXPECT_EQ(tpu_devices[1][1].host,
457             "/job:worker/replica:0/task:1/device:CPU:0");
458   EXPECT_EQ(tpu_devices[2][0].device,
459             "/job:worker/replica:0/task:0/device:TPU:3");
460   EXPECT_EQ(tpu_devices[2][0].host,
461             "/job:worker/replica:0/task:0/device:CPU:0");
462   EXPECT_EQ(tpu_devices[2][1].device,
463             "/job:worker/replica:0/task:1/device:TPU:0");
464   EXPECT_EQ(tpu_devices[2][1].host,
465             "/job:worker/replica:0/task:1/device:CPU:0");
466   EXPECT_EQ(tpu_devices[3][0].device,
467             "/job:worker/replica:0/task:0/device:TPU:2");
468   EXPECT_EQ(tpu_devices[3][0].host,
469             "/job:worker/replica:0/task:0/device:CPU:0");
470   EXPECT_EQ(tpu_devices[3][1].device,
471             "/job:worker/replica:0/task:1/device:TPU:1");
472   EXPECT_EQ(tpu_devices[3][1].host,
473             "/job:worker/replica:0/task:1/device:CPU:0");
474 
475   auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
476   ASSERT_TRUE(xla_device_assignment.has_value());
477   EXPECT_EQ(xla_device_assignment->replica_count(), 4);
478   EXPECT_EQ(xla_device_assignment->computation_count(), 2);
479   ASSERT_EQ(xla_device_assignment->computation_devices_size(), 2);
480   const auto& computation_device_0 =
481       xla_device_assignment->computation_devices(0);
482   ASSERT_EQ(computation_device_0.replica_device_ids_size(), 4);
483   const auto& computation_device_1 =
484       xla_device_assignment->computation_devices(1);
485   ASSERT_EQ(computation_device_1.replica_device_ids_size(), 4);
486 
487   EXPECT_EQ(computation_device_0.replica_device_ids(0), 0);
488   EXPECT_EQ(computation_device_0.replica_device_ids(1), 4);
489   EXPECT_EQ(computation_device_0.replica_device_ids(2), 2);
490   EXPECT_EQ(computation_device_0.replica_device_ids(3), 6);
491   EXPECT_EQ(computation_device_1.replica_device_ids(0), 1);
492   EXPECT_EQ(computation_device_1.replica_device_ids(1), 5);
493   EXPECT_EQ(computation_device_1.replica_device_ids(2), 3);
494   EXPECT_EQ(computation_device_1.replica_device_ids(3), 7);
495 }
496 
TEST(TPURewriteDeviceUtilTest,ValidGeneralDeviceAssignmentMesh1x2x1x3)497 TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
498   tpu::TopologyProto topology_proto;
499   {
500     topology_proto.add_mesh_shape(1);
501     topology_proto.add_mesh_shape(2);
502     topology_proto.add_mesh_shape(1);
503     topology_proto.add_mesh_shape(3);
504     topology_proto.set_num_tasks(3);
505     topology_proto.set_num_tpu_devices_per_task(2);
506     topology_proto.add_device_coordinates(0);
507     topology_proto.add_device_coordinates(0);
508     topology_proto.add_device_coordinates(0);
509     topology_proto.add_device_coordinates(0);
510     topology_proto.add_device_coordinates(0);
511     topology_proto.add_device_coordinates(1);
512     topology_proto.add_device_coordinates(0);
513     topology_proto.add_device_coordinates(0);
514     topology_proto.add_device_coordinates(0);
515     topology_proto.add_device_coordinates(1);
516     topology_proto.add_device_coordinates(0);
517     topology_proto.add_device_coordinates(1);
518     topology_proto.add_device_coordinates(0);
519     topology_proto.add_device_coordinates(0);
520     topology_proto.add_device_coordinates(0);
521     topology_proto.add_device_coordinates(1);
522     topology_proto.add_device_coordinates(0);
523     topology_proto.add_device_coordinates(0);
524     topology_proto.add_device_coordinates(0);
525     topology_proto.add_device_coordinates(2);
526     topology_proto.add_device_coordinates(0);
527     topology_proto.add_device_coordinates(1);
528     topology_proto.add_device_coordinates(0);
529     topology_proto.add_device_coordinates(2);
530   }
531 
532   std::string topology_attr = topology_proto.SerializeAsString();
533   std::vector<int64_t> device_assignment_attr{
534       0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 2, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0};
535 
536   llvm::SmallVector<Device, 8> devices;
537   std::vector<std::string> device_names =
538       MakeDeviceSet(/*num_tasks=*/3, /*num_devices_per_task=*/2);
539   ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices));
540 
541   auto status_or = GetTPUCompilationAndExecutionDevices(
542       devices, /*num_replicas=*/2, /*num_cores_per_replica=*/3, topology_attr,
543       device_assignment_attr);
544 
545   TF_ASSERT_OK(status_or.status());
546 
547   auto& tpu_device_assignment = status_or.ValueOrDie();
548   EXPECT_EQ(tpu_device_assignment.compilation_device,
549             "/job:worker/replica:0/task:0/device:CPU:0");
550 
551   auto& tpu_devices = tpu_device_assignment.tpu_devices;
552   ASSERT_EQ(tpu_devices.size(), 2);
553   for (const auto& replica_tpu_devices : tpu_devices)
554     ASSERT_EQ(replica_tpu_devices.size(), 3);
555 
556   EXPECT_EQ(tpu_devices[0][0].device,
557             "/job:worker/replica:0/task:1/device:TPU:1");
558   EXPECT_EQ(tpu_devices[0][0].host,
559             "/job:worker/replica:0/task:1/device:CPU:0");
560   EXPECT_EQ(tpu_devices[0][1].device,
561             "/job:worker/replica:0/task:1/device:TPU:0");
562   EXPECT_EQ(tpu_devices[0][1].host,
563             "/job:worker/replica:0/task:1/device:CPU:0");
564   EXPECT_EQ(tpu_devices[0][2].device,
565             "/job:worker/replica:0/task:2/device:TPU:0");
566   EXPECT_EQ(tpu_devices[0][2].host,
567             "/job:worker/replica:0/task:2/device:CPU:0");
568   EXPECT_EQ(tpu_devices[1][0].device,
569             "/job:worker/replica:0/task:2/device:TPU:1");
570   EXPECT_EQ(tpu_devices[1][0].host,
571             "/job:worker/replica:0/task:2/device:CPU:0");
572   EXPECT_EQ(tpu_devices[1][1].device,
573             "/job:worker/replica:0/task:0/device:TPU:0");
574   EXPECT_EQ(tpu_devices[1][1].host,
575             "/job:worker/replica:0/task:0/device:CPU:0");
576   EXPECT_EQ(tpu_devices[1][2].device,
577             "/job:worker/replica:0/task:0/device:TPU:1");
578   EXPECT_EQ(tpu_devices[1][2].host,
579             "/job:worker/replica:0/task:0/device:CPU:0");
580 
581   auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
582   ASSERT_TRUE(xla_device_assignment.has_value());
583   EXPECT_EQ(xla_device_assignment->replica_count(), 2);
584   EXPECT_EQ(xla_device_assignment->computation_count(), 3);
585   ASSERT_EQ(xla_device_assignment->computation_devices_size(), 3);
586   const auto& computation_device_0 =
587       xla_device_assignment->computation_devices(0);
588   ASSERT_EQ(computation_device_0.replica_device_ids_size(), 2);
589   const auto& computation_device_1 =
590       xla_device_assignment->computation_devices(1);
591   ASSERT_EQ(computation_device_1.replica_device_ids_size(), 2);
592   const auto& computation_device_2 =
593       xla_device_assignment->computation_devices(2);
594   ASSERT_EQ(computation_device_2.replica_device_ids_size(), 2);
595 
596   EXPECT_EQ(computation_device_0.replica_device_ids(0), 1);
597   EXPECT_EQ(computation_device_0.replica_device_ids(1), 5);
598   EXPECT_EQ(computation_device_1.replica_device_ids(0), 4);
599   EXPECT_EQ(computation_device_1.replica_device_ids(1), 0);
600   EXPECT_EQ(computation_device_2.replica_device_ids(0), 2);
601   EXPECT_EQ(computation_device_2.replica_device_ids(1), 3);
602 }
603 
TEST(TPURewriteDeviceUtilTest,TestGetDeviceCoordinates)604 TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
605   mlir::MLIRContext context;
606   mlir::Builder builder(&context);
607   auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
608   auto status_or_device_coodinates =
609       GetDeviceCoordinates(device_assignment_attr);
610   ASSERT_TRUE(status_or_device_coodinates.ok());
611   auto device_coordinates = status_or_device_coodinates.value();
612   EXPECT_EQ(device_coordinates[0], 1);
613   EXPECT_EQ(device_coordinates[1], 2);
614   EXPECT_EQ(device_coordinates[2], 3);
615 }
616 
TEST(TPURewriteDeviceUtilTest,TestInvalidAttrForDeviceAssignmentDisallowed)617 TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
618   mlir::MLIRContext context;
619   mlir::Builder builder(&context);
620   auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
621   auto status_or_device_coodinates =
622       GetDeviceCoordinates(device_assignment_attr);
623   ASSERT_TRUE(!status_or_device_coodinates.ok());
624   EXPECT_EQ(status_or_device_coodinates.status().error_message(),
625             "bad 'device_assignment' attribute at index 0, not an int");
626 }
627 
TEST(TPURewriteDeviceUtilTest,TestHasModelParallelismFalse)628 TEST(TPURewriteDeviceUtilTest, TestHasModelParallelismFalse) {
629   mlir::MLIRContext context;
630   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
631   mlir::OwningOpRef<mlir::ModuleOp> module_ref =
632       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
633   mlir::OpBuilder builder(module_ref->getBodyRegion());
634 
635   llvm::SmallVector<mlir::Type, 8> result_types;
636   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
637       mlir::UnknownLoc::get(&context), result_types);
638   cluster->setAttr(kNumCoresPerReplicaAttr,
639                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
640   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
641   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
642 
643   EXPECT_FALSE(HasModelParallelism(cluster));
644 }
645 
TEST(TPURewriteDeviceUtilTest,TestHasModelParallelismTrue)646 TEST(TPURewriteDeviceUtilTest, TestHasModelParallelismTrue) {
647   mlir::MLIRContext context;
648   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
649   mlir::OwningOpRef<mlir::ModuleOp> module_ref =
650       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
651   mlir::OpBuilder builder(module_ref->getBodyRegion());
652 
653   llvm::SmallVector<mlir::Type, 8> result_types;
654   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
655       mlir::UnknownLoc::get(&context), result_types);
656   cluster->setAttr(kNumCoresPerReplicaAttr,
657                    builder.getIntegerAttr(builder.getIntegerType(64), 5));
658   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
659   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
660 
661   EXPECT_TRUE(HasModelParallelism(cluster));
662 }
663 
TEST(TPURewriteDeviceUtilTest,TestHasModelParallelismFalseMissingCoresPerReplicaAttr)664 TEST(TPURewriteDeviceUtilTest,
665      TestHasModelParallelismFalseMissingCoresPerReplicaAttr) {
666   mlir::MLIRContext context;
667   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
668   mlir::OwningOpRef<mlir::ModuleOp> module_ref =
669       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
670   mlir::OpBuilder builder(module_ref->getBodyRegion());
671 
672   llvm::SmallVector<mlir::Type, 8> result_types;
673   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
674       mlir::UnknownLoc::get(&context), result_types);
675   cluster->setAttr(kNumCoresPerReplicaAttr,
676                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
677   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
678   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
679 
680   EXPECT_FALSE(HasModelParallelism(cluster));
681 }
682 
TEST(TPURewriteDeviceUtilTest,TestGetHostFailNumCoresPerReplicaMissingAttributes)683 TEST(TPURewriteDeviceUtilTest,
684      TestGetHostFailNumCoresPerReplicaMissingAttributes) {
685   mlir::MLIRContext context;
686   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
687   mlir::OwningOpRef<mlir::ModuleOp> module_ref =
688       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
689   mlir::OpBuilder builder(module_ref->getBodyRegion());
690   llvm::SmallVector<mlir::Type, 8> result_types;
691   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
692       mlir::UnknownLoc::get(&context), result_types);
693   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
694 
695   mlir::TF::RuntimeDevices devices;
696   std::string host_device;
697   EXPECT_TRUE(mlir::failed(
698       GetHostDeviceOutsideComputation(devices, cluster, &host_device)));
699 }
700 
TEST(TPURewriteDeviceUtilTest,TestGetHostFailDeviceMissingAttributes)701 TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
702   mlir::MLIRContext context;
703   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
704   mlir::OwningOpRef<mlir::ModuleOp> module_ref =
705       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
706   mlir::OpBuilder builder(module_ref->getBodyRegion());
707   llvm::SmallVector<mlir::Type, 8> result_types;
708   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
709       mlir::UnknownLoc::get(&context), result_types);
710   cluster->setAttr(kNumCoresPerReplicaAttr,
711                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
712 
713   mlir::TF::RuntimeDevices devices;
714   std::string host_device;
715   EXPECT_TRUE(mlir::failed(
716       GetHostDeviceOutsideComputation(devices, cluster, &host_device)));
717 }
718 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailMissingTopology)719 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) {
720   mlir::MLIRContext context;
721   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
722   mlir::OwningOpRef<mlir::ModuleOp> module_ref =
723       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
724   mlir::OpBuilder builder(module_ref->getBodyRegion());
725 
726   llvm::SmallVector<mlir::Type, 8> result_types;
727   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
728       mlir::UnknownLoc::get(&context), result_types);
729   cluster->setAttr(kNumCoresPerReplicaAttr,
730                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
731   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
732 
733   mlir::TF::RuntimeDevices runtime_devices;
734   std::string host_device;
735   EXPECT_TRUE(mlir::failed(
736       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
737 }
738 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailMissingDeviceAssignment)739 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) {
740   mlir::MLIRContext context;
741   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
742   mlir::OwningOpRef<mlir::ModuleOp> module_ref =
743       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
744   mlir::OpBuilder builder(module_ref->getBodyRegion());
745 
746   llvm::SmallVector<mlir::Type, 8> result_types;
747   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
748       mlir::UnknownLoc::get(&context), result_types);
749   cluster->setAttr(kNumCoresPerReplicaAttr,
750                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
751   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
752 
753   mlir::TF::RuntimeDevices runtime_devices;
754   std::string host_device;
755   EXPECT_TRUE(mlir::failed(
756       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
757 }
758 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailBadDeviceAssignment)759 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) {
760   mlir::MLIRContext context;
761   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
762   mlir::OwningOpRef<mlir::ModuleOp> module_ref =
763       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
764   mlir::OpBuilder builder(module_ref->getBodyRegion());
765 
766   llvm::SmallVector<mlir::Type, 8> result_types;
767   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
768       mlir::UnknownLoc::get(&context), result_types);
769   cluster->setAttr(kNumCoresPerReplicaAttr,
770                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
771   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
772   cluster->setAttr(kDeviceAssignmentAttr,
773                    builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
774                        {"bad_device_assigment"})));
775 
776   mlir::TF::RuntimeDevices runtime_devices;
777   std::string host_device;
778   EXPECT_TRUE(mlir::failed(
779       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
780 }
781 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceFailBadDeviceName)782 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) {
783   mlir::MLIRContext context;
784   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
785   mlir::OwningOpRef<mlir::ModuleOp> module_ref =
786       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
787   mlir::OpBuilder builder(module_ref->getBodyRegion());
788   (*module_ref)
789       ->setAttr("tf.devices",
790                 builder.getStrArrayAttr(
791                     llvm::ArrayRef<llvm::StringRef>({"bad_device_name"})));
792 
793   llvm::SmallVector<mlir::Type, 8> result_types;
794   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
795       mlir::UnknownLoc::get(&context), result_types);
796   cluster->setAttr(kNumCoresPerReplicaAttr,
797                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
798   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
799   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
800 
801   mlir::TF::RuntimeDevices runtime_devices;
802   (void)GetDevicesFromOp(*module_ref, &runtime_devices);
803   std::string host_device;
804   EXPECT_TRUE(mlir::failed(
805       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
806 }
807 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceTPUReplicate)808 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
809   mlir::MLIRContext context;
810   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
811   mlir::OwningOpRef<mlir::ModuleOp> module_ref =
812       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
813   mlir::OpBuilder builder(module_ref->getBodyRegion());
814 
815   llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<llvm::StringRef, 4>>
816       devices;
817   auto replicate = builder.create<mlir::tf_device::ReplicateOp>(
818       mlir::UnknownLoc::get(&context), /*num_replicas=*/2, devices,
819       llvm::ArrayRef<std::pair<mlir::ValueRange, mlir::Type>>{},
820       mlir::ValueRange{}, mlir::TypeRange{});
821   builder.setInsertionPoint(&replicate.body().front(),
822                             replicate.body().front().begin());
823 
824   llvm::SmallVector<mlir::Type, 8> result_types;
825   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
826       mlir::UnknownLoc::get(&context), result_types);
827 
828   mlir::TF::RuntimeDevices runtime_devices;
829   std::string host_device;
830   EXPECT_TRUE(mlir::succeeded(
831       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
832   EXPECT_EQ(host_device, kTPUReplicatedHost);
833 }
834 
TEST(TPURewriteDeviceUtilTest,TestGetHostDeviceNotReplicated)835 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) {
836   mlir::MLIRContext context;
837   context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
838   mlir::OwningOpRef<mlir::ModuleOp> module_ref =
839       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
840   mlir::OpBuilder builder(module_ref->getBodyRegion());
841   (*module_ref)
842       ->setAttr("tf.devices",
843                 builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
844                     {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
845                      "/job:localhost/replica:0/task:0/device:TPU:0",
846                      "/job:worker/replica:0/task:0/device:CPU:0"})));
847 
848   llvm::SmallVector<mlir::Type, 8> result_types;
849   auto cluster = builder.create<mlir::tf_device::ClusterOp>(
850       mlir::UnknownLoc::get(&context), result_types);
851   cluster->setAttr(kNumCoresPerReplicaAttr,
852                    builder.getIntegerAttr(builder.getIntegerType(64), 1));
853   cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
854   cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
855 
856   mlir::TF::RuntimeDevices runtime_devices;
857   (void)GetDevicesFromOp(*module_ref, &runtime_devices);
858   std::string host_device;
859   EXPECT_TRUE(mlir::succeeded(
860       GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
861   EXPECT_EQ(host_device, "/job:localhost/replica:0/task:0/device:CPU:0");
862 }
863 
TEST(TPURewriteDeviceUtilTest,TestIsTPUDevice)864 TEST(TPURewriteDeviceUtilTest, TestIsTPUDevice) {
865   EXPECT_TRUE(IsTPUDevice("/job:localhost/replica:0/task:0/device:TPU:0"));
866   EXPECT_FALSE(IsTPUDevice("/job:localhost/replica:0/task:0/device:CPU:0"));
867   EXPECT_FALSE(IsTPUDevice("INVALID_DEVICE"));
868 }
869 
870 }  // anonymous namespace
871 }  // namespace tensorflow
872