1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""critical section tests.""" 16 17import itertools 18 19from absl.testing import parameterized 20 21from tensorflow.python.data.experimental.ops import prefetching_ops 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.eager import context 24from tensorflow.python.eager import def_function 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import control_flow_v2_toggles 30from tensorflow.python.ops import critical_section_ops 31from tensorflow.python.ops import resource_variable_ops 32from tensorflow.python.platform import test 33from tensorflow.python.platform import tf_logging as logging 34# TODO(ebrevdo): Re-enable once CriticalSection is in core. 35# from tensorflow.python.training import saver as saver_lib 36 37 38@test_util.with_control_flow_v2 39class CriticalSectionTest(test.TestCase, parameterized.TestCase): 40 41 @test_util.run_in_graph_and_eager_modes 42 def testCreateCriticalSection(self): 43 cs = critical_section_ops.CriticalSection(shared_name="cs") 44 v = resource_variable_ops.ResourceVariable(0.0, name="v") 45 46 def fn(a, b): 47 c = v.value() 48 with ops.control_dependencies([c]): 49 nv = v.assign_add(a * b) 50 with ops.control_dependencies([nv]): 51 return array_ops.identity(c) 52 53 num_concurrent = 100 54 r = [cs.execute(lambda: fn(1.0, 2.0)) for _ in range(num_concurrent)] 55 self.evaluate(v.initializer) 56 r_value = self.evaluate(r) 57 self.assertAllClose([2.0 * i for i in range(num_concurrent)], 58 sorted(r_value)) 59 60 @parameterized.named_parameters( 61 ("Inner%sOuter%s" % (inner, outer), inner, outer) 62 for (inner, outer) in itertools.product(*([(False, True)] * 2))) 63 @test_util.run_in_graph_and_eager_modes 64 @test_util.xla_allow_fallback("b/128495870") 65 def testCriticalSectionWithControlFlow(self, outer_cond, inner_cond): 66 if (not context.executing_eagerly() and 67 control_flow_v2_toggles.control_flow_v2_enabled()): 68 self.skipTest("b/135070612") 69 cs = critical_section_ops.CriticalSection(shared_name="cs") 70 v = resource_variable_ops.ResourceVariable(0.0, name="v") 71 num_concurrent = 100 72 73 # pylint: disable=cell-var-from-loop 74 def fn(a, b): 75 c = v.read_value() 76 def true_fn(): 77 with ops.control_dependencies([c]): 78 nv = v.assign_add(a * b) 79 with ops.control_dependencies([nv]): 80 return array_ops.identity(c) 81 return control_flow_ops.cond( 82 array_ops.identity(inner_cond), true_fn, lambda: c) 83 84 def execute(): 85 return cs.execute(lambda: fn(1.0, 2.0)) 86 87 r = [ 88 control_flow_ops.cond(array_ops.identity(outer_cond), 89 execute, 90 v.read_value) 91 for _ in range(num_concurrent) 92 ] 93 # pylint: enable=cell-var-from-loop 94 95 self.evaluate(v.initializer) 96 r_value = self.evaluate(r) 97 if inner_cond and outer_cond: 98 self.assertAllClose([2.0 * i for i in range(num_concurrent)], 99 sorted(r_value)) 100 else: 101 self.assertAllClose([0] * num_concurrent, r_value) 102 103 @test_util.run_v1_only("b/123990562 Sees CancelledError on some calls") 104 def testCriticalSectionInParallelDoesntDeadlockOnError(self): 105 # No eager mode execution of this test because eager does not 106 # run fn() in parallel, which is where the deadlock could 107 # potentially occur (in graph mode). 108 cs = critical_section_ops.CriticalSection(shared_name="cs") 109 v = resource_variable_ops.ResourceVariable(0.0, name="v") 110 111 def fn(i): 112 error = control_flow_ops.Assert((i % 2) == 1, ["Error"]) 113 with ops.control_dependencies([error]): 114 return v.read_value() 115 116 num_concurrent = 2 117 118 @def_function.function(autograph=False) 119 def run_concurrently(): 120 return [cs.execute(lambda: fn(i)) for i in range(num_concurrent)] 121 122 if not context.executing_eagerly(): 123 run_concurrently = run_concurrently() 124 125 self.evaluate(v.initializer) 126 for _ in range(100): 127 with self.assertRaisesOpError("Error"): 128 if context.executing_eagerly(): 129 run_concurrently() 130 else: 131 self.evaluate(run_concurrently) 132 133 @test_util.run_in_graph_and_eager_modes 134 def testCreateCriticalSectionFnReturnsOp(self): 135 cs = critical_section_ops.CriticalSection(shared_name="cs") 136 v = resource_variable_ops.ResourceVariable(0.0, name="v") 137 138 def fn_return_op(a, b): 139 c = v.read_value() 140 with ops.control_dependencies([c]): 141 nv = v.assign_add(a * b) 142 with ops.control_dependencies([nv]): 143 return control_flow_ops.no_op() 144 145 num_concurrent = 100 146 r = [cs.execute(lambda: fn_return_op(1.0, 2.0)) 147 for _ in range(num_concurrent)] 148 self.evaluate(v.initializer) 149 self.evaluate(r) 150 final_v = self.evaluate(v) 151 self.assertAllClose(2.0 * num_concurrent, final_v) 152 153 @test_util.run_v1_only("Collections don't exist in TF2") 154 def testCollection(self): 155 cs = critical_section_ops.CriticalSection(shared_name="cs") 156 self.assertIn( 157 cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)) 158 add = lambda x: x + 1 159 execute = cs.execute(lambda: add(1.0), name="my_execute") 160 execute_op = [ 161 x for x in execute.graph.get_operations() 162 if "my_execute" in x.name and "MutexLock" in x.type 163 ][0] 164 self.assertIn( 165 execute_op, 166 [signature.op for signature in 167 ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)]) 168 169 def testRecursiveCriticalSectionAccessIsIllegal(self): 170 # This does not work properly in eager mode. Eager users will 171 # just hit a deadlock if they do this. But at least it'll be easier 172 # to debug. 173 cs = critical_section_ops.CriticalSection() 174 add = lambda y: y + 1 175 def fn(x): 176 return cs.execute(lambda: add(x)) 177 178 with self.assertRaisesRegex( 179 ValueError, r"Attempting to lock a CriticalSection .* in which we are"): 180 cs.execute(lambda: fn(1.0)) 181 182 def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self): 183 # This one is subtle; and we're being overly cautious here. The 184 # deadlock we are ensuring we catch is: 185 # 186 # to_capture = CS[lambda x: x + 1](1.0) 187 # deadlocked = CS[lambda x: x + to_capture](1.0) 188 # 189 # This would have caused a deadlock because executing `deadlocked` will 190 # lock the mutex on CS; but then due to dependencies, will attempt 191 # to compute `to_capture`. This computation requires locking CS, 192 # but that is not possible now because CS is already locked by 193 # `deadlocked`. 194 # 195 # We check that CriticalSection.execute properly inserts new 196 # control dependencies to its lock to ensure all captured 197 # operations are finished before anything runs within the critical section. 198 cs = critical_section_ops.CriticalSection(shared_name="cs") 199 fn = array_ops.identity 200 to_capture = cs.execute(lambda: fn(1.0)) 201 fn_captures = lambda x: x + to_capture 202 to_capture_too = array_ops.identity(to_capture) 203 204 ex_0 = cs.execute(lambda: fn_captures(1.0)) 205 206 with ops.control_dependencies([to_capture]): 207 # This is OK because to_capture will execute before this next call 208 ex_1 = cs.execute(lambda: fn_captures(1.0)) 209 210 dependency = array_ops.identity(to_capture) 211 212 fn_captures_dependency = lambda x: x + dependency 213 214 ex_2 = cs.execute(lambda: fn_captures_dependency(1.0)) 215 216 with ops.control_dependencies([to_capture_too]): 217 ex_3 = cs.execute(lambda: fn_captures_dependency(1.0)) 218 219 # Ensure there's no actual deadlock on to_execute. 220 self.assertEqual(2.0, self.evaluate(ex_0)) 221 self.assertEqual(2.0, self.evaluate(ex_1)) 222 self.assertEqual(2.0, self.evaluate(ex_2)) 223 self.assertEqual(2.0, self.evaluate(ex_3)) 224 225 def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self): 226 cs = critical_section_ops.CriticalSection(shared_name="cs") 227 228 def body_implicit_capture(i, j): 229 # This would have caused a deadlock if not for logic in execute 230 # that inserts additional control dependencies onto the lock op: 231 # * Loop body argument j is captured by fn() 232 # * i is running in parallel to move forward the execution 233 # * j is not being checked by the predicate function 234 # * output of cs.execute() is returned as next j. 235 fn = lambda: j + 1 236 return (i + 1, cs.execute(fn)) 237 238 (i_n, j_n) = control_flow_ops.while_loop( 239 lambda i, _: i < 1000, 240 body_implicit_capture, 241 [0, 0], 242 parallel_iterations=25) 243 # For consistency between eager and graph mode. 244 i_n = array_ops.identity(i_n) 245 logging.warn( 246 "\n==============\nRunning " 247 "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " 248 "body_implicit_capture'\n" 249 "==============\n") 250 self.assertEqual((1000, 1000), self.evaluate((i_n, j_n))) 251 logging.warn( 252 "\n==============\nSuccessfully finished running " 253 "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " 254 "body_implicit_capture'\n" 255 "==============\n") 256 257 def body_implicit_capture_protected(i, j): 258 # This version is ok because we manually add a control 259 # dependency on j, which is an argument to the while_loop body 260 # and captured by fn. 261 fn = lambda: j + 1 262 with ops.control_dependencies([j]): 263 return (i + 1, cs.execute(fn)) 264 265 (i_n, j_n) = control_flow_ops.while_loop( 266 lambda i, _: i < 1000, 267 body_implicit_capture_protected, 268 [0, 0], 269 parallel_iterations=25) 270 # For consistency between eager and graph mode. 271 i_n = array_ops.identity(i_n) 272 logging.warn( 273 "\n==============\nRunning " 274 "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " 275 "body_implicit_capture_protected'\n" 276 "==============\n") 277 self.assertEqual((1000, 1000), self.evaluate((i_n, j_n))) 278 logging.warn( 279 "\n==============\nSuccessfully finished running " 280 "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " 281 "body_implicit_capture_protected'\n" 282 "==============\n") 283 284 def body_args_capture(i, j): 285 # This version is ok because j is an argument to fn and we can 286 # ensure there's a control dependency on j. 287 fn = lambda x: x + 1 288 return (i + 1, cs.execute(lambda: fn(j))) 289 290 (i_n, j_n) = control_flow_ops.while_loop( 291 lambda i, _: i < 1000, 292 body_args_capture, 293 [0, 0], 294 parallel_iterations=25) 295 # For consistency between eager and graph mode. 296 i_n = array_ops.identity(i_n) 297 logging.warn( 298 "\n==============\nRunning " 299 "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " 300 "body_args_capture'\n" 301 "==============\n") 302 self.assertEqual((1000, 1000), self.evaluate((i_n, j_n))) 303 logging.warn( 304 "\n==============\nSuccessfully finished running " 305 "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " 306 "body_args_capture'\n" 307 "==============\n") 308 309 def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self): 310 # This does not work properly in eager mode. Eager users will 311 # just hit a deadlock if they do this. But at least it'll be easier 312 # to debug. 313 cs = critical_section_ops.CriticalSection(shared_name="cs") 314 cs_same = critical_section_ops.CriticalSection(shared_name="cs") 315 add = lambda x: x + 1 316 def fn(x): 317 return cs_same.execute(lambda: add(x)) 318 319 with self.assertRaisesRegex( 320 ValueError, r"Attempting to lock a CriticalSection .* in which we are"): 321 cs.execute(lambda: fn(1.0)) 322 323 @test_util.run_v1_only( 324 "b/123955885 Can't identify consumed resources in eager mode") 325 def testMultipleCSExecutionsRequestSameResource(self): 326 cs0 = critical_section_ops.CriticalSection() 327 cs1 = critical_section_ops.CriticalSection() 328 v = resource_variable_ops.ResourceVariable(0.0, name="v") 329 cs0.execute(lambda: v + 1) 330 # It's OK for the same CriticalSection to access this resource. 331 cs0.execute(lambda: v - 1) 332 # It's *not* OK for a different CriticalSection to access it by 333 # default. 334 with self.assertRaisesRegex(ValueError, 335 "requested exclusive resource access"): 336 cs1.execute(lambda: v + 1) 337 # It's not even OK if the second call doesn't request exclusive access. 338 with self.assertRaisesRegex(ValueError, 339 "requested exclusive resource access"): 340 cs1.execute(lambda: v + 1, exclusive_resource_access=False) 341 342 v2 = resource_variable_ops.ResourceVariable(0.0, name="v2") 343 cs0.execute(lambda: v2 + 1, exclusive_resource_access=False) 344 # It's OK if neither requests exclusive resource access. 345 cs1.execute(lambda: v2 + 1, exclusive_resource_access=False) 346 347 # It's not OK if the second request requires exclusive resource 348 # access. 349 with self.assertRaisesRegex(ValueError, 350 "requested exclusive resource access"): 351 cs1.execute(lambda: v2 + 1) 352 353 def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self): 354 cs = critical_section_ops.CriticalSection() 355 v = resource_variable_ops.ResourceVariable(0, name="v") 356 # Make sure that the control dependencies on v do not cause issues 357 # in the lock_op's automatic control dependency adder. 358 # 359 # Note, here v must be a resource variable (or something similar), 360 # otherwise it gets hoisted into the while_loop by the time we add 361 # control dependencies to the lock_op. 362 def body(i): 363 add_j = lambda j: v + j + 1 364 return cs.execute(lambda: add_j(i)) 365 out = control_flow_ops.while_loop( 366 lambda i: i < 10, body, [0]) 367 self.evaluate(v.initializer) 368 self.assertEqual(10, self.evaluate(out)) 369 370 @test_util.run_in_graph_and_eager_modes 371 def testInsideFunction(self): 372 if test_util.is_gpu_available(): 373 self.skipTest( 374 "b/123899495: Colocation errors for critical sections in map on GPU") 375 cs = critical_section_ops.CriticalSection() 376 with ops.device("/gpu:0" if test_util.is_gpu_available() else "/cpu:0"): 377 v = resource_variable_ops.ResourceVariable(1) 378 def fn(): 379 return v.read_value() 380 381 # map() creates a TensorFlow function. 382 ds = dataset_ops.Dataset.range(1) 383 if test_util.is_gpu_available(): 384 ds = (ds.apply(prefetching_ops.copy_to_device("/gpu:0")) 385 .apply(prefetching_ops.map_on_gpu(lambda _: cs.execute(fn)))) 386 else: 387 ds = ds.map(lambda _: cs.execute(fn)) 388 389 def get_first(): 390 if context.executing_eagerly(): 391 return self.evaluate(dataset_ops.make_one_shot_iterator(ds).get_next()) 392 itr = dataset_ops.make_initializable_iterator(ds) 393 self.evaluate([v.initializer, itr.initializer]) 394 return self.evaluate(itr.get_next()) 395 396 self.assertEqual(1, get_first()) 397 398 399if __name__ == "__main__": 400 test.main() 401