xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/computation_placer.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_
18 
19 #include <map>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "tensorflow/compiler/xla/array2d.h"
26 #include "tensorflow/compiler/xla/service/global_device_id.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/stream_executor/platform.h"
32 
33 namespace xla {
34 
35 // Class that represents the device assignment for a set of XLA replicated
36 // computations. For R replicas and C computations, R * C devices are required
37 // execute the computation in parallel. The assigned device ids can be accessed
38 // by assignment(replica, computation).
39 class DeviceAssignment : public Array2D<int> {
40  public:
DeviceAssignment()41   DeviceAssignment() {}
DeviceAssignment(int replica_count,int computation_count)42   DeviceAssignment(int replica_count, int computation_count)
43       : Array2D<int>(replica_count, computation_count, -1) {
44     CHECK_GT(replica_count, 0);
45     CHECK_GT(computation_count, 0);
46   }
47 
replica_count()48   int replica_count() const { return height(); }
computation_count()49   int computation_count() const { return width(); }
50 
51   // The logical ID of a device is its (replica ID, computation ID) pair.
52   struct LogicalID {
53     int replica_id;
54     int computation_id;
55   };
56 
57   // Finds the (replica ID, computation ID) pair for the given device.
58   StatusOr<LogicalID> LogicalIdForDevice(GlobalDeviceId device_id) const;
59   // Finds the replica ID for the given device.
60   StatusOr<int> ReplicaIdForDevice(GlobalDeviceId device_id) const;
61   // Returns a map from device ID to logical ID. Querying this map is much more
62   // efficient than `LogicalIdForDevice` if queried repeatedly.
63   absl::flat_hash_map<GlobalDeviceId, LogicalID> GetDeviceToLogicalIdMap()
64       const;
65 
66   // Protocol buffer serialization and deserialization.
67   Status Serialize(DeviceAssignmentProto* proto) const;
68 
69   // Return a std::unique_ptr<DeviceAssignment> instead of a DeviceAssignment
70   // directly because one of the supported TF platforms (mac) does not compile
71   // due to a StatusOr of an incomplete type (DeviceAssignment).
72   static StatusOr<std::unique_ptr<DeviceAssignment>> Deserialize(
73       const DeviceAssignmentProto& proto);
74 
75   std::string ToString() const;
76 };
77 
78 // A generic implementation of the XLA computation placer, which assigns device
79 // ids to a set of replicated computations.
80 class ComputationPlacer {
81  public:
ComputationPlacer()82   ComputationPlacer() {}
~ComputationPlacer()83   virtual ~ComputationPlacer() {}
84 
85   // Returns the device id assigned to the given replica and computation
86   // instance for [replica_count x computation_count] setup. The returned device
87   // id must match the assignment from PlaceReplicatedComputation().
88   virtual StatusOr<int> DeviceId(int replica, int computation,
89                                  int replica_count, int computation_count);
90 
91   // Returns the device ids assigned to a set of replicated computations, given
92   // the number of replicas and the number of computations.
93   virtual StatusOr<DeviceAssignment> AssignDevices(int replica_count,
94                                                    int computation_count);
95 
96   using ComputationPlacerCreationFunction =
97       std::unique_ptr<ComputationPlacer> (*)();
98 
99   // Registers a computation placer creation function for a particular platform.
100   static void RegisterComputationPlacer(
101       se::Platform::Id platform_id,
102       ComputationPlacerCreationFunction creation_function);
103 
104   // Returns the computation placer singleton pointer if it is available for the
105   // given platform, or an error status if it is not.
106   static StatusOr<ComputationPlacer*> GetForPlatform(
107       const se::Platform* platform);
108 
109  private:
110   // The mutex that guards the platform-to-computation placer map.
111   static absl::Mutex platform_computation_placer_mutex_;
112 
113   // State kept for each kind of ComputationPlacer. Registration functions set
114   // up creation_function, and then we use that to lazily create "placer" the
115   // first time GetForPlatform is invoked for a particular id.
116   struct State {
117     std::unique_ptr<ComputationPlacer> placer;
118     ComputationPlacerCreationFunction creation_function = nullptr;
119   };
120 
121   // Map from platform kind to computation placer singleton.
122   static std::map<se::Platform::Id, State>* GetPlatformComputationPlacers();
123 
124   ComputationPlacer(const ComputationPlacer&) = delete;
125   ComputationPlacer& operator=(const ComputationPlacer&) = delete;
126 };
127 
128 }  // namespace xla
129 
130 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_
131