1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_ 17 #define TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_ 18 19 #include <functional> 20 #include <memory> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/strings/string_view.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 26 #include "tensorflow/compiler/xla/status_macros.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/core/framework/types.h" 29 30 namespace tensorflow { 31 namespace jit { 32 class DeviceInfoCache; 33 class DeviceSet; 34 35 // Instances of DeviceId represent TensorFlow devices as integers. 36 // 37 // This helps avoid having to manipulate device names as strings when 38 // auto-clustering. 39 class DeviceId { 40 public: 41 DeviceId(DeviceId&&) = default; 42 DeviceId(const DeviceId&) = default; 43 DeviceId& operator=(const DeviceId&) = default; 44 45 bool operator==(const DeviceId& other) const { return id() == other.id(); } 46 bool operator!=(const DeviceId& other) const { return !(*this == other); } 47 48 private: 49 int id_; 50 DeviceId(int id)51 explicit DeviceId(int id) : id_(id) {} 52 id()53 int id() const { return id_; } 54 55 friend class DeviceInfoCache; 56 friend class DeviceSet; 57 }; 58 59 // A set of DeviceIds, represented as a bitmap. 60 class DeviceSet { 61 public: 62 void Insert(DeviceId device_id); 63 void UnionWith(const DeviceSet& other); 64 bool IsEmpty() const; 65 66 // Calls `func` on each DeviceId in the set. Stops iterating early if `func` 67 // return false. 68 // 69 // TODO(sanjoy): Change this to take a typed std::function if that's 70 // performance neutral. 71 template <typename FnTy> ForEach(FnTy func)72 void ForEach(FnTy func) const { 73 // This is really a poor man's iterator, we should consider writing a proper 74 // iterator if this ends up being used widely. 75 for (int word_index = 0, end = storage_.size(); word_index < end; 76 word_index++) { 77 uint64 word = storage_[word_index]; 78 while (word != 0) { 79 uint64 only_lowest_bit_set = word & -word; 80 // The number of trailing zeros in a non-zero word is the index of the 81 // least significant 1. 82 int bit_index = ctz_uint64(word); 83 if (!func(DeviceId(word_index * kWordSize + bit_index))) { 84 return; 85 } 86 word ^= only_lowest_bit_set; 87 } 88 } 89 } 90 91 private: ctz_uint64(uint64 x)92 static int ctz_uint64(uint64 x) { 93 DCHECK_NE(x, 0); 94 #ifdef __GNUC__ 95 return __builtin_ctzl(x); 96 #else 97 int result = 0u; 98 while ((x & 1u) == 0u) { 99 x >>= 1; 100 ++result; 101 } 102 return result; 103 #endif 104 } 105 106 absl::InlinedVector<uint64, 1> storage_; 107 108 const int kWordSize = 64; 109 }; 110 111 // Caches some miscellaneous information about TF devices. Thread compatible. 112 class DeviceInfoCache { 113 public: IsGpu(DeviceId device)114 bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; } IsCpu(DeviceId device)115 bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; } 116 GetNameFor(DeviceId device)117 absl::string_view GetNameFor(DeviceId device) const { 118 return names_[device.id()]; 119 } 120 121 StatusOr<DeviceId> GetIdFor(absl::string_view name); 122 123 using DeviceRegistration = const XlaOpRegistry::DeviceRegistration; 124 GetCompilationDevice(DeviceId device)125 DeviceRegistration* GetCompilationDevice(DeviceId device) const { 126 return id_to_compilation_device_[device.id()]; 127 } 128 GetCompilationDevice(absl::string_view name)129 StatusOr<DeviceRegistration*> GetCompilationDevice(absl::string_view name) { 130 TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name)); 131 return GetCompilationDevice(device_id); 132 } 133 GetDeviceTypeFor(DeviceId device)134 const DeviceType& GetDeviceTypeFor(DeviceId device) const { 135 return *id_to_device_type_[device.id()]; 136 } 137 138 using DeviceTypeConstRef = std::reference_wrapper<const DeviceType>; 139 GetDeviceTypeFor(absl::string_view device_name)140 StatusOr<DeviceTypeConstRef> GetDeviceTypeFor(absl::string_view device_name) { 141 TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name)); 142 return std::cref(*id_to_device_type_[device_id.id()]); 143 } 144 145 string DebugString(const DeviceSet& device_set) const; 146 147 private: 148 absl::flat_hash_map<string, DeviceId> name_to_id_; 149 150 // These fields are populated for a device in GetIdFor, *before* we give out a 151 // DeviceId. 152 std::vector<const XlaOpRegistry::DeviceRegistration*> 153 id_to_compilation_device_; 154 std::vector<std::unique_ptr<DeviceType>> id_to_device_type_; 155 std::vector<string> names_; 156 std::vector<bool> is_cpu_; 157 std::vector<bool> is_gpu_; 158 }; 159 160 } // namespace jit 161 162 // Returns the DeviceType corresponding to 'device'. 163 Status DeviceNameToDeviceType(const string& device, DeviceType* device_type); 164 165 // Picks the device for which XLA should compile a cluster that contains 166 // operations placed in devices in `devices`. For instance a cluster that 167 // contains operations solely placed on the CPU will be compiled into a CPU 168 // executable by XLA, whereas a cluster that contains operations placed on the 169 // CPU and also operations placed on the GPU will be compiled into a GPU 170 // executable. 171 // 172 // Returns a non-OK Status if no unambiguous choice of device exists. 173 // 174 // We choose the device using the following rules: 175 // 176 // - It is an error for `device_names` to contain more than one device of the 177 // same type. 178 // - GPU is preferred over CPU. 179 // - If `allow_mixing_unknown_and_cpu` is true then unknown devices are 180 // preferred over CPU. 181 // - XLA devices count as "unrecognized devices". 182 // 183 // This set of rules above implicitly assume that XLA:GPU can compile all 184 // operations in the cluster that XLA:CPU can compile, and if 185 // `allow_mixing_unknown_and_cpu` then the unrecognized device can also compile 186 // all operations in the cluster that XLA:CPU can compile. 187 // 188 // We provide the `allow_mixing_unknown_and_cpu` knob so that we can do both of 189 // the following things: 190 // 191 // - Let MarkForCompilationPass not inject CPU-placed operations into clusters 192 // that will run on unknown devices (because the unknown XLA backend may not 193 // support every operation supported by CPU). 194 // - Let BuildXlaOpsPass successfully infer a compilation device for a cluster 195 // that contains nodes placed on both the CPU and on unknown devices. In this 196 // case it is the responsibility of the optimization pass that injected the 197 // CPU nodes into the cluster to ensure that these nodes can be compiled by 198 // the unknown XLA backend. 199 StatusOr<jit::DeviceId> PickDeviceForXla( 200 const jit::DeviceInfoCache& device_info_cache, 201 const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu); 202 203 // This is like `PickDeviceForXla` except that it returns nullopt (instead of a 204 // non-OK Status) if no unambiguous choice of device exists. 205 // 206 // We return a failing Status for errors unrelated to the device choice 207 // algorithm itself. 208 StatusOr<std::optional<jit::DeviceId>> MaybePickDeviceForXla( 209 const jit::DeviceInfoCache& device_info_cache, 210 const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu); 211 } // namespace tensorflow 212 213 #endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_ 214