1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Propagates information about tensor layouts across operations.""" 16 17import contextlib 18import logging 19import os 20import threading 21from typing import Any, List, Sequence, Set 22 23import numpy as np 24 25from tensorflow.core.framework import attr_value_pb2 26from tensorflow.dtensor.python import gen_dtensor_ops 27from tensorflow.dtensor.python import layout as layout_lib 28from tensorflow.python import _pywrap_dtensor_device 29from tensorflow.python.eager import context 30from tensorflow.python.eager import core 31from tensorflow.python.framework import device as tf_device 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import sparse_tensor 35from tensorflow.python.ops import resource_variable_ops 36 37_DT_CLIENT_ID = "DTENSOR_CLIENT_ID" 38_DT_NUM_CLIENTS = "DTENSOR_NUM_CLIENTS" 39_DT_JOB_NAME = "DTENSOR_JOB_NAME" 40 41# TODO(allenl): Allow something other than "CUSTOM" so we don't need device 42# numbering hacks to avoid collisions between parallel devices and dtensor 43# devices. 44_next_device_number = 0 45_next_device_number_lock = threading.Lock() 46 47 48class DTensorDevice(object): 49 """Wraps a custom device which attempts to propagate tensor layouts.""" 50 51 def __init__(self, meshes: List[layout_lib.Mesh], is_async=True): 52 """Create a new DTensorDevice which executes ops on `underlying_device`. 53 54 Args: 55 meshes: A list of `Mesh` objects indicating groups of devices to execute 56 on. These may also be registered lazily. 57 is_async: Indicates whether DTensor operations on this client will return 58 immediately (with "non-ready" handles) or block until executed. This is 59 on by default and is exposed as an option for ease of debugging. 60 """ 61 if any(not isinstance(mesh, layout_lib.Mesh) for mesh in meshes): 62 raise TypeError( 63 "Expected a flat list of Mesh objects, got {}".format(meshes)) 64 global _next_device_number 65 ctx = context.context() 66 with _next_device_number_lock: 67 self.name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(), 68 _next_device_number) 69 _next_device_number += 1 70 device, device_info = _pywrap_dtensor_device.Allocate(self.name) 71 context.register_custom_device(device, self.name, device_info) 72 73 self._device_info = device_info 74 self._current_output_layout = None 75 self._current_default_mesh = None 76 self._is_async = is_async 77 self._meshes = set() 78 self._mesh_lock = threading.Lock() 79 for mesh in meshes: 80 self._register_mesh(mesh) 81 82# FIXME(b/241819185): Reuse the logic in api.py. 83# LINT.IfChange 84 def _num_clients(self): 85 """Returns number of clients in current DTensor cluster.""" 86 # If missing, 1 is a good default. 87 return int(os.environ.get(_DT_NUM_CLIENTS, "1")) 88# LINT.ThenChange(//tensorflow/dtensor/cc/dtensor_utils.cc) 89 90# LINT.IfChange 91 def _client_id(self): 92 """Returns current client ID (int) in current DTensor cluster.""" 93 return int(os.environ.get(_DT_CLIENT_ID, "0")) 94# LINT.ThenChange(//tensorflow/dtensor/cc/dtensor_utils.cc) 95 96 def _job_name(self): 97 """Returns the DTensor Borg job name.""" 98 # If missing, the program is likely running locally or on Forge. 99 return os.environ.get(_DT_JOB_NAME, 100 "localhost" if self._num_clients() == 1 else "worker") 101 102 def _full_job_name(self): 103 """Returns the fully qualified TF job name for this task.""" 104 return self._job_name() + "/replica:0/task:" + str(self._client_id()) 105 106 def _create_host_array(self, shape, host_id): 107 """Returns ID and device lists that can be used to create a host mesh.""" 108 num_global_devices = np.prod(shape) 109 global_device_ids = np.arange(num_global_devices).reshape(shape) 110 local_device_list = [ 111 tf_device.DeviceSpec( 112 job=self._full_job_name(), device_type="CPU", device_index=0) 113 ] 114 num_local_devices = len(local_device_list) 115 local_device_ids = [ 116 x + host_id * num_local_devices for x in range(num_local_devices) 117 ] 118 return global_device_ids, local_device_ids, local_device_list 119 120 def _create_embedding_host_mesh(self, tpu_mesh: layout_lib.Mesh): 121 """Returns Embedding host mesh for each client.""" 122 if tpu_mesh.device_type().upper() != "TPU": 123 raise ValueError("Must pass input of a tpu mesh.") 124 125 # Global device ids are global host ids, while local device ids contains 126 # local host id. 127 128 ts_local_device_ids = [] 129 ts_local_devices = [] 130 for local_device_str in tpu_mesh.local_devices(): 131 # We only need to keep TPU:0 for each client. 132 if not local_device_str.endswith("TPU:0"): 133 continue 134 135 device_spec = tf_device.DeviceSpec.from_string(local_device_str) 136 ts_local_device_ids.append(device_spec.task) 137 ts_local_devices.append(device_spec.replace(device_type="CPU")) 138 139 if not ts_local_device_ids or not ts_local_device_ids: 140 logging.info( 141 "Cannot create tpu system mesh as %s has no `TPU:0` local device " 142 "found", tpu_mesh.to_string()) 143 return None 144 145 ts_global_device_ids = np.arange(self._num_clients()) 146 # TODO(zhonglinhan): parse global device specs as input when not None. 147 return layout_lib.Mesh( 148 dim_names=[tpu_mesh.dim_names[0]], # 1D mesh. 149 global_device_ids=ts_global_device_ids, 150 local_device_ids=ts_local_device_ids, 151 local_devices=ts_local_devices) 152 153 def _register_mesh(self, mesh: layout_lib.Mesh): 154 """Idempotently register `mesh` with the dtensor device.""" 155 with self._mesh_lock: 156 if mesh not in self._meshes: 157 _pywrap_dtensor_device.AddMesh(self._device_info, mesh.to_string(), 158 self._is_async, False) 159 self._meshes.add(mesh) 160 if mesh.device_type().upper() == "TPU": 161 logging.info( 162 "Registering virtual 1:1 mapped host mesh %s for mesh %s", 163 mesh.host_mesh().to_string(), mesh.to_string()) 164 _pywrap_dtensor_device.AddMesh(self._device_info, 165 mesh.host_mesh().to_string(), 166 self._is_async, True) 167 self._meshes.add(mesh.host_mesh()) 168 embedding_host_mesh = self._create_embedding_host_mesh(mesh) 169 if embedding_host_mesh: 170 logging.info( 171 "Registering embedding host mesh %s on each client for mesh %s", 172 embedding_host_mesh.to_string(), mesh.to_string()) 173 _pywrap_dtensor_device.AddMesh(self._device_info, 174 embedding_host_mesh.to_string(), 175 self._is_async, False) 176 self._meshes.add(embedding_host_mesh) 177 178 @property 179 def meshes(self) -> Set[layout_lib.Mesh]: 180 return self._meshes 181 182 def copy_to_mesh(self, tensor, new_layout, source_layout=None) -> ops.Tensor: 183 """Copy `tensor` to `device` with the given layout.""" 184 self._register_mesh(new_layout.mesh) 185 with ops.device(self.name): 186 return gen_dtensor_ops.copy_to_mesh( 187 tensor, 188 layout=new_layout.to_string(), 189 source_layout=source_layout.to_string() if source_layout else "") 190 191 def pack(self, tensors: Sequence[Any], layout: layout_lib.Layout) -> Any: 192 """Packs tensors into a DTensor handle on this DTensor device. 193 194 Packing and unpacking are inverse operations: 195 196 ``` 197 * unpack(pack(tensors)) == tensors 198 * pack(unpack(dtensor)) == dtensor 199 ``` 200 201 Refer to `dtensor.pack` for more information. 202 203 Args: 204 tensors: The list of tensors to pack into a DTensor. 205 layout: The layout of the DTensor to be created. 206 207 Returns: 208 A DTensor created from the individual component tensors. 209 210 Raises: 211 RuntimeError: When not called eagerly. 212 """ 213 if not context.executing_eagerly(): 214 raise RuntimeError("Pack must be called eagerly.") 215 if any( 216 issubclass(type(t), resource_variable_ops.BaseResourceVariable) 217 for t in tensors): 218 raise TypeError( 219 "Received Variable input to Pack, Variable is not supported.") 220 self._register_mesh(layout.mesh) 221 with ops.device(self.name): 222 if all(isinstance(t, sparse_tensor.SparseTensor) for t in tensors): 223 if not all(t.shape == tensors[0].shape for t in tensors): 224 raise TypeError("All input SparseTensors to Pack must be same shape.") 225 is_sparse = True 226 tensors = [t.indices for t in tensors] + [t.values for t in tensors] + [ 227 ops.convert_to_tensor(t.shape, dtype=dtypes.int64) for t in tensors 228 ] 229 elif any(isinstance(t, sparse_tensor.SparseTensor) for t in tensors): 230 raise TypeError("Cannot Pack SparseTensors with Tensors.") 231 else: 232 is_sparse = False 233 try: 234 return _pywrap_dtensor_device.Pack( 235 context.context()._handle, # pylint: disable=protected-access 236 tensors, 237 layout.to_string(), 238 self._device_info, 239 is_sparse) 240 except core._NotOkStatusException as e: # pylint: disable=protected-access 241 raise core._status_to_exception(e) from None # pylint: disable=protected-access 242 243 def unpack(self, dtensor: Any) -> Sequence[Any]: 244 """Unpacks a DTensor handle on this DTensor device. 245 246 Packing and unpacking are inverse operations: 247 248 ``` 249 * unpack(pack(tensors)) == tensors 250 * pack(unpack(dtensor)) == dtensor 251 ``` 252 253 Refer to `dtensor.unpack` for more information. 254 255 Args: 256 dtensor: The DTensor to unpack. 257 258 Returns: 259 The raw underlying tensor components of the DTensor. 260 261 Raises: 262 RuntimeError: When not called eagerly. 263 """ 264 if not context.executing_eagerly(): 265 raise RuntimeError("Unpack must be called eagerly.") 266 if issubclass(type(dtensor), resource_variable_ops.BaseResourceVariable): 267 raise TypeError( 268 "Received Variable input to unpack, Variable is not supported.") 269 try: 270 tensors = _pywrap_dtensor_device.Unpack( 271 context.context()._handle, # pylint: disable=protected-access 272 dtensor, 273 self._device_info) 274 except core._NotOkStatusException as e: # pylint: disable=protected-access 275 raise core._status_to_exception(e) from None # pylint: disable=protected-access 276 277 is_sparse = _pywrap_dtensor_device.IsSparseDTensor( 278 context.context()._handle, # pylint: disable=protected-access. 279 dtensor, 280 self._device_info) 281 if is_sparse: 282 result = [] 283 for i in range(len(tensors) // 3): 284 result.append( 285 sparse_tensor.SparseTensor(tensors[i], 286 tensors[i + len(tensors) // 3], 287 tensors[i + 2 * len(tensors) // 3])) 288 return result 289 else: 290 return tensors 291 292 def fetch_layout(self, dtensor: Any) -> layout_lib.Layout: 293 """Fetches the layout of the DTensor. 294 295 Args: 296 dtensor: The DTensor whose layout is to be fetched. 297 298 Returns: 299 The `Layout` of this DTensor. 300 301 Raises: 302 RuntimeError: When not called eagerly. 303 """ 304 if not context.executing_eagerly(): 305 raise RuntimeError("FetchLayout must be called eagerly.") 306 if issubclass(type(dtensor), resource_variable_ops.BaseResourceVariable): 307 dtensor = dtensor.read_value() 308 try: 309 layout_string = _pywrap_dtensor_device.FetchLayout( 310 context.context()._handle, # pylint: disable=protected-access 311 dtensor, 312 self._device_info) 313 except core._NotOkStatusException as e: # pylint: disable=protected-access 314 raise core._status_to_exception(e) from None # pylint: disable=protected-access 315 return layout_lib.Layout.from_string(layout_string) 316 317 def set_same_shape_policy(self, enabled): 318 """Guess layouts using the layouts of other tensors with the same shape. 319 320 This is the default behavior, and is quite safe. The `default_layout` scope 321 overrides shape-based guesses. 322 323 Args: 324 enabled: A boolean indicating whether to use the policy. 325 """ 326 _pywrap_dtensor_device.SetSameShapePolicy(self._device_info, enabled) 327 328 def set_tpu_core_ids(self, mesh_name, tpu_core_ids): 329 """Sets the singleton global device ID-to-physical core ID map. 330 331 Args: 332 mesh_name: The name of a mesh. If empty, set the default mapping. 333 tpu_core_ids: TPU core IDs sorted by TF task/device ordinal. 334 """ 335 _pywrap_dtensor_device.SetTPUCoreIDs(self._device_info, mesh_name, 336 tpu_core_ids) 337 338 def clear_tpu_core_ids(self): 339 _pywrap_dtensor_device.ClearTPUCoreIDs(self._device_info) 340 341 def tpu_core_ids_to_locations(self, tpu_core_ids): 342 """Translates TPU core IDs to TPU core locations. 343 344 Args: 345 tpu_core_ids: A list of TPU core IDs. Each one is an unsigned integer. 346 347 Returns: 348 A list of corresponding TPU core locations. 349 """ 350 return _pywrap_dtensor_device.TPUCoreIDsToLocations( 351 context.context()._handle, # pylint: disable=protected-access 352 self._device_info, 353 tpu_core_ids) 354 355 def tpu_core_locations_to_ids(self, tpu_core_locations): 356 """Translates TPU core locations to TPU core IDs. 357 358 Args: 359 tpu_core_locations: A list of TPU core locations. Each one is a list of 360 four unsigned integers, [x, y, z, core]. 361 362 Returns: 363 A list of corresponding TPU core IDs. 364 """ 365 return _pywrap_dtensor_device.TPUCoreLocationsToIDs( 366 context.context()._handle, # pylint: disable=protected-access 367 self._device_info, 368 tpu_core_locations) 369 370 def _get_function_cache_hit_and_miss_count(self): 371 """Returns the number of cache hit and miss for function compilation. 372 373 Returns: 374 A dictionary keyed with miss and hit, corresponding to the cache hit and 375 miss count. 376 """ 377 return _pywrap_dtensor_device.GetFunctionCacheHitAndMissCount( 378 context.context()._handle, # pylint: disable=protected-access, 379 self._device_info) 380 381 @contextlib.contextmanager 382 def _experimental_default_mesh(self, mesh: layout_lib.Mesh): 383 """Sets a default mesh for all ops in the scope. 384 385 Note: This is an internal helper method, which is not user facing api. 386 387 Useful for requesting a specific mesh for ops which would have no inferred 388 layout, e.g. tf.zeros. 389 390 Args: 391 mesh: A Mesh to be used for ops without Mesh. 392 393 Yields: 394 Nothing. 395 """ 396 previous_default = self._current_default_mesh 397 self._register_mesh(mesh) 398 _pywrap_dtensor_device.ExperimentalSetDefaultMesh( 399 self._device_info, 400 mesh.to_string().encode("utf-8")) 401 self._current_default_mesh = mesh 402 yield 403 _pywrap_dtensor_device.ExperimentalClearDefaultMesh(self._device_info) 404 if previous_default: 405 _pywrap_dtensor_device.ExperimentalSetDefaultMesh( 406 self._device_info, 407 previous_default.to_string().encode("utf-8")) 408 self._current_default_mesh = previous_default 409 410 @contextlib.contextmanager 411 def _default_layout(self, layout: layout_lib.Layout): 412 """Sets a default output layout for all ops in the scope. 413 414 Note: This is an internal helper method, which is not user facing api. 415 416 Useful for requesting a specific layout for ops which would have no inferred 417 layout, e.g. tf.zeros. 418 419 Caveats: 420 421 - Currently only affects the first output of an op. For Op with multiple 422 outputs, this does not support yet. 423 424 - All Ops in the scope will be attached with the same layout. This might not 425 be valid as the rank is different. The current suggestion is: Try to wrap 426 the raw op wheneven possible. 427 428 Args: 429 layout: A Layout for the outputs of all operations in this scope. 430 431 Yields: 432 Nothing. 433 """ 434 previous_default = None 435 previous_graph_size = None 436 graph = None 437 438 self._register_mesh(layout.mesh) 439 try: 440 previous_default = self._current_output_layout 441 self._current_output_layout = layout.to_string().encode("utf-8") 442 _pywrap_dtensor_device.ExperimentalSetDefaultLayout( 443 self._device_info, self._current_output_layout) 444 if context.executing_eagerly(): 445 with ops.device(self.name): 446 yield 447 else: 448 # Custom devices currently don't affect graph building, so we need a 449 # separate way to indicate layouts. 450 # 451 # TODO(allenl): Remove this case once the DTensor device is active 452 # during tracing. 453 graph = ops.get_default_graph() 454 previous_graph_size = len(graph.get_operations()) 455 yield 456 finally: 457 if graph is not None: 458 # Tag operations added under this scope 459 for operation in graph.get_operations()[previous_graph_size:]: 460 # Set layout directly on the Op itself. 461 operation._set_attr( # pylint: disable=protected-access 462 "_layout", 463 attr_value_pb2.AttrValue( 464 list=attr_value_pb2.AttrValue.ListValue( 465 s=[self._current_output_layout]))) 466 operation._set_attr( # pylint: disable=protected-access 467 "_mesh", 468 attr_value_pb2.AttrValue( 469 s=layout.mesh.to_string().encode("utf-8"))) 470 471 self._current_output_layout = previous_default 472 if self._current_output_layout is None: 473 _pywrap_dtensor_device.ExperimentalClearDefaultLayout(self._device_info) 474 else: 475 _pywrap_dtensor_device.ExperimentalSetDefaultLayout( 476 self._device_info, self._current_output_layout.decode("utf-8")) 477