xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/computation_placer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/computation_placer.h"
17 
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/global_device_id.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
37 #include "tensorflow/stream_executor/host/host_platform_id.h"
38 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
39 
40 using absl::StrAppend;
41 using absl::StrCat;
42 
43 namespace xla {
44 
LogicalIdForDevice(GlobalDeviceId device_id) const45 StatusOr<DeviceAssignment::LogicalID> DeviceAssignment::LogicalIdForDevice(
46     GlobalDeviceId device_id) const {
47   std::optional<DeviceAssignment::LogicalID> logical_id;
48   for (int r = 0; r < replica_count(); ++r) {
49     for (int c = 0; c < computation_count(); ++c) {
50       if ((*this)(r, c) == device_id.value()) {
51         if (logical_id.has_value()) {
52           return InternalError(
53               "Device %d appears twice in DeviceAssignment: %s",
54               device_id.value(), ToString());
55         }
56         logical_id.emplace(DeviceAssignment::LogicalID{r, c});
57       }
58     }
59   }
60   if (logical_id.has_value()) {
61     return *logical_id;
62   } else {
63     return InternalError("Device %d doesn't appear in DeviceAssignment: %s",
64                          device_id.value(), ToString());
65   }
66 }
67 
ReplicaIdForDevice(GlobalDeviceId device_id) const68 StatusOr<int> DeviceAssignment::ReplicaIdForDevice(
69     GlobalDeviceId device_id) const {
70   TF_ASSIGN_OR_RETURN(const LogicalID logical_id,
71                       LogicalIdForDevice(device_id));
72   return logical_id.replica_id;
73 }
74 
75 absl::flat_hash_map<GlobalDeviceId, DeviceAssignment::LogicalID>
GetDeviceToLogicalIdMap() const76 DeviceAssignment::GetDeviceToLogicalIdMap() const {
77   absl::flat_hash_map<GlobalDeviceId, DeviceAssignment::LogicalID>
78       device_to_logical_id;
79   for (int r = 0; r < replica_count(); ++r) {
80     for (int c = 0; c < computation_count(); ++c) {
81       GlobalDeviceId device_id((*this)(r, c));
82       device_to_logical_id[device_id] = DeviceAssignment::LogicalID{r, c};
83     }
84   }
85   return device_to_logical_id;
86 }
87 
Serialize(DeviceAssignmentProto * proto) const88 Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
89   proto->set_replica_count(replica_count());
90   proto->set_computation_count(computation_count());
91   for (int computation = 0; computation < computation_count(); ++computation) {
92     DeviceAssignmentProto::ComputationDevice* computation_device =
93         proto->add_computation_devices();
94     for (int replica = 0; replica < replica_count(); ++replica) {
95       computation_device->add_replica_device_ids((*this)(replica, computation));
96     }
97   }
98   return OkStatus();
99 }
100 
101 /* static */ StatusOr<std::unique_ptr<DeviceAssignment>>
Deserialize(const DeviceAssignmentProto & proto)102 DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
103   TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
104   if (proto.replica_count() <= 0 || proto.computation_count() <= 0) {
105     return InvalidArgument(
106         "Invalid device assignment topology: replica_count=%d, "
107         "computation_count=%d",
108         proto.replica_count(), proto.computation_count());
109   }
110   auto assignment = std::make_unique<DeviceAssignment>(
111       proto.replica_count(), proto.computation_count());
112   for (int computation = 0; computation < proto.computation_count();
113        ++computation) {
114     const auto& computation_device = proto.computation_devices(computation);
115     TF_RET_CHECK(computation_device.replica_device_ids_size() ==
116                  proto.replica_count());
117     for (int replica = 0; replica < proto.replica_count(); ++replica) {
118       (*assignment)(replica, computation) =
119           computation_device.replica_device_ids(replica);
120     }
121   }
122   return std::move(assignment);
123 }
124 
ToString() const125 std::string DeviceAssignment::ToString() const {
126   std::string output = StrCat("Computations: ", computation_count(),
127                               " Replicas: ", replica_count(), "\n");
128   for (int computation = 0; computation < computation_count(); ++computation) {
129     StrAppend(&output, "Computation ", computation, ": ");
130     for (int replica = 0; replica < replica_count(); ++replica) {
131       StrAppend(&output, operator()(replica, computation), " ");
132     }
133     StrAppend(&output, "\n");
134   }
135   return output;
136 }
137 
DeviceId(int replica,int computation,int replica_count,int computation_count)138 StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
139                                           int replica_count,
140                                           int computation_count) {
141   TF_RET_CHECK(replica < replica_count);
142   TF_RET_CHECK(computation < computation_count);
143 
144   return computation * replica_count + replica;
145 }
146 
AssignDevices(int replica_count,int computation_count)147 StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
148     int replica_count, int computation_count) {
149   DeviceAssignment assignment(replica_count, computation_count);
150   for (int replica = 0; replica < replica_count; ++replica) {
151     for (int computation = 0; computation < computation_count; ++computation) {
152       TF_ASSIGN_OR_RETURN(
153           int device_id,
154           DeviceId(replica, computation, replica_count, computation_count));
155       assignment(replica, computation) = device_id;
156     }
157   }
158   return std::move(assignment);
159 }
160 
RegisterComputationPlacer(se::Platform::Id platform_id,ComputationPlacerCreationFunction creation_function)161 /* static */ void ComputationPlacer::RegisterComputationPlacer(
162     se::Platform::Id platform_id,
163     ComputationPlacerCreationFunction creation_function) {
164   absl::MutexLock lock(&ComputationPlacer::platform_computation_placer_mutex_);
165   auto* computation_placers = GetPlatformComputationPlacers();
166   CHECK(computation_placers->find(platform_id) == computation_placers->end());
167   (*computation_placers)[platform_id].creation_function = creation_function;
168 }
169 
GetForPlatform(const se::Platform * platform)170 /* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
171     const se::Platform* platform) {
172   absl::MutexLock lock(&ComputationPlacer::platform_computation_placer_mutex_);
173   auto* computation_placers = GetPlatformComputationPlacers();
174 
175   auto it = computation_placers->find(platform->id());
176   if (it == computation_placers->end()) {
177     return NotFound(
178         "could not find registered computation placer for platform %s -- check "
179         "target linkage",
180         platform->Name());
181   }
182 
183   if (it->second.placer == nullptr) {
184     // Lazily create the computation placer the first time it is needed.
185     it->second.placer = (*it->second.creation_function)();
186   }
187 
188   return it->second.placer.get();
189 }
190 
191 /* static */ absl::Mutex ComputationPlacer::platform_computation_placer_mutex_(
192     absl::kConstInit);
193 
194 /* static */ std::map<se::Platform::Id, ComputationPlacer::State>*
GetPlatformComputationPlacers()195 ComputationPlacer::GetPlatformComputationPlacers() {
196   static auto* r = new std::map<se::Platform::Id, ComputationPlacer::State>;
197   return r;
198 }
199 
200 }  // namespace xla
201 
CreateComputationPlacer()202 static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
203   return std::make_unique<xla::ComputationPlacer>();
204 }
205 
InitModule()206 static bool InitModule() {
207   xla::ComputationPlacer::RegisterComputationPlacer(
208       stream_executor::host::kHostPlatformId, &CreateComputationPlacer);
209   xla::ComputationPlacer::RegisterComputationPlacer(
210       stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
211   xla::ComputationPlacer::RegisterComputationPlacer(
212       stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer);
213   return true;
214 }
215 static bool module_initialized = InitModule();
216