xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/critical_section_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Critical Section object and execution logic."""
16
17import collections
18import contextlib
19import threading
20
21from tensorflow.python.eager import context
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import gen_resource_variable_ops
27from tensorflow.python.ops import tensor_array_ops
28from tensorflow.python.util import nest
29from tensorflow.python.util import object_identity
30from tensorflow.python.util.tf_export import tf_export
31
32
33__all__ = ["CriticalSection"]
34
35
36# Graph Keys
37CRITICAL_SECTIONS = "critical_sections"
38CRITICAL_SECTION_EXECUTIONS = "critical_section_executions"
39
40
41class _ExecutionSignature(
42    collections.namedtuple("_ExecutionSignature",
43                           ("op", "handle",
44                            "resources", "exclusive_resource_access"))):
45  """A class storing an `ExecuteInCriticalResource` op and associated attrs."""
46  pass
47
48
49def _identity(x):
50  """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`."""
51  if isinstance(x, tensor_array_ops.TensorArray):
52    return x.identity()
53  elif isinstance(x, ops.Operation):
54    return control_flow_ops.group(x)
55  elif context.executing_eagerly() and x is None:
56    return None
57  else:
58    return array_ops.identity(x)
59
60
61def _get_device_or_colocation(op):
62  return op.device or _get_colocation(op)
63
64
65def _get_colocation(op):
66  """Get colocation symbol from op, if any."""
67  try:
68    return op.get_attr("_class")
69  except (ValueError, AttributeError):
70    return None
71
72
73_CRITICAL_SECTION_STACK = threading.local()
74
75
76def _get_critical_section_stack():
77  try:
78    return _CRITICAL_SECTION_STACK.value
79  except AttributeError:
80    _CRITICAL_SECTION_STACK.value = []
81    return _CRITICAL_SECTION_STACK.value
82
83
84@contextlib.contextmanager
85def _push_critical_section_stack(signature):
86  """Push a CriticalSection._signature to the thread-local stack.
87
88  If the signature is already on the stack, raise an error because it means
89  we're trying to execute inside the same locked CriticalSection, which
90  will create a deadlock.
91
92  Args:
93    signature: Tuple of the type `CriticalSection._signature`.  Uniquely
94      identifies a CriticalSection by its `shared_name`, `container`,
95      and device.
96
97  Yields:
98    An empty value.  The context is guaranteed to run without deadlock.
99
100  Raises:
101    ValueError: If the signature is already on the stack.
102    RuntimeError: If another thread or function modifies the current stack
103      entry during the yield.
104  """
105  stack = _get_critical_section_stack()
106  if signature in stack:
107    raise ValueError(
108        f"Attempting to lock a CriticalSection (signature={signature}) in which"
109        " we are already running. This is illegal and may cause deadlocks.")
110  stack.append(signature)
111  try:
112    yield
113  finally:
114    received_signature = stack.pop()
115    if received_signature != signature:
116      raise RuntimeError(
117          "CriticalSection stack inconsistency: expected signature "
118          f"{signature} but received {received_signature}")
119
120
121@tf_export("CriticalSection")
122class CriticalSection:
123  """Critical section.
124
125  A `CriticalSection` object is a resource in the graph which executes subgraphs
126  in **serial** order.  A common example of a subgraph one may wish to run
127  exclusively is the one given by the following function:
128
129  ```python
130  v = resource_variable_ops.ResourceVariable(0.0, name="v")
131
132  def count():
133    value = v.read_value()
134    with tf.control_dependencies([value]):
135      with tf.control_dependencies([v.assign_add(1)]):
136        return tf.identity(value)
137  ```
138
139  Here, a snapshot of `v` is captured in `value`; and then `v` is updated.
140  The snapshot value is returned.
141
142  If multiple workers or threads all execute `count` in parallel, there is no
143  guarantee that access to the variable `v` is atomic at any point within
144  any thread's calculation of `count`.  In fact, even implementing an atomic
145  counter that guarantees that the user will see each value `0, 1, ...,` is
146  currently impossible.
147
148  The solution is to ensure any access to the underlying resource `v` is
149  only processed through a critical section:
150
151  ```python
152  cs = CriticalSection()
153  f1 = cs.execute(count)
154  f2 = cs.execute(count)
155  output = f1 + f2
156  session.run(output)
157  ```
158  The functions `f1` and `f2` will be executed serially, and updates to `v`
159  will be atomic.
160
161  **NOTES**
162
163  All resource objects, including the critical section and any captured
164  variables of functions executed on that critical section, will be
165  colocated to the same device (host and cpu/gpu).
166
167  When using multiple critical sections on the same resources, there is no
168  guarantee of exclusive access to those resources.  This behavior is disallowed
169  by default (but see the kwarg `exclusive_resource_access`).
170
171  For example, running the same function in two separate critical sections
172  will not ensure serial execution:
173
174  ```python
175  v = tf.compat.v1.get_variable("v", initializer=0.0, use_resource=True)
176  def accumulate(up):
177    x = v.read_value()
178    with tf.control_dependencies([x]):
179      with tf.control_dependencies([v.assign_add(up)]):
180        return tf.identity(x)
181  ex1 = CriticalSection().execute(
182    accumulate, 1.0, exclusive_resource_access=False)
183  ex2 = CriticalSection().execute(
184    accumulate, 1.0, exclusive_resource_access=False)
185  bad_sum = ex1 + ex2
186  sess.run(v.initializer)
187  sess.run(bad_sum)  # May return 0.0
188  ```
189  """
190
191  def __init__(self, name=None, shared_name=None,
192               critical_section_def=None, import_scope=None):
193    """Creates a critical section."""
194    context.ensure_initialized()
195    if critical_section_def and name is not None:
196      raise ValueError(f"Arguments critical_section_def={critical_section_def} "
197                       f"and shared_name={shared_name} are mutually exclusive. "
198                       "Please only specify one of them.")
199    if critical_section_def:
200      raise ValueError("Argument `critical_section_def` is not supported.")
201    else:
202      self._init_from_args(name, shared_name)
203
204  def _init_from_args(self, name, shared_name):  # pylint: disable=invalid-name
205    """Initialize the CriticalSection from constructor arguments."""
206    with ops.name_scope(name, "CriticalSection", []) as name:
207      with ops.init_scope():
208        # pylint: disable=protected-access
209        container = ops.get_default_graph()._container
210        # pylint: enable=protected-access
211        if shared_name is None:
212          shared_name = name
213        if container is None:
214          container = ""
215        self._handle = gen_resource_variable_ops.mutex_v2(
216            shared_name=shared_name, container=container, name=name)
217        # Get a uniquely identifying signature for the handle.
218        self._signature = (
219            container,
220            # If shared_name is empty, a unique CriticalSection is created.
221            shared_name or id(self._handle),
222            _get_device_or_colocation(self._handle))
223
224    if not context.executing_eagerly():
225      ops.add_to_collections(CRITICAL_SECTIONS, self)
226
227  @property
228  def name(self):
229    return self._handle.op.name
230
231  def execute(self, fn, exclusive_resource_access=True, name=None):
232    """Execute function `fn()` inside the critical section.
233
234    `fn` should not accept any arguments.  To add extra arguments to when
235    calling `fn` in the critical section, create a lambda:
236
237    ```python
238    critical_section.execute(lambda: fn(*my_args, **my_kwargs))
239    ```
240
241    Args:
242      fn: The function to execute.  Must return at least one tensor.
243      exclusive_resource_access: Whether the resources required by
244        `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
245        You may want to set this to `False` if you will be accessing a
246        resource in read-only mode in two different CriticalSections.
247      name: The name to use when creating the execute operation.
248
249    Returns:
250      The tensors returned from `fn()`.
251
252    Raises:
253      ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
254        or lazy way that may cause a deadlock.
255      ValueError: If `exclusive_resource_access == True` and
256        another `CriticalSection` has an execution requesting the same
257        resources as `fn``.  Note, even if `exclusive_resource_access` is
258        `True`, if another execution in another `CriticalSection` was created
259        without `exclusive_resource_access=True`, a `ValueError` will be raised.
260    """
261    with ops.name_scope(name, "critical_section_execute", []):
262      # Ensure that mutex locking only happens *after* all args and
263      # kwargs have been executed.  This avoids certain types of deadlocks.
264      with _push_critical_section_stack(self._signature):
265        lock = gen_resource_variable_ops.mutex_lock(self._handle)
266
267        if not context.executing_eagerly():
268          # NOTE(ebrevdo): This is to ensure we don't pick up spurious
269          # Operations created by other threads.
270          with ops.get_default_graph()._lock:  # pylint: disable=protected-access
271            existing_ops = ops.get_default_graph().get_operations()
272            with ops.control_dependencies([lock]):
273              r = fn()
274            # TODO(ebrevdo): If creating critical sections in a python loop,
275            # this makes graph creation time quadratic.  Revisit if this
276            # becomes a problem.
277            created_ops = (set(ops.get_default_graph().get_operations())
278                           .difference(existing_ops))
279        else:
280          with ops.control_dependencies([lock]):
281            r = fn()
282
283      if not context.executing_eagerly():
284        self._add_control_dependencies_to_lock(created_ops, lock.op)
285
286        # captured_resources is a list of resources that are directly
287        # accessed only by ops created during fn(), not by any
288        # ancestors of those ops in the graph.
289        captured_resources = object_identity.ObjectIdentitySet([
290            input_ for op in created_ops
291            for input_ in op.inputs
292            if input_.dtype == dtypes.resource
293        ])
294
295        # NOTE(ebrevdo): The only time self._is_self_handle() is True
296        # in this call is if one of the recently created ops, within
297        # the execute(), themselves attempt to access the
298        # CriticalSection.  This will cause a deadlock.
299        if any(self._is_self_handle(x) for x in captured_resources):
300          raise ValueError(
301              "Attempting to lock a CriticalSection in which we are "
302              f"already running (signature={self._signature}). This is illegal "
303              "and may cause deadlocks.")
304
305        self._check_multiple_access_to_resources(
306            captured_resources, exclusive_resource_access)
307
308      r_flat = [_identity(x) for x in nest.flatten(r)]
309
310      with ops.control_dependencies(r_flat):
311        # The identity must run on the same machine as self._handle
312        with ops.colocate_with(self._handle):
313          # Do not use array_ops.identity as there are special
314          # optimizations within TensorFlow which seem to elide it
315          # even when optimizations are disabled(!).
316          ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
317              lock)
318
319        # Make sure that if any element of r is accessed, all of
320        # them are executed together.
321        r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))
322
323      with ops.control_dependencies([ensure_lock_exists]):
324        outputs = nest.map_structure(_identity, r)
325
326      if not context.executing_eagerly():
327        signature = _ExecutionSignature(
328            op=lock.op,
329            handle=self._handle,
330            resources=list(captured_resources),
331            exclusive_resource_access=exclusive_resource_access)
332        ops.add_to_collections(
333            CRITICAL_SECTION_EXECUTIONS, signature)
334
335      return outputs
336
337  def _add_control_dependencies_to_lock(self, created_ops, lock_op):
338    """To avoid deadlocks, all args must be executed before lock_op."""
339    # Get all arguments (explicit and captured) of all ops created by fn().
340    all_args = set([input_.op for op in created_ops for input_ in op.inputs])
341    all_args.update(
342        input_op for op in created_ops for input_op in op.control_inputs)
343    # Unfortunately, we can't use sets throughout because TF seems to
344    # create new Operation objects for the same op sometimes; and we
345    # can't rely on id(op).
346
347    # pylint: disable=protected-access
348    all_args_dict = dict((op._id, op) for op in all_args)
349
350    # Remove ops created within fn, or that lock_op already has a
351    # control dependency on.  Also remove a possible self-loop.
352    for op in created_ops:
353      all_args_dict.pop(op._id, None)
354    for op in lock_op.control_inputs:
355      all_args_dict.pop(op._id, None)
356    for input_ in lock_op.inputs:
357      all_args_dict.pop(input_.op._id, None)
358    all_args_dict.pop(lock_op._id, None)
359
360    all_args = all_args_dict.values()
361
362    if not all_args:
363      # No control dependencies to add; return early.
364      return
365
366    # This group is important: it ensures that any ops in all_args
367    # outside the control context of the lock_op (and this fn, which
368    # runs in the same context) are added to this context before
369    # being added to the control dependencies of lock_op.
370    all_args = control_flow_ops.group(*all_args)
371
372    lock_op._add_control_input(all_args)
373    # pylint: enable=protected-access
374
375  def _is_self_handle(self, x):
376    """Check if the tensor `x` is the same Mutex as `self._handle`."""
377    if isinstance(x, ops.EagerTensor):
378      return x is self._handle
379    return (x.op.type == "MutexV2"
380            # blank shared_name means the op will create a unique one.
381            and x.op.get_attr("shared_name")
382            and (x.op.get_attr("shared_name") ==
383                 self._handle.op.get_attr("shared_name"))
384            and (x.op.device == self._handle.op.device
385                 or _get_colocation(x.op) == _get_colocation(self._handle.op)))
386
387  def _check_multiple_access_to_resources(
388      self, captured_resources, exclusive_resource_access):
389    """Raise if captured_resources are accessed by another CriticalSection.
390
391    Args:
392      captured_resources: Set of tensors of type resource.
393      exclusive_resource_access: Whether this execution requires exclusive
394        resource access.
395
396    Raises:
397      ValueError: If any tensors in `captured_resources` are also accessed
398        by another `CriticalSection`, and at least one of them requires
399        exclusive resource access.
400    """
401    # Collections and op introspection does not work in eager
402    # mode.  This is generally ok; since eager mode (as of
403    # writing) executes sequentially anyway.
404    for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
405      if self._is_self_handle(sg.handle):
406        # Other executions in the same critical section are allowed.
407        continue
408      if not (exclusive_resource_access or sg.exclusive_resource_access):
409        # Neither execution requested exclusive access.
410        continue
411      resource_intersection = captured_resources.intersection(sg.resources)
412      if resource_intersection:
413        raise ValueError(
414            "This execution would access resources: "
415            f"{list(resource_intersection)}. Either this lock "
416            f"(CriticalSection: {self._handle}) or lock '{sg}' "
417            f"(CriticalSection: {sg.handle}) requested exclusive resource "
418            "access of this resource. Did you mean to call execute with "
419            "keyword argument exclusive_resource_access=False?")
420