xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/critical_section_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""critical section tests."""
16
17import itertools
18
19from absl.testing import parameterized
20
21from tensorflow.python.data.experimental.ops import prefetching_ops
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.eager import context
24from tensorflow.python.eager import def_function
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import control_flow_v2_toggles
30from tensorflow.python.ops import critical_section_ops
31from tensorflow.python.ops import resource_variable_ops
32from tensorflow.python.platform import test
33from tensorflow.python.platform import tf_logging as logging
34# TODO(ebrevdo): Re-enable once CriticalSection is in core.
35# from tensorflow.python.training import saver as saver_lib
36
37
38@test_util.with_control_flow_v2
39class CriticalSectionTest(test.TestCase, parameterized.TestCase):
40
41  @test_util.run_in_graph_and_eager_modes
42  def testCreateCriticalSection(self):
43    cs = critical_section_ops.CriticalSection(shared_name="cs")
44    v = resource_variable_ops.ResourceVariable(0.0, name="v")
45
46    def fn(a, b):
47      c = v.value()
48      with ops.control_dependencies([c]):
49        nv = v.assign_add(a * b)
50        with ops.control_dependencies([nv]):
51          return array_ops.identity(c)
52
53    num_concurrent = 100
54    r = [cs.execute(lambda: fn(1.0, 2.0)) for _ in range(num_concurrent)]
55    self.evaluate(v.initializer)
56    r_value = self.evaluate(r)
57    self.assertAllClose([2.0 * i for i in range(num_concurrent)],
58                        sorted(r_value))
59
60  @parameterized.named_parameters(
61      ("Inner%sOuter%s" % (inner, outer), inner, outer)
62      for (inner, outer) in itertools.product(*([(False, True)] * 2)))
63  @test_util.run_in_graph_and_eager_modes
64  @test_util.xla_allow_fallback("b/128495870")
65  def testCriticalSectionWithControlFlow(self, outer_cond, inner_cond):
66    if (not context.executing_eagerly() and
67        control_flow_v2_toggles.control_flow_v2_enabled()):
68      self.skipTest("b/135070612")
69    cs = critical_section_ops.CriticalSection(shared_name="cs")
70    v = resource_variable_ops.ResourceVariable(0.0, name="v")
71    num_concurrent = 100
72
73    # pylint: disable=cell-var-from-loop
74    def fn(a, b):
75      c = v.read_value()
76      def true_fn():
77        with ops.control_dependencies([c]):
78          nv = v.assign_add(a * b)
79          with ops.control_dependencies([nv]):
80            return array_ops.identity(c)
81      return control_flow_ops.cond(
82          array_ops.identity(inner_cond), true_fn, lambda: c)
83
84    def execute():
85      return cs.execute(lambda: fn(1.0, 2.0))
86
87    r = [
88        control_flow_ops.cond(array_ops.identity(outer_cond),
89                              execute,
90                              v.read_value)
91        for _ in range(num_concurrent)
92    ]
93    # pylint: enable=cell-var-from-loop
94
95    self.evaluate(v.initializer)
96    r_value = self.evaluate(r)
97    if inner_cond and outer_cond:
98      self.assertAllClose([2.0 * i for i in range(num_concurrent)],
99                          sorted(r_value))
100    else:
101      self.assertAllClose([0] * num_concurrent, r_value)
102
103  @test_util.run_v1_only("b/123990562 Sees CancelledError on some calls")
104  def testCriticalSectionInParallelDoesntDeadlockOnError(self):
105    # No eager mode execution of this test because eager does not
106    # run fn() in parallel, which is where the deadlock could
107    # potentially occur (in graph mode).
108    cs = critical_section_ops.CriticalSection(shared_name="cs")
109    v = resource_variable_ops.ResourceVariable(0.0, name="v")
110
111    def fn(i):
112      error = control_flow_ops.Assert((i % 2) == 1, ["Error"])
113      with ops.control_dependencies([error]):
114        return v.read_value()
115
116    num_concurrent = 2
117
118    @def_function.function(autograph=False)
119    def run_concurrently():
120      return [cs.execute(lambda: fn(i)) for i in range(num_concurrent)]
121
122    if not context.executing_eagerly():
123      run_concurrently = run_concurrently()
124
125    self.evaluate(v.initializer)
126    for _ in range(100):
127      with self.assertRaisesOpError("Error"):
128        if context.executing_eagerly():
129          run_concurrently()
130        else:
131          self.evaluate(run_concurrently)
132
133  @test_util.run_in_graph_and_eager_modes
134  def testCreateCriticalSectionFnReturnsOp(self):
135    cs = critical_section_ops.CriticalSection(shared_name="cs")
136    v = resource_variable_ops.ResourceVariable(0.0, name="v")
137
138    def fn_return_op(a, b):
139      c = v.read_value()
140      with ops.control_dependencies([c]):
141        nv = v.assign_add(a * b)
142        with ops.control_dependencies([nv]):
143          return control_flow_ops.no_op()
144
145    num_concurrent = 100
146    r = [cs.execute(lambda: fn_return_op(1.0, 2.0))
147         for _ in range(num_concurrent)]
148    self.evaluate(v.initializer)
149    self.evaluate(r)
150    final_v = self.evaluate(v)
151    self.assertAllClose(2.0 * num_concurrent, final_v)
152
153  @test_util.run_v1_only("Collections don't exist in TF2")
154  def testCollection(self):
155    cs = critical_section_ops.CriticalSection(shared_name="cs")
156    self.assertIn(
157        cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS))
158    add = lambda x: x + 1
159    execute = cs.execute(lambda: add(1.0), name="my_execute")
160    execute_op = [
161        x for x in execute.graph.get_operations()
162        if "my_execute" in x.name and "MutexLock" in x.type
163    ][0]
164    self.assertIn(
165        execute_op,
166        [signature.op for signature in
167         ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
168
169  def testRecursiveCriticalSectionAccessIsIllegal(self):
170    # This does not work properly in eager mode.  Eager users will
171    # just hit a deadlock if they do this.  But at least it'll be easier
172    # to debug.
173    cs = critical_section_ops.CriticalSection()
174    add = lambda y: y + 1
175    def fn(x):
176      return cs.execute(lambda: add(x))
177
178    with self.assertRaisesRegex(
179        ValueError, r"Attempting to lock a CriticalSection .* in which we are"):
180      cs.execute(lambda: fn(1.0))
181
182  def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
183    # This one is subtle; and we're being overly cautious here.  The
184    # deadlock we are ensuring we catch is:
185    #
186    # to_capture = CS[lambda x: x + 1](1.0)
187    # deadlocked = CS[lambda x: x + to_capture](1.0)
188    #
189    # This would have caused a deadlock because executing `deadlocked` will
190    # lock the mutex on CS; but then due to dependencies, will attempt
191    # to compute `to_capture`.  This computation requires locking CS,
192    # but that is not possible now because CS is already locked by
193    # `deadlocked`.
194    #
195    # We check that CriticalSection.execute properly inserts new
196    # control dependencies to its lock to ensure all captured
197    # operations are finished before anything runs within the critical section.
198    cs = critical_section_ops.CriticalSection(shared_name="cs")
199    fn = array_ops.identity
200    to_capture = cs.execute(lambda: fn(1.0))
201    fn_captures = lambda x: x + to_capture
202    to_capture_too = array_ops.identity(to_capture)
203
204    ex_0 = cs.execute(lambda: fn_captures(1.0))
205
206    with ops.control_dependencies([to_capture]):
207      # This is OK because to_capture will execute before this next call
208      ex_1 = cs.execute(lambda: fn_captures(1.0))
209
210    dependency = array_ops.identity(to_capture)
211
212    fn_captures_dependency = lambda x: x + dependency
213
214    ex_2 = cs.execute(lambda: fn_captures_dependency(1.0))
215
216    with ops.control_dependencies([to_capture_too]):
217      ex_3 = cs.execute(lambda: fn_captures_dependency(1.0))
218
219    # Ensure there's no actual deadlock on to_execute.
220    self.assertEqual(2.0, self.evaluate(ex_0))
221    self.assertEqual(2.0, self.evaluate(ex_1))
222    self.assertEqual(2.0, self.evaluate(ex_2))
223    self.assertEqual(2.0, self.evaluate(ex_3))
224
225  def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self):
226    cs = critical_section_ops.CriticalSection(shared_name="cs")
227
228    def body_implicit_capture(i, j):
229      # This would have caused a deadlock if not for logic in execute
230      # that inserts additional control dependencies onto the lock op:
231      #   * Loop body argument j is captured by fn()
232      #   * i is running in parallel to move forward the execution
233      #   * j is not being checked by the predicate function
234      #   * output of cs.execute() is returned as next j.
235      fn = lambda: j + 1
236      return (i + 1, cs.execute(fn))
237
238    (i_n, j_n) = control_flow_ops.while_loop(
239        lambda i, _: i < 1000,
240        body_implicit_capture,
241        [0, 0],
242        parallel_iterations=25)
243    # For consistency between eager and graph mode.
244    i_n = array_ops.identity(i_n)
245    logging.warn(
246        "\n==============\nRunning "
247        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
248        "body_implicit_capture'\n"
249        "==============\n")
250    self.assertEqual((1000, 1000), self.evaluate((i_n, j_n)))
251    logging.warn(
252        "\n==============\nSuccessfully finished running "
253        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
254        "body_implicit_capture'\n"
255        "==============\n")
256
257    def body_implicit_capture_protected(i, j):
258      # This version is ok because we manually add a control
259      # dependency on j, which is an argument to the while_loop body
260      # and captured by fn.
261      fn = lambda: j + 1
262      with ops.control_dependencies([j]):
263        return (i + 1, cs.execute(fn))
264
265    (i_n, j_n) = control_flow_ops.while_loop(
266        lambda i, _: i < 1000,
267        body_implicit_capture_protected,
268        [0, 0],
269        parallel_iterations=25)
270    # For consistency between eager and graph mode.
271    i_n = array_ops.identity(i_n)
272    logging.warn(
273        "\n==============\nRunning "
274        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
275        "body_implicit_capture_protected'\n"
276        "==============\n")
277    self.assertEqual((1000, 1000), self.evaluate((i_n, j_n)))
278    logging.warn(
279        "\n==============\nSuccessfully finished running "
280        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
281        "body_implicit_capture_protected'\n"
282        "==============\n")
283
284    def body_args_capture(i, j):
285      # This version is ok because j is an argument to fn and we can
286      # ensure there's a control dependency on j.
287      fn = lambda x: x + 1
288      return (i + 1, cs.execute(lambda: fn(j)))
289
290    (i_n, j_n) = control_flow_ops.while_loop(
291        lambda i, _: i < 1000,
292        body_args_capture,
293        [0, 0],
294        parallel_iterations=25)
295    # For consistency between eager and graph mode.
296    i_n = array_ops.identity(i_n)
297    logging.warn(
298        "\n==============\nRunning "
299        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
300        "body_args_capture'\n"
301        "==============\n")
302    self.assertEqual((1000, 1000), self.evaluate((i_n, j_n)))
303    logging.warn(
304        "\n==============\nSuccessfully finished running "
305        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
306        "body_args_capture'\n"
307        "==============\n")
308
309  def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self):
310    # This does not work properly in eager mode.  Eager users will
311    # just hit a deadlock if they do this.  But at least it'll be easier
312    # to debug.
313    cs = critical_section_ops.CriticalSection(shared_name="cs")
314    cs_same = critical_section_ops.CriticalSection(shared_name="cs")
315    add = lambda x: x + 1
316    def fn(x):
317      return cs_same.execute(lambda: add(x))
318
319    with self.assertRaisesRegex(
320        ValueError, r"Attempting to lock a CriticalSection .* in which we are"):
321      cs.execute(lambda: fn(1.0))
322
323  @test_util.run_v1_only(
324      "b/123955885 Can't identify consumed resources in eager mode")
325  def testMultipleCSExecutionsRequestSameResource(self):
326    cs0 = critical_section_ops.CriticalSection()
327    cs1 = critical_section_ops.CriticalSection()
328    v = resource_variable_ops.ResourceVariable(0.0, name="v")
329    cs0.execute(lambda: v + 1)
330    # It's OK for the same CriticalSection to access this resource.
331    cs0.execute(lambda: v - 1)
332    # It's *not* OK for a different CriticalSection to access it by
333    # default.
334    with self.assertRaisesRegex(ValueError,
335                                "requested exclusive resource access"):
336      cs1.execute(lambda: v + 1)
337    # It's not even OK if the second call doesn't request exclusive access.
338    with self.assertRaisesRegex(ValueError,
339                                "requested exclusive resource access"):
340      cs1.execute(lambda: v + 1, exclusive_resource_access=False)
341
342    v2 = resource_variable_ops.ResourceVariable(0.0, name="v2")
343    cs0.execute(lambda: v2 + 1, exclusive_resource_access=False)
344    # It's OK if neither requests exclusive resource access.
345    cs1.execute(lambda: v2 + 1, exclusive_resource_access=False)
346
347    # It's not OK if the second request requires exclusive resource
348    # access.
349    with self.assertRaisesRegex(ValueError,
350                                "requested exclusive resource access"):
351      cs1.execute(lambda: v2 + 1)
352
353  def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self):
354    cs = critical_section_ops.CriticalSection()
355    v = resource_variable_ops.ResourceVariable(0, name="v")
356    # Make sure that the control dependencies on v do not cause issues
357    # in the lock_op's automatic control dependency adder.
358    #
359    # Note, here v must be a resource variable (or something similar),
360    # otherwise it gets hoisted into the while_loop by the time we add
361    # control dependencies to the lock_op.
362    def body(i):
363      add_j = lambda j: v + j + 1
364      return cs.execute(lambda: add_j(i))
365    out = control_flow_ops.while_loop(
366        lambda i: i < 10, body, [0])
367    self.evaluate(v.initializer)
368    self.assertEqual(10, self.evaluate(out))
369
370  @test_util.run_in_graph_and_eager_modes
371  def testInsideFunction(self):
372    if test_util.is_gpu_available():
373      self.skipTest(
374          "b/123899495: Colocation errors for critical sections in map on GPU")
375    cs = critical_section_ops.CriticalSection()
376    with ops.device("/gpu:0" if test_util.is_gpu_available() else "/cpu:0"):
377      v = resource_variable_ops.ResourceVariable(1)
378    def fn():
379      return v.read_value()
380
381    # map() creates a TensorFlow function.
382    ds = dataset_ops.Dataset.range(1)
383    if test_util.is_gpu_available():
384      ds = (ds.apply(prefetching_ops.copy_to_device("/gpu:0"))
385            .apply(prefetching_ops.map_on_gpu(lambda _: cs.execute(fn))))
386    else:
387      ds = ds.map(lambda _: cs.execute(fn))
388
389    def get_first():
390      if context.executing_eagerly():
391        return self.evaluate(dataset_ops.make_one_shot_iterator(ds).get_next())
392      itr = dataset_ops.make_initializable_iterator(ds)
393      self.evaluate([v.initializer, itr.initializer])
394      return self.evaluate(itr.get_next())
395
396    self.assertEqual(1, get_first())
397
398
399if __name__ == "__main__":
400  test.main()
401