1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A variable which packs a list of variables distributed across devices.""" 16 17from tensorflow.python.distribute import device_util 18from tensorflow.python.eager import context 19from tensorflow.python.framework import ops 20from tensorflow.python.ops import math_ops 21from tensorflow.python.ops import resource_variable_ops 22 23 24class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable): 25 """A variable which packs multiple variables distributed across devices. 26 27 It's only supported when eager execution is enabled. 28 For op-by-op execution, use an unpacked handle on the current device; for 29 function execution, use the packed handle to reduce the overhead of function 30 calls. 31 """ 32 33 def __init__(self, distributed_variables=None, name=None, **unused_kwargs): 34 """Packs a list of variables which are distributed across devices. 35 36 Args: 37 distributed_variables: A list of distributed Variables to pack. 38 name: Optional name for the variable. Defaults to `'Variable'` and gets 39 uniquified automatically. 40 """ 41 if not ops.executing_eagerly_outside_functions(): 42 raise ValueError( 43 "PackedDistributedVariable should be created in eager mode.") 44 if not distributed_variables: 45 raise ValueError("Expect a non-empty list of variables to pack.") 46 for i, var in enumerate(distributed_variables): 47 if not resource_variable_ops.is_resource_variable(var): 48 raise ValueError("Expect a list of ResourceVariables to pack, " 49 "but the %d-th variable is %s" % (i, type(var))) 50 51 self._distributed_variables = distributed_variables 52 self._devices = [v.device for v in distributed_variables] 53 with ops.init_scope(): 54 with ops.name_scope(name, "Variable", skip_on_eager=False) as name: 55 handle = ops.pack_eager_tensors( 56 [var.handle for var in distributed_variables]) 57 handle_name = ops.name_from_scope_name(name) 58 unique_id = "%s_%d" % (handle_name, ops.uid()) 59 super(PackedDistributedVariable, self).__init__( 60 trainable=distributed_variables[0].trainable, 61 shape=distributed_variables[0].shape, 62 dtype=distributed_variables[0].dtype, 63 handle=handle, 64 synchronization=distributed_variables[0].synchronization, 65 constraint=distributed_variables[0].constraint, 66 aggregation=distributed_variables[0].aggregation, 67 distribute_strategy=distributed_variables[0]._distribute_strategy, # pylint: disable=protected-access 68 name=name, 69 unique_id=unique_id, 70 handle_name=handle_name, 71 graph_element=None, 72 initial_value=None, 73 initializer_op=None, 74 is_initialized_op=None, 75 cached_value=None, 76 caching_device=None, 77 is_distributed_variables=True) 78 79 @property 80 def devices(self): 81 return self._devices 82 83 def on_device(self, device): 84 return PackedVarAndDevice(self, device) 85 86 def get_var_on_device(self, device): 87 for i, d in enumerate(self._devices): 88 if d == device: 89 return self._distributed_variables[i] 90 raise ValueError("Device %s is not found" % device) 91 92 def get_var_on_current_device(self): 93 current_device = device_util.canonicalize(device_util.current()) 94 return self.get_var_on_device(current_device) 95 96 def initial_value(self, device): 97 """Returns the Tensor used as the initial value for the variable.""" 98 return self.get_var_on_device(device).initial_value 99 100 @property 101 def handle(self): 102 if context.executing_eagerly(): 103 return self.get_var_on_current_device().handle 104 else: 105 return self._handle 106 107 @property 108 def packed_handle(self): 109 return self._handle 110 111 def _read_variable_op(self): 112 if context.executing_eagerly(): 113 return self.get_var_on_current_device().value() 114 else: 115 return super(PackedDistributedVariable, self)._read_variable_op() 116 117 def value(self): 118 return self._read_variable_op() 119 120 def is_initialized(self, name=None): 121 if context.executing_eagerly(): 122 result = self._distributed_variables[0].is_initialized() 123 for v in self._distributed_variables[1:-1]: 124 result = math_ops.logical_and(result, v.is_initialized()) 125 result = math_ops.logical_and( 126 result, self._distributed_variables[-1].is_initialized(), name=name) 127 else: 128 with ops.device(self._devices[0]): 129 result = super(PackedDistributedVariable, self).is_initialized(name) 130 for d in self._devices[1:-1]: 131 with ops.device(d): 132 initialized = super(PackedDistributedVariable, 133 self).is_initialized(name) 134 result = math_ops.logical_and(result, initialized) 135 with ops.device(self._devices[-1]): 136 initialized = super(PackedDistributedVariable, 137 self).is_initialized(name) 138 result = math_ops.logical_and(result, initialized, name=name) 139 return result 140 141 def _update(self, update_fn, value, **kwargs): 142 if context.executing_eagerly(): 143 return update_fn(self.get_var_on_current_device(), value, **kwargs) 144 else: 145 return update_fn(super(PackedDistributedVariable, self), value, **kwargs) 146 147 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 148 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 149 return self._update( 150 update_fn=assign_sub_fn, 151 value=delta, 152 use_locking=use_locking, 153 name=name, 154 read_value=read_value) 155 156 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 157 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 158 return self._update( 159 update_fn=assign_add_fn, 160 value=delta, 161 use_locking=use_locking, 162 name=name, 163 read_value=read_value) 164 165 def assign(self, value, use_locking=None, name=None, read_value=True): 166 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 167 return self._update( 168 update_fn=assign_fn, 169 value=value, 170 use_locking=use_locking, 171 name=name, 172 read_value=read_value) 173 174 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 175 scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw) 176 return self._update( 177 update_fn=scatter_sub_fn, 178 value=sparse_delta, 179 use_locking=use_locking, 180 name=name) 181 182 def scatter_add(self, sparse_delta, use_locking=False, name=None): 183 scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw) 184 return self._update( 185 update_fn=scatter_add_fn, 186 value=sparse_delta, 187 use_locking=use_locking, 188 name=name) 189 190 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 191 scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw) 192 return self._update( 193 update_fn=scatter_mul_fn, 194 value=sparse_delta, 195 use_locking=use_locking, 196 name=name) 197 198 def scatter_div(self, sparse_delta, use_locking=False, name=None): 199 scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw) 200 return self._update( 201 update_fn=scatter_div_fn, 202 value=sparse_delta, 203 use_locking=use_locking, 204 name=name) 205 206 def scatter_min(self, sparse_delta, use_locking=False, name=None): 207 scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw) 208 return self._update( 209 update_fn=scatter_min_fn, 210 value=sparse_delta, 211 use_locking=use_locking, 212 name=name) 213 214 def scatter_max(self, sparse_delta, use_locking=False, name=None): 215 scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw) 216 return self._update( 217 update_fn=scatter_max_fn, 218 value=sparse_delta, 219 use_locking=use_locking, 220 name=name) 221 222 def scatter_update(self, sparse_delta, use_locking=False, name=None): 223 scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw) 224 return self._update( 225 update_fn=scatter_update_fn, 226 value=sparse_delta, 227 use_locking=use_locking, 228 name=name) 229 230 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 231 if context.executing_eagerly(): 232 return self.get_var_on_current_device()._dense_var_to_tensor( # pylint: disable=protected-access 233 dtype=dtype, 234 name=name, 235 as_ref=as_ref) 236 else: 237 return super(PackedDistributedVariable, self)._dense_var_to_tensor( # pylint: disable=protected-access 238 dtype=dtype, 239 name=name, 240 as_ref=as_ref) 241 242 243class PackedVarAndDevice(object): 244 """Holds a packed distributed variable and a device.""" 245 246 def __init__(self, var, device): 247 self._var = var 248 self._device = device 249 250 def __getattr__(self, name): 251 # Exceptions raised inside the contextmanager can cause a reference 252 # cycle.[1] The cycle involves the current frame, which holds the reference 253 # to the outer frame. Tensorflow, e.g. iterators, relies on object 254 # finalizers to clean up resources. Such references prevents the resource 255 # from being deleted and can cause leaks and errors. One corner the case is 256 # that iterators are kept alive and the garbage collector happens to run 257 # after auto control dependencies; this causes the deletion to lose the 258 # control dependencies to operations that uses such resources. 259 # 260 # Catch and re-raise the exception seems to workaround the issue. 261 # 262 # [1] https://bugs.python.org/issue43533 263 try: 264 with ops.device(self._device): 265 return getattr(self._var, name) 266 except: # pylint: disable=try-except-raise 267 raise 268 269 def var(self): 270 return self._var 271 272 def value(self): 273 with ops.device(self._device): 274 return self._var.value() 275 276 def read_value(self): 277 with ops.device(self._device): 278 return self._var.read_value() 279 280 @property 281 def initial_value(self): 282 return self._var.initial_value(self._device) 283 284 def initialized_value(self): 285 with ops.device(self._device): 286 return self._var.initialized_value() 287 288 @property 289 def device(self): 290 return self._device 291 292 @property 293 def handle(self): 294 with ops.device(self._device): 295 return self._var.handle 296 297 def on_device_handle(self): 298 with ops.device(self._device): 299 return self._var.get_var_on_current_device().handle 300 301 @property 302 def op(self): 303 with ops.device(self._device): 304 return self._var.op 305 306 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 307 with ops.device(self._device): 308 return self._var.assign_sub(delta, use_locking, name, read_value) 309 310 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 311 with ops.device(self._device): 312 return self._var.assign_add(delta, use_locking, name, read_value) 313 314 def assign(self, value, use_locking=None, name=None, read_value=True): 315 with ops.device(self._device): 316 return self._var.assign(value, use_locking, name, read_value) 317 318 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 319 with ops.device(self._device): 320 return self._var.scatter_sub(sparse_delta, use_locking, name) 321 322 def scatter_add(self, sparse_delta, use_locking=False, name=None): 323 with ops.device(self._device): 324 return self._var.scatter_add(sparse_delta, use_locking, name) 325 326 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 327 with ops.device(self._device): 328 return self._var.scatter_mul(sparse_delta, use_locking, name) 329 330 def scatter_div(self, sparse_delta, use_locking=False, name=None): 331 with ops.device(self._device): 332 return self._var.scatter_div(sparse_delta, use_locking, name) 333 334 def scatter_min(self, sparse_delta, use_locking=False, name=None): 335 with ops.device(self._device): 336 return self._var.scatter_min(sparse_delta, use_locking, name) 337 338 def scatter_max(self, sparse_delta, use_locking=False, name=None): 339 with ops.device(self._device): 340 return self._var.scatter_max(sparse_delta, use_locking, name) 341 342 def scatter_update(self, sparse_delta, use_locking=False, name=None): 343 with ops.device(self._device): 344 return self._var.scatter_update(sparse_delta, use_locking, name) 345 346 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 347 with ops.device(self._device): 348 return self._var._dense_var_to_tensor( # pylint: disable=protected-access 349 dtype=dtype, 350 name=name, 351 as_ref=as_ref) 352 353 def _as_graph_element(self): 354 return self._var._as_graph_element() # pylint: disable=protected-access 355 356 357def _tensor_conversion_packed_var_and_device(var, 358 dtype=None, 359 name=None, 360 as_ref=False): 361 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 362 363 364ops.register_tensor_conversion_function( 365 PackedVarAndDevice, _tensor_conversion_packed_var_and_device) 366