xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/device_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
17 #define TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
18 
19 #include <functional>
20 #include <memory>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/core/framework/types.h"
29 
30 namespace tensorflow {
31 namespace jit {
32 class DeviceInfoCache;
33 class DeviceSet;
34 
35 // Instances of DeviceId represent TensorFlow devices as integers.
36 //
37 // This helps avoid having to manipulate device names as strings when
38 // auto-clustering.
39 class DeviceId {
40  public:
41   DeviceId(DeviceId&&) = default;
42   DeviceId(const DeviceId&) = default;
43   DeviceId& operator=(const DeviceId&) = default;
44 
45   bool operator==(const DeviceId& other) const { return id() == other.id(); }
46   bool operator!=(const DeviceId& other) const { return !(*this == other); }
47 
48  private:
49   int id_;
50 
DeviceId(int id)51   explicit DeviceId(int id) : id_(id) {}
52 
id()53   int id() const { return id_; }
54 
55   friend class DeviceInfoCache;
56   friend class DeviceSet;
57 };
58 
59 // A set of DeviceIds, represented as a bitmap.
60 class DeviceSet {
61  public:
62   void Insert(DeviceId device_id);
63   void UnionWith(const DeviceSet& other);
64   bool IsEmpty() const;
65 
66   // Calls `func` on each DeviceId in the set.  Stops iterating early if `func`
67   // return false.
68   //
69   // TODO(sanjoy): Change this to take a typed std::function if that's
70   // performance neutral.
71   template <typename FnTy>
ForEach(FnTy func)72   void ForEach(FnTy func) const {
73     // This is really a poor man's iterator, we should consider writing a proper
74     // iterator if this ends up being used widely.
75     for (int word_index = 0, end = storage_.size(); word_index < end;
76          word_index++) {
77       uint64 word = storage_[word_index];
78       while (word != 0) {
79         uint64 only_lowest_bit_set = word & -word;
80         // The number of trailing zeros in a non-zero word is the index of the
81         // least significant 1.
82         int bit_index = ctz_uint64(word);
83         if (!func(DeviceId(word_index * kWordSize + bit_index))) {
84           return;
85         }
86         word ^= only_lowest_bit_set;
87       }
88     }
89   }
90 
91  private:
ctz_uint64(uint64 x)92   static int ctz_uint64(uint64 x) {
93     DCHECK_NE(x, 0);
94 #ifdef __GNUC__
95     return __builtin_ctzl(x);
96 #else
97     int result = 0u;
98     while ((x & 1u) == 0u) {
99       x >>= 1;
100       ++result;
101     }
102     return result;
103 #endif
104   }
105 
106   absl::InlinedVector<uint64, 1> storage_;
107 
108   const int kWordSize = 64;
109 };
110 
111 // Caches some miscellaneous information about TF devices.  Thread compatible.
112 class DeviceInfoCache {
113  public:
IsGpu(DeviceId device)114   bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; }
IsCpu(DeviceId device)115   bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; }
116 
GetNameFor(DeviceId device)117   absl::string_view GetNameFor(DeviceId device) const {
118     return names_[device.id()];
119   }
120 
121   StatusOr<DeviceId> GetIdFor(absl::string_view name);
122 
123   using DeviceRegistration = const XlaOpRegistry::DeviceRegistration;
124 
GetCompilationDevice(DeviceId device)125   DeviceRegistration* GetCompilationDevice(DeviceId device) const {
126     return id_to_compilation_device_[device.id()];
127   }
128 
GetCompilationDevice(absl::string_view name)129   StatusOr<DeviceRegistration*> GetCompilationDevice(absl::string_view name) {
130     TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name));
131     return GetCompilationDevice(device_id);
132   }
133 
GetDeviceTypeFor(DeviceId device)134   const DeviceType& GetDeviceTypeFor(DeviceId device) const {
135     return *id_to_device_type_[device.id()];
136   }
137 
138   using DeviceTypeConstRef = std::reference_wrapper<const DeviceType>;
139 
GetDeviceTypeFor(absl::string_view device_name)140   StatusOr<DeviceTypeConstRef> GetDeviceTypeFor(absl::string_view device_name) {
141     TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name));
142     return std::cref(*id_to_device_type_[device_id.id()]);
143   }
144 
145   string DebugString(const DeviceSet& device_set) const;
146 
147  private:
148   absl::flat_hash_map<string, DeviceId> name_to_id_;
149 
150   // These fields are populated for a device in GetIdFor, *before* we give out a
151   // DeviceId.
152   std::vector<const XlaOpRegistry::DeviceRegistration*>
153       id_to_compilation_device_;
154   std::vector<std::unique_ptr<DeviceType>> id_to_device_type_;
155   std::vector<string> names_;
156   std::vector<bool> is_cpu_;
157   std::vector<bool> is_gpu_;
158 };
159 
160 }  // namespace jit
161 
162 // Returns the DeviceType corresponding to 'device'.
163 Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
164 
165 // Picks the device for which XLA should compile a cluster that contains
166 // operations placed in devices in `devices`.  For instance a cluster that
167 // contains operations solely placed on the CPU will be compiled into a CPU
168 // executable by XLA, whereas a cluster that contains operations placed on the
169 // CPU and also operations placed on the GPU will be compiled into a GPU
170 // executable.
171 //
172 // Returns a non-OK Status if no unambiguous choice of device exists.
173 //
174 // We choose the device using the following rules:
175 //
176 //  - It is an error for `device_names` to contain more than one device of the
177 //    same type.
178 //  - GPU is preferred over CPU.
179 //  - If `allow_mixing_unknown_and_cpu` is true then unknown devices are
180 //    preferred over CPU.
181 //  - XLA devices count as "unrecognized devices".
182 //
183 // This set of rules above implicitly assume that XLA:GPU can compile all
184 // operations in the cluster that XLA:CPU can compile, and if
185 // `allow_mixing_unknown_and_cpu` then the unrecognized device can also compile
186 // all operations in the cluster that XLA:CPU can compile.
187 //
188 // We provide the `allow_mixing_unknown_and_cpu` knob so that we can do both of
189 // the following things:
190 //
191 // - Let MarkForCompilationPass not inject CPU-placed operations into clusters
192 //   that will run on unknown devices (because the unknown XLA backend may not
193 //   support every operation supported by CPU).
194 // - Let BuildXlaOpsPass successfully infer a compilation device for a cluster
195 //   that contains nodes placed on both the CPU and on unknown devices.  In this
196 //   case it is the responsibility of the optimization pass that injected the
197 //   CPU nodes into the cluster to ensure that these nodes can be compiled by
198 //   the unknown XLA backend.
199 StatusOr<jit::DeviceId> PickDeviceForXla(
200     const jit::DeviceInfoCache& device_info_cache,
201     const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
202 
203 // This is like `PickDeviceForXla` except that it returns nullopt (instead of a
204 // non-OK Status) if no unambiguous choice of device exists.
205 //
206 // We return a failing Status for errors unrelated to the device choice
207 // algorithm itself.
208 StatusOr<std::optional<jit::DeviceId>> MaybePickDeviceForXla(
209     const jit::DeviceInfoCache& device_info_cache,
210     const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
211 }  // namespace tensorflow
212 
213 #endif  // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
214