1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Critical Section object and execution logic.""" 16 17import collections 18import contextlib 19import threading 20 21from tensorflow.python.eager import context 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import gen_resource_variable_ops 27from tensorflow.python.ops import tensor_array_ops 28from tensorflow.python.util import nest 29from tensorflow.python.util import object_identity 30from tensorflow.python.util.tf_export import tf_export 31 32 33__all__ = ["CriticalSection"] 34 35 36# Graph Keys 37CRITICAL_SECTIONS = "critical_sections" 38CRITICAL_SECTION_EXECUTIONS = "critical_section_executions" 39 40 41class _ExecutionSignature( 42 collections.namedtuple("_ExecutionSignature", 43 ("op", "handle", 44 "resources", "exclusive_resource_access"))): 45 """A class storing an `ExecuteInCriticalResource` op and associated attrs.""" 46 pass 47 48 49def _identity(x): 50 """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`.""" 51 if isinstance(x, tensor_array_ops.TensorArray): 52 return x.identity() 53 elif isinstance(x, ops.Operation): 54 return control_flow_ops.group(x) 55 elif context.executing_eagerly() and x is None: 56 return None 57 else: 58 return array_ops.identity(x) 59 60 61def _get_device_or_colocation(op): 62 return op.device or _get_colocation(op) 63 64 65def _get_colocation(op): 66 """Get colocation symbol from op, if any.""" 67 try: 68 return op.get_attr("_class") 69 except (ValueError, AttributeError): 70 return None 71 72 73_CRITICAL_SECTION_STACK = threading.local() 74 75 76def _get_critical_section_stack(): 77 try: 78 return _CRITICAL_SECTION_STACK.value 79 except AttributeError: 80 _CRITICAL_SECTION_STACK.value = [] 81 return _CRITICAL_SECTION_STACK.value 82 83 84@contextlib.contextmanager 85def _push_critical_section_stack(signature): 86 """Push a CriticalSection._signature to the thread-local stack. 87 88 If the signature is already on the stack, raise an error because it means 89 we're trying to execute inside the same locked CriticalSection, which 90 will create a deadlock. 91 92 Args: 93 signature: Tuple of the type `CriticalSection._signature`. Uniquely 94 identifies a CriticalSection by its `shared_name`, `container`, 95 and device. 96 97 Yields: 98 An empty value. The context is guaranteed to run without deadlock. 99 100 Raises: 101 ValueError: If the signature is already on the stack. 102 RuntimeError: If another thread or function modifies the current stack 103 entry during the yield. 104 """ 105 stack = _get_critical_section_stack() 106 if signature in stack: 107 raise ValueError( 108 f"Attempting to lock a CriticalSection (signature={signature}) in which" 109 " we are already running. This is illegal and may cause deadlocks.") 110 stack.append(signature) 111 try: 112 yield 113 finally: 114 received_signature = stack.pop() 115 if received_signature != signature: 116 raise RuntimeError( 117 "CriticalSection stack inconsistency: expected signature " 118 f"{signature} but received {received_signature}") 119 120 121@tf_export("CriticalSection") 122class CriticalSection: 123 """Critical section. 124 125 A `CriticalSection` object is a resource in the graph which executes subgraphs 126 in **serial** order. A common example of a subgraph one may wish to run 127 exclusively is the one given by the following function: 128 129 ```python 130 v = resource_variable_ops.ResourceVariable(0.0, name="v") 131 132 def count(): 133 value = v.read_value() 134 with tf.control_dependencies([value]): 135 with tf.control_dependencies([v.assign_add(1)]): 136 return tf.identity(value) 137 ``` 138 139 Here, a snapshot of `v` is captured in `value`; and then `v` is updated. 140 The snapshot value is returned. 141 142 If multiple workers or threads all execute `count` in parallel, there is no 143 guarantee that access to the variable `v` is atomic at any point within 144 any thread's calculation of `count`. In fact, even implementing an atomic 145 counter that guarantees that the user will see each value `0, 1, ...,` is 146 currently impossible. 147 148 The solution is to ensure any access to the underlying resource `v` is 149 only processed through a critical section: 150 151 ```python 152 cs = CriticalSection() 153 f1 = cs.execute(count) 154 f2 = cs.execute(count) 155 output = f1 + f2 156 session.run(output) 157 ``` 158 The functions `f1` and `f2` will be executed serially, and updates to `v` 159 will be atomic. 160 161 **NOTES** 162 163 All resource objects, including the critical section and any captured 164 variables of functions executed on that critical section, will be 165 colocated to the same device (host and cpu/gpu). 166 167 When using multiple critical sections on the same resources, there is no 168 guarantee of exclusive access to those resources. This behavior is disallowed 169 by default (but see the kwarg `exclusive_resource_access`). 170 171 For example, running the same function in two separate critical sections 172 will not ensure serial execution: 173 174 ```python 175 v = tf.compat.v1.get_variable("v", initializer=0.0, use_resource=True) 176 def accumulate(up): 177 x = v.read_value() 178 with tf.control_dependencies([x]): 179 with tf.control_dependencies([v.assign_add(up)]): 180 return tf.identity(x) 181 ex1 = CriticalSection().execute( 182 accumulate, 1.0, exclusive_resource_access=False) 183 ex2 = CriticalSection().execute( 184 accumulate, 1.0, exclusive_resource_access=False) 185 bad_sum = ex1 + ex2 186 sess.run(v.initializer) 187 sess.run(bad_sum) # May return 0.0 188 ``` 189 """ 190 191 def __init__(self, name=None, shared_name=None, 192 critical_section_def=None, import_scope=None): 193 """Creates a critical section.""" 194 context.ensure_initialized() 195 if critical_section_def and name is not None: 196 raise ValueError(f"Arguments critical_section_def={critical_section_def} " 197 f"and shared_name={shared_name} are mutually exclusive. " 198 "Please only specify one of them.") 199 if critical_section_def: 200 raise ValueError("Argument `critical_section_def` is not supported.") 201 else: 202 self._init_from_args(name, shared_name) 203 204 def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name 205 """Initialize the CriticalSection from constructor arguments.""" 206 with ops.name_scope(name, "CriticalSection", []) as name: 207 with ops.init_scope(): 208 # pylint: disable=protected-access 209 container = ops.get_default_graph()._container 210 # pylint: enable=protected-access 211 if shared_name is None: 212 shared_name = name 213 if container is None: 214 container = "" 215 self._handle = gen_resource_variable_ops.mutex_v2( 216 shared_name=shared_name, container=container, name=name) 217 # Get a uniquely identifying signature for the handle. 218 self._signature = ( 219 container, 220 # If shared_name is empty, a unique CriticalSection is created. 221 shared_name or id(self._handle), 222 _get_device_or_colocation(self._handle)) 223 224 if not context.executing_eagerly(): 225 ops.add_to_collections(CRITICAL_SECTIONS, self) 226 227 @property 228 def name(self): 229 return self._handle.op.name 230 231 def execute(self, fn, exclusive_resource_access=True, name=None): 232 """Execute function `fn()` inside the critical section. 233 234 `fn` should not accept any arguments. To add extra arguments to when 235 calling `fn` in the critical section, create a lambda: 236 237 ```python 238 critical_section.execute(lambda: fn(*my_args, **my_kwargs)) 239 ``` 240 241 Args: 242 fn: The function to execute. Must return at least one tensor. 243 exclusive_resource_access: Whether the resources required by 244 `fn` should be exclusive to this `CriticalSection`. Default: `True`. 245 You may want to set this to `False` if you will be accessing a 246 resource in read-only mode in two different CriticalSections. 247 name: The name to use when creating the execute operation. 248 249 Returns: 250 The tensors returned from `fn()`. 251 252 Raises: 253 ValueError: If `fn` attempts to lock this `CriticalSection` in any nested 254 or lazy way that may cause a deadlock. 255 ValueError: If `exclusive_resource_access == True` and 256 another `CriticalSection` has an execution requesting the same 257 resources as `fn``. Note, even if `exclusive_resource_access` is 258 `True`, if another execution in another `CriticalSection` was created 259 without `exclusive_resource_access=True`, a `ValueError` will be raised. 260 """ 261 with ops.name_scope(name, "critical_section_execute", []): 262 # Ensure that mutex locking only happens *after* all args and 263 # kwargs have been executed. This avoids certain types of deadlocks. 264 with _push_critical_section_stack(self._signature): 265 lock = gen_resource_variable_ops.mutex_lock(self._handle) 266 267 if not context.executing_eagerly(): 268 # NOTE(ebrevdo): This is to ensure we don't pick up spurious 269 # Operations created by other threads. 270 with ops.get_default_graph()._lock: # pylint: disable=protected-access 271 existing_ops = ops.get_default_graph().get_operations() 272 with ops.control_dependencies([lock]): 273 r = fn() 274 # TODO(ebrevdo): If creating critical sections in a python loop, 275 # this makes graph creation time quadratic. Revisit if this 276 # becomes a problem. 277 created_ops = (set(ops.get_default_graph().get_operations()) 278 .difference(existing_ops)) 279 else: 280 with ops.control_dependencies([lock]): 281 r = fn() 282 283 if not context.executing_eagerly(): 284 self._add_control_dependencies_to_lock(created_ops, lock.op) 285 286 # captured_resources is a list of resources that are directly 287 # accessed only by ops created during fn(), not by any 288 # ancestors of those ops in the graph. 289 captured_resources = object_identity.ObjectIdentitySet([ 290 input_ for op in created_ops 291 for input_ in op.inputs 292 if input_.dtype == dtypes.resource 293 ]) 294 295 # NOTE(ebrevdo): The only time self._is_self_handle() is True 296 # in this call is if one of the recently created ops, within 297 # the execute(), themselves attempt to access the 298 # CriticalSection. This will cause a deadlock. 299 if any(self._is_self_handle(x) for x in captured_resources): 300 raise ValueError( 301 "Attempting to lock a CriticalSection in which we are " 302 f"already running (signature={self._signature}). This is illegal " 303 "and may cause deadlocks.") 304 305 self._check_multiple_access_to_resources( 306 captured_resources, exclusive_resource_access) 307 308 r_flat = [_identity(x) for x in nest.flatten(r)] 309 310 with ops.control_dependencies(r_flat): 311 # The identity must run on the same machine as self._handle 312 with ops.colocate_with(self._handle): 313 # Do not use array_ops.identity as there are special 314 # optimizations within TensorFlow which seem to elide it 315 # even when optimizations are disabled(!). 316 ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock( 317 lock) 318 319 # Make sure that if any element of r is accessed, all of 320 # them are executed together. 321 r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r))) 322 323 with ops.control_dependencies([ensure_lock_exists]): 324 outputs = nest.map_structure(_identity, r) 325 326 if not context.executing_eagerly(): 327 signature = _ExecutionSignature( 328 op=lock.op, 329 handle=self._handle, 330 resources=list(captured_resources), 331 exclusive_resource_access=exclusive_resource_access) 332 ops.add_to_collections( 333 CRITICAL_SECTION_EXECUTIONS, signature) 334 335 return outputs 336 337 def _add_control_dependencies_to_lock(self, created_ops, lock_op): 338 """To avoid deadlocks, all args must be executed before lock_op.""" 339 # Get all arguments (explicit and captured) of all ops created by fn(). 340 all_args = set([input_.op for op in created_ops for input_ in op.inputs]) 341 all_args.update( 342 input_op for op in created_ops for input_op in op.control_inputs) 343 # Unfortunately, we can't use sets throughout because TF seems to 344 # create new Operation objects for the same op sometimes; and we 345 # can't rely on id(op). 346 347 # pylint: disable=protected-access 348 all_args_dict = dict((op._id, op) for op in all_args) 349 350 # Remove ops created within fn, or that lock_op already has a 351 # control dependency on. Also remove a possible self-loop. 352 for op in created_ops: 353 all_args_dict.pop(op._id, None) 354 for op in lock_op.control_inputs: 355 all_args_dict.pop(op._id, None) 356 for input_ in lock_op.inputs: 357 all_args_dict.pop(input_.op._id, None) 358 all_args_dict.pop(lock_op._id, None) 359 360 all_args = all_args_dict.values() 361 362 if not all_args: 363 # No control dependencies to add; return early. 364 return 365 366 # This group is important: it ensures that any ops in all_args 367 # outside the control context of the lock_op (and this fn, which 368 # runs in the same context) are added to this context before 369 # being added to the control dependencies of lock_op. 370 all_args = control_flow_ops.group(*all_args) 371 372 lock_op._add_control_input(all_args) 373 # pylint: enable=protected-access 374 375 def _is_self_handle(self, x): 376 """Check if the tensor `x` is the same Mutex as `self._handle`.""" 377 if isinstance(x, ops.EagerTensor): 378 return x is self._handle 379 return (x.op.type == "MutexV2" 380 # blank shared_name means the op will create a unique one. 381 and x.op.get_attr("shared_name") 382 and (x.op.get_attr("shared_name") == 383 self._handle.op.get_attr("shared_name")) 384 and (x.op.device == self._handle.op.device 385 or _get_colocation(x.op) == _get_colocation(self._handle.op))) 386 387 def _check_multiple_access_to_resources( 388 self, captured_resources, exclusive_resource_access): 389 """Raise if captured_resources are accessed by another CriticalSection. 390 391 Args: 392 captured_resources: Set of tensors of type resource. 393 exclusive_resource_access: Whether this execution requires exclusive 394 resource access. 395 396 Raises: 397 ValueError: If any tensors in `captured_resources` are also accessed 398 by another `CriticalSection`, and at least one of them requires 399 exclusive resource access. 400 """ 401 # Collections and op introspection does not work in eager 402 # mode. This is generally ok; since eager mode (as of 403 # writing) executes sequentially anyway. 404 for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): 405 if self._is_self_handle(sg.handle): 406 # Other executions in the same critical section are allowed. 407 continue 408 if not (exclusive_resource_access or sg.exclusive_resource_access): 409 # Neither execution requested exclusive access. 410 continue 411 resource_intersection = captured_resources.intersection(sg.resources) 412 if resource_intersection: 413 raise ValueError( 414 "This execution would access resources: " 415 f"{list(resource_intersection)}. Either this lock " 416 f"(CriticalSection: {self._handle}) or lock '{sg}' " 417 f"(CriticalSection: {sg.handle}) requested exclusive resource " 418 "access of this resource. Did you mean to call execute with " 419 "keyword argument exclusive_resource_access=False?") 420