1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class to represent a device.""" 16 17from tensorflow.python.util.tf_export import tf_export 18from tensorflow.python import pywrap_tfe 19 20# EPU represents for TPU embedding for now. Subject to change in future. 21_VALID_DEVICE_TYPES = frozenset({"CPU", "GPU", "TPU", "CUSTOM", "EPU"}) 22 23# ============================================================================== 24# == Global Implementation Details ============================================= 25# ============================================================================== 26_STRING_TO_COMPONENTS_CACHE = {} 27_COMPONENTS_TO_STRING_CACHE = {} 28 29 30def _as_str_or_none(inp): 31 return None if inp is None else str(inp) 32 33 34def _as_int_or_none(inp): 35 return None if inp is None else int(inp) 36 37 38def _as_device_str_or_none(device_type): 39 # For backwards compatibility only, we support lowercase variants of 40 # cpu and gpu but turn them into uppercase here. 41 if device_type in ("cpu", "gpu"): 42 return device_type.upper() 43 return _as_str_or_none(device_type) 44 45 46@tf_export("DeviceSpec", v1=[]) 47class DeviceSpecV2(object): 48 """Represents a (possibly partial) specification for a TensorFlow device. 49 50 `DeviceSpec`s are used throughout TensorFlow to describe where state is stored 51 and computations occur. Using `DeviceSpec` allows you to parse device spec 52 strings to verify their validity, merge them or compose them programmatically. 53 54 Example: 55 56 ```python 57 # Place the operations on device "GPU:0" in the "ps" job. 58 device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) 59 with tf.device(device_spec.to_string()): 60 # Both my_var and squared_var will be placed on /job:ps/device:GPU:0. 61 my_var = tf.Variable(..., name="my_variable") 62 squared_var = tf.square(my_var) 63 ``` 64 65 With eager execution disabled (by default in TensorFlow 1.x and by calling 66 disable_eager_execution() in TensorFlow 2.x), the following syntax 67 can be used: 68 69 ```python 70 tf.compat.v1.disable_eager_execution() 71 72 # Same as previous 73 device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) 74 # No need of .to_string() method. 75 with tf.device(device_spec): 76 my_var = tf.Variable(..., name="my_variable") 77 squared_var = tf.square(my_var) 78 ``` 79 80 If a `DeviceSpec` is partially specified, it will be merged with other 81 `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec` 82 components defined in inner scopes take precedence over those defined in 83 outer scopes. 84 85 ```python 86 gpu0_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) 87 with tf.device(DeviceSpec(job="train").to_string()): 88 with tf.device(gpu0_spec.to_string()): 89 # Nodes created here will be assigned to /job:ps/device:GPU:0. 90 with tf.device(DeviceSpec(device_type="GPU", device_index=1).to_string()): 91 # Nodes created here will be assigned to /job:train/device:GPU:1. 92 ``` 93 94 A `DeviceSpec` consists of 5 components -- each of 95 which is optionally specified: 96 97 * Job: The job name. 98 * Replica: The replica index. 99 * Task: The task index. 100 * Device type: The device type string (e.g. "CPU" or "GPU"). 101 * Device index: The device index. 102 """ 103 104 __slots__ = ("_job", "_replica", "_task", "_device_type", "_device_index", 105 "_as_string", "_hash") 106 107 def __init__(self, 108 job=None, 109 replica=None, 110 task=None, 111 device_type=None, 112 device_index=None): 113 """Create a new `DeviceSpec` object. 114 115 Args: 116 job: string. Optional job name. 117 replica: int. Optional replica index. 118 task: int. Optional task index. 119 device_type: Optional device type string (e.g. "CPU" or "GPU") 120 device_index: int. Optional device index. If left unspecified, device 121 represents 'any' device_index. 122 """ 123 self._job = _as_str_or_none(job) 124 self._replica = _as_int_or_none(replica) 125 self._task = _as_int_or_none(task) 126 self._device_type = _as_device_str_or_none(device_type) 127 self._device_index = _as_int_or_none(device_index) 128 self._as_string = self._components_to_string( 129 job=self._job, 130 replica=self._replica, 131 task=self._task, 132 device_type=self._device_type, 133 device_index=self._device_index) 134 self._hash = hash(self.to_string()) 135 136 def to_string(self): 137 """Return a string representation of this `DeviceSpec`. 138 139 Returns: 140 a string of the form 141 /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>. 142 """ 143 return self._as_string 144 145 @classmethod 146 def from_string(cls, spec): 147 """Construct a `DeviceSpec` from a string. 148 149 Args: 150 spec: a string of the form 151 /job:<name>/replica:<id>/task:<id>/device:CPU:<id> or 152 /job:<name>/replica:<id>/task:<id>/device:GPU:<id> as cpu and gpu are 153 mutually exclusive. All entries are optional. 154 155 Returns: 156 A DeviceSpec. 157 """ 158 return cls(*cls._string_to_components(spec)) 159 160 def parse_from_string(self, spec): 161 """Parse a `DeviceSpec` name into its components. 162 163 **2.x behavior change**: 164 165 In TensorFlow 1.x, this function mutates its own state and returns itself. 166 In 2.x, DeviceSpecs are immutable, and this function will return a 167 DeviceSpec which contains the spec. 168 169 * Recommended: 170 171 ``` 172 # my_spec and my_updated_spec are unrelated. 173 my_spec = tf.DeviceSpec.from_string("/CPU:0") 174 my_updated_spec = tf.DeviceSpec.from_string("/GPU:0") 175 with tf.device(my_updated_spec): 176 ... 177 ``` 178 179 * Will work in 1.x and 2.x (though deprecated in 2.x): 180 181 ``` 182 my_spec = tf.DeviceSpec.from_string("/CPU:0") 183 my_updated_spec = my_spec.parse_from_string("/GPU:0") 184 with tf.device(my_updated_spec): 185 ... 186 ``` 187 188 * Will NOT work in 2.x: 189 190 ``` 191 my_spec = tf.DeviceSpec.from_string("/CPU:0") 192 my_spec.parse_from_string("/GPU:0") # <== Will not update my_spec 193 with tf.device(my_spec): 194 ... 195 ``` 196 197 In general, `DeviceSpec.from_string` should completely replace 198 `DeviceSpec.parse_from_string`, and `DeviceSpec.replace` should 199 completely replace setting attributes directly. 200 201 Args: 202 spec: an optional string of the form 203 /job:<name>/replica:<id>/task:<id>/device:CPU:<id> or 204 /job:<name>/replica:<id>/task:<id>/device:GPU:<id> as cpu and gpu are 205 mutually exclusive. All entries are optional. 206 207 Returns: 208 The `DeviceSpec`. 209 210 Raises: 211 ValueError: if the spec was not valid. 212 """ 213 return self.from_string(spec) 214 215 def make_merged_spec(self, dev): 216 """Returns a new DeviceSpec which incorporates `dev`. 217 218 When combining specs, `dev` will take precedence over the current spec. 219 So for instance: 220 ``` 221 first_spec = tf.DeviceSpec(job=0, device_type="CPU") 222 second_spec = tf.DeviceSpec(device_type="GPU") 223 combined_spec = first_spec.make_merged_spec(second_spec) 224 ``` 225 226 is equivalent to: 227 ``` 228 combined_spec = tf.DeviceSpec(job=0, device_type="GPU") 229 ``` 230 231 Args: 232 dev: a `DeviceSpec` 233 234 Returns: 235 A new `DeviceSpec` which combines `self` and `dev` 236 """ 237 return self.__class__(*self._get_combined_properties(dev)) 238 239 def replace(self, **kwargs): 240 """Convenience method for making a new DeviceSpec by overriding fields. 241 242 For instance: 243 ``` 244 my_spec = DeviceSpec=(job="my_job", device="CPU") 245 my_updated_spec = my_spec.replace(device="GPU") 246 my_other_spec = my_spec.replace(device=None) 247 ``` 248 249 Args: 250 **kwargs: This method takes the same args as the DeviceSpec constructor 251 252 Returns: 253 A DeviceSpec with the fields specified in kwargs overridden. 254 """ 255 init_kwargs = dict( 256 job=self.job, 257 replica=self.replica, 258 task=self.task, 259 device_type=self.device_type, 260 device_index=self.device_index) 261 262 # Explicitly provided kwargs take precedence. 263 init_kwargs.update(kwargs) 264 return self.__class__(**init_kwargs) 265 266 @property 267 def job(self): 268 return self._job 269 270 @property 271 def replica(self): 272 return self._replica 273 274 @property 275 def task(self): 276 return self._task 277 278 @property 279 def device_type(self): 280 return self._device_type 281 282 @property 283 def device_index(self): 284 return self._device_index 285 286 def _get_combined_properties(self, dev): 287 """Combine the current DeviceSpec with another DeviceSpec. 288 289 The combination of DeviceSpecs is will give priority to dev. 290 291 Args: 292 dev: a `DeviceSpec` 293 294 Returns: 295 A tuple of (job, replica, task, device_type, device_index) which 296 represents the combination of self and dev. 297 """ 298 return ( 299 dev.job if dev.job is not None else self.job, 300 dev.replica if dev.replica is not None else self.replica, 301 dev.task if dev.task is not None else self.task, 302 dev.device_type if dev.device_type is not None else self.device_type, 303 dev.device_index if dev.device_index is not None else self.device_index, 304 ) 305 306 @staticmethod 307 def _get_valid_device_types(): 308 valid_device_types = set({}) 309 physical_devices = pywrap_tfe.TF_ListPluggablePhysicalDevices() 310 for device in physical_devices: 311 valid_device_types.add(device.decode().split(":")[1]) 312 valid_device_types = valid_device_types | _VALID_DEVICE_TYPES 313 return valid_device_types 314 315 @staticmethod 316 def _string_to_components(spec=None): 317 """Stateless portion of device spec string parsing. 318 319 Args: 320 spec: An optional string specifying a device specification. 321 322 Returns: 323 The parsed components of `spec`. Note that the result of this function 324 must go through attribute setters of DeviceSpec, and should therefore NOT 325 be used directly. 326 """ 327 cached_result = _STRING_TO_COMPONENTS_CACHE.get(spec) 328 if cached_result is not None: 329 return cached_result 330 331 raw_spec = spec # keep a copy of the original to update the cache 332 job, replica, task, device_type, device_index = None, None, None, None, None 333 334 spec = spec or "" 335 splits = [x.split(":") for x in spec.split("/")] 336 valid_device_types = DeviceSpecV2._get_valid_device_types() 337 for y in splits: 338 ly = len(y) 339 if y: 340 # NOTE(taylorrobie): these will go through setters later. 341 if ly == 2 and y[0] == "job": 342 job = y[1] 343 elif ly == 2 and y[0] == "replica": 344 replica = y[1] 345 elif ly == 2 and y[0] == "task": 346 task = y[1] 347 elif ((ly == 1 or ly == 2) and (y[0].upper() in valid_device_types)): 348 if device_type is not None: 349 raise ValueError(f"Multiple device types are not allowed " 350 f"while parsing the device spec: {spec}.") 351 device_type = y[0].upper() 352 if ly == 2 and y[1] != "*": 353 device_index = int(y[1]) 354 elif ly == 3 and y[0] == "device": 355 if device_type is not None: 356 raise ValueError(f"Multiple device types are not allowed " 357 f"while parsing the device spec: {spec}.") 358 device_type = y[1] 359 if y[2] != "*": 360 device_index = int(y[2]) 361 elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison 362 raise ValueError(f"Unknown attribute '{y[0]}' is encountered " 363 f"while parsing the device spec: '{spec}'.") 364 365 output = (job, replica, task, device_type, device_index) 366 _STRING_TO_COMPONENTS_CACHE[raw_spec] = output 367 return output 368 369 @staticmethod 370 def _components_to_string(job, replica, task, device_type, device_index): 371 """Stateless portion of `to_string` (separated to allow caching).""" 372 key = (job, replica, task, device_type, device_index) 373 cached_result = _COMPONENTS_TO_STRING_CACHE.get(key) 374 if cached_result is not None: 375 return cached_result 376 377 output = [] 378 if job is not None: 379 output.append("/job:" + job) 380 if replica is not None: 381 output.append("/replica:" + str(replica)) 382 if task is not None: 383 output.append("/task:" + str(task)) 384 if device_type is not None: 385 device_index_string = "*" 386 if device_index is not None: 387 # Unlike the others, device_index is stored as an int. 388 device_index_string = str(device_index) 389 output.append("/device:%s:%s" % (device_type, device_index_string)) 390 391 output = "".join(output) 392 _COMPONENTS_TO_STRING_CACHE[key] = output 393 return output 394 395 def __eq__(self, other): 396 """Checks if the `other` DeviceSpec is same as the current instance, eg have 397 398 same value for all the internal fields. 399 400 Args: 401 other: Another DeviceSpec 402 403 Returns: 404 Return `True` if `other` is also a DeviceSpec instance and has same value 405 as the current instance. 406 Return `False` otherwise. 407 """ 408 return (isinstance(other, self.__class__) and 409 self.to_string() == other.to_string()) 410 411 def __hash__(self): 412 return self._hash 413 414 415@tf_export(v1=["DeviceSpec"]) # pylint: disable=missing-docstring 416class DeviceSpecV1(DeviceSpecV2): 417 __doc__ = DeviceSpecV2.__doc__ 418 __slots__ = DeviceSpecV2.__slots__ 419 420 @DeviceSpecV2.job.setter 421 def job(self, job): 422 self._job = _as_str_or_none(job) 423 self._as_string, self._hash = None, None 424 425 @DeviceSpecV2.replica.setter 426 def replica(self, replica): 427 self._replica = _as_int_or_none(replica) 428 self._as_string, self._hash = None, None 429 430 @DeviceSpecV2.task.setter 431 def task(self, task): 432 self._task = _as_int_or_none(task) 433 self._as_string, self._hash = None, None 434 435 @DeviceSpecV2.device_type.setter 436 def device_type(self, device_type): 437 self._device_type = _as_device_str_or_none(device_type) 438 self._as_string, self._hash = None, None 439 440 @DeviceSpecV2.device_index.setter 441 def device_index(self, device_index): 442 self._device_index = _as_int_or_none(device_index) 443 self._as_string, self._hash = None, None 444 445 def __hash__(self): 446 if self._hash is None: 447 self._hash = hash(self.to_string()) 448 return self._hash 449 450 def to_string(self): 451 if self._as_string is None: 452 self._as_string = self._components_to_string( 453 job=self.job, 454 replica=self.replica, 455 task=self.task, 456 device_type=self.device_type, 457 device_index=self.device_index) 458 return self._as_string 459 460 def parse_from_string(self, spec): 461 (self.job, self.replica, self.task, self.device_type, 462 self.device_index) = self._string_to_components(spec) 463 464 return self 465 466 def merge_from(self, dev): 467 """Merge the properties of "dev" into this `DeviceSpec`. 468 469 Note: Will be removed in TensorFlow 2.x since DeviceSpecs will become 470 immutable. 471 472 Args: 473 dev: a `DeviceSpec`. 474 """ 475 (self.job, self.replica, self.task, self.device_type, 476 self.device_index) = self._get_combined_properties(dev) 477 478 # Use parent class docstrings for public methods. 479 to_string.__doc__ = DeviceSpecV2.to_string.__doc__ 480 parse_from_string.__doc__ = DeviceSpecV2.parse_from_string.__doc__ 481