1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/computation_placer.h"
17
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/global_device_id.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
37 #include "tensorflow/stream_executor/host/host_platform_id.h"
38 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
39
40 using absl::StrAppend;
41 using absl::StrCat;
42
43 namespace xla {
44
LogicalIdForDevice(GlobalDeviceId device_id) const45 StatusOr<DeviceAssignment::LogicalID> DeviceAssignment::LogicalIdForDevice(
46 GlobalDeviceId device_id) const {
47 std::optional<DeviceAssignment::LogicalID> logical_id;
48 for (int r = 0; r < replica_count(); ++r) {
49 for (int c = 0; c < computation_count(); ++c) {
50 if ((*this)(r, c) == device_id.value()) {
51 if (logical_id.has_value()) {
52 return InternalError(
53 "Device %d appears twice in DeviceAssignment: %s",
54 device_id.value(), ToString());
55 }
56 logical_id.emplace(DeviceAssignment::LogicalID{r, c});
57 }
58 }
59 }
60 if (logical_id.has_value()) {
61 return *logical_id;
62 } else {
63 return InternalError("Device %d doesn't appear in DeviceAssignment: %s",
64 device_id.value(), ToString());
65 }
66 }
67
ReplicaIdForDevice(GlobalDeviceId device_id) const68 StatusOr<int> DeviceAssignment::ReplicaIdForDevice(
69 GlobalDeviceId device_id) const {
70 TF_ASSIGN_OR_RETURN(const LogicalID logical_id,
71 LogicalIdForDevice(device_id));
72 return logical_id.replica_id;
73 }
74
75 absl::flat_hash_map<GlobalDeviceId, DeviceAssignment::LogicalID>
GetDeviceToLogicalIdMap() const76 DeviceAssignment::GetDeviceToLogicalIdMap() const {
77 absl::flat_hash_map<GlobalDeviceId, DeviceAssignment::LogicalID>
78 device_to_logical_id;
79 for (int r = 0; r < replica_count(); ++r) {
80 for (int c = 0; c < computation_count(); ++c) {
81 GlobalDeviceId device_id((*this)(r, c));
82 device_to_logical_id[device_id] = DeviceAssignment::LogicalID{r, c};
83 }
84 }
85 return device_to_logical_id;
86 }
87
Serialize(DeviceAssignmentProto * proto) const88 Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
89 proto->set_replica_count(replica_count());
90 proto->set_computation_count(computation_count());
91 for (int computation = 0; computation < computation_count(); ++computation) {
92 DeviceAssignmentProto::ComputationDevice* computation_device =
93 proto->add_computation_devices();
94 for (int replica = 0; replica < replica_count(); ++replica) {
95 computation_device->add_replica_device_ids((*this)(replica, computation));
96 }
97 }
98 return OkStatus();
99 }
100
101 /* static */ StatusOr<std::unique_ptr<DeviceAssignment>>
Deserialize(const DeviceAssignmentProto & proto)102 DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
103 TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
104 if (proto.replica_count() <= 0 || proto.computation_count() <= 0) {
105 return InvalidArgument(
106 "Invalid device assignment topology: replica_count=%d, "
107 "computation_count=%d",
108 proto.replica_count(), proto.computation_count());
109 }
110 auto assignment = std::make_unique<DeviceAssignment>(
111 proto.replica_count(), proto.computation_count());
112 for (int computation = 0; computation < proto.computation_count();
113 ++computation) {
114 const auto& computation_device = proto.computation_devices(computation);
115 TF_RET_CHECK(computation_device.replica_device_ids_size() ==
116 proto.replica_count());
117 for (int replica = 0; replica < proto.replica_count(); ++replica) {
118 (*assignment)(replica, computation) =
119 computation_device.replica_device_ids(replica);
120 }
121 }
122 return std::move(assignment);
123 }
124
ToString() const125 std::string DeviceAssignment::ToString() const {
126 std::string output = StrCat("Computations: ", computation_count(),
127 " Replicas: ", replica_count(), "\n");
128 for (int computation = 0; computation < computation_count(); ++computation) {
129 StrAppend(&output, "Computation ", computation, ": ");
130 for (int replica = 0; replica < replica_count(); ++replica) {
131 StrAppend(&output, operator()(replica, computation), " ");
132 }
133 StrAppend(&output, "\n");
134 }
135 return output;
136 }
137
DeviceId(int replica,int computation,int replica_count,int computation_count)138 StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
139 int replica_count,
140 int computation_count) {
141 TF_RET_CHECK(replica < replica_count);
142 TF_RET_CHECK(computation < computation_count);
143
144 return computation * replica_count + replica;
145 }
146
AssignDevices(int replica_count,int computation_count)147 StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
148 int replica_count, int computation_count) {
149 DeviceAssignment assignment(replica_count, computation_count);
150 for (int replica = 0; replica < replica_count; ++replica) {
151 for (int computation = 0; computation < computation_count; ++computation) {
152 TF_ASSIGN_OR_RETURN(
153 int device_id,
154 DeviceId(replica, computation, replica_count, computation_count));
155 assignment(replica, computation) = device_id;
156 }
157 }
158 return std::move(assignment);
159 }
160
RegisterComputationPlacer(se::Platform::Id platform_id,ComputationPlacerCreationFunction creation_function)161 /* static */ void ComputationPlacer::RegisterComputationPlacer(
162 se::Platform::Id platform_id,
163 ComputationPlacerCreationFunction creation_function) {
164 absl::MutexLock lock(&ComputationPlacer::platform_computation_placer_mutex_);
165 auto* computation_placers = GetPlatformComputationPlacers();
166 CHECK(computation_placers->find(platform_id) == computation_placers->end());
167 (*computation_placers)[platform_id].creation_function = creation_function;
168 }
169
GetForPlatform(const se::Platform * platform)170 /* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
171 const se::Platform* platform) {
172 absl::MutexLock lock(&ComputationPlacer::platform_computation_placer_mutex_);
173 auto* computation_placers = GetPlatformComputationPlacers();
174
175 auto it = computation_placers->find(platform->id());
176 if (it == computation_placers->end()) {
177 return NotFound(
178 "could not find registered computation placer for platform %s -- check "
179 "target linkage",
180 platform->Name());
181 }
182
183 if (it->second.placer == nullptr) {
184 // Lazily create the computation placer the first time it is needed.
185 it->second.placer = (*it->second.creation_function)();
186 }
187
188 return it->second.placer.get();
189 }
190
191 /* static */ absl::Mutex ComputationPlacer::platform_computation_placer_mutex_(
192 absl::kConstInit);
193
194 /* static */ std::map<se::Platform::Id, ComputationPlacer::State>*
GetPlatformComputationPlacers()195 ComputationPlacer::GetPlatformComputationPlacers() {
196 static auto* r = new std::map<se::Platform::Id, ComputationPlacer::State>;
197 return r;
198 }
199
200 } // namespace xla
201
CreateComputationPlacer()202 static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
203 return std::make_unique<xla::ComputationPlacer>();
204 }
205
InitModule()206 static bool InitModule() {
207 xla::ComputationPlacer::RegisterComputationPlacer(
208 stream_executor::host::kHostPlatformId, &CreateComputationPlacer);
209 xla::ComputationPlacer::RegisterComputationPlacer(
210 stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
211 xla::ComputationPlacer::RegisterComputationPlacer(
212 stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer);
213 return true;
214 }
215 static bool module_initialized = InitModule();
216