1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.kernels.functional_ops.""" 16 17import numpy as np 18 19from tensorflow.core.framework import attr_value_pb2 20from tensorflow.core.protobuf import config_pb2 21from tensorflow.python.client import session 22from tensorflow.python.data.ops import iterator_ops 23from tensorflow.python.eager import cancellation 24from tensorflow.python.eager import context 25from tensorflow.python.eager import def_function as eager_def_function 26from tensorflow.python.eager import executor 27from tensorflow.python.eager import function as eager_function 28from tensorflow.python.framework import config as framework_config 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import function 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import test_util 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import collective_ops 37from tensorflow.python.ops import functional_ops 38from tensorflow.python.ops import gen_functional_ops 39from tensorflow.python.ops import gradients_impl 40from tensorflow.python.ops import init_ops 41from tensorflow.python.ops import math_ops 42from tensorflow.python.ops import resource_variable_ops 43from tensorflow.python.ops import variable_scope 44from tensorflow.python.ops import variables 45import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import 46from tensorflow.python.platform import test 47from tensorflow.python.util import compat 48 49 50# pylint: disable=invalid-name 51def simple_scoped_fn(a, x): 52 """Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope.""" 53 with variable_scope.variable_scope("body"): 54 # Dummy variable, just to check that scoping works as intended. 55 two = variable_scope.get_variable( 56 "two", [], 57 dtype=dtypes.int32, 58 initializer=init_ops.constant_initializer(2)) 59 return math_ops.multiply(math_ops.add(a, x), two) 60 61 62@test_util.with_control_flow_v2 63class FunctionalOpsTest(test.TestCase): 64 65 @test_util.run_in_graph_and_eager_modes 66 def testFoldl_Simple(self): 67 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 68 69 r = functional_ops.foldl( 70 lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), 71 elems) 72 self.assertAllEqual(208, self.evaluate(r)) 73 74 r = functional_ops.foldl( 75 lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), 76 elems, 77 initializer=10) 78 self.assertAllEqual(880, self.evaluate(r)) 79 80 @test_util.run_in_graph_and_eager_modes 81 def testFoldl_SingleInputMultiOutput(self): 82 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 83 initializer = np.array([1, -1.0]) 84 r = functional_ops.foldl(lambda a, x: a + x, elems, initializer) 85 r_value = self.evaluate(r) 86 87 self.assertAllEqual(22, r_value[0]) 88 self.assertAllEqual(20, r_value[1]) 89 90 @test_util.run_in_graph_and_eager_modes 91 def testFoldl_MultiInputSingleOutput(self): 92 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 93 initializer = np.array(1.0) 94 r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems), 95 initializer) 96 self.assertAllEqual(1, self.evaluate(r)) 97 98 @test_util.run_in_graph_and_eager_modes 99 def testFoldl_MultiInputDifferentDimsSingleOutput(self): 100 elems = np.array([[1.0, 1.0, 1.0], [2.0, 3.0, 4.0]]) 101 other_elems = np.array([-1.0, 1.0]) 102 initializer = np.array([0.0, 0.0, 0.0]) 103 r = functional_ops.foldl(lambda a, x: a + x[0] * x[1], 104 (elems, other_elems), initializer) 105 self.assertAllEqual([1.0, 2.0, 3.0], self.evaluate(r)) 106 107 @test_util.run_deprecated_v1 108 def testFoldl_Scoped(self): 109 with self.cached_session() as sess: 110 with variable_scope.variable_scope("root") as varscope: 111 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 112 113 r = functional_ops.foldl(simple_scoped_fn, elems) 114 # Check that we have the one variable we asked for here. 115 self.assertEqual(len(variables.trainable_variables()), 1) 116 self.assertEqual(variables.trainable_variables()[0].name, 117 "root/body/two:0") 118 sess.run([variables.global_variables_initializer()]) 119 self.assertAllEqual(208, self.evaluate(r)) 120 121 # Now let's reuse our single variable. 122 varscope.reuse_variables() 123 r = functional_ops.foldl(simple_scoped_fn, elems, initializer=10) 124 self.assertEqual(len(variables.trainable_variables()), 1) 125 self.assertAllEqual(880, self.evaluate(r)) 126 127 @test_util.run_in_graph_and_eager_modes 128 def testFoldr_Simple(self): 129 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 130 131 r = functional_ops.foldr( 132 lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), 133 elems) 134 self.assertAllEqual(450, self.evaluate(r)) 135 136 r = functional_ops.foldr( 137 lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), 138 elems, 139 initializer=10) 140 self.assertAllEqual(1282, self.evaluate(r)) 141 142 @test_util.run_in_graph_and_eager_modes 143 def testFoldr_SingleInputMultiOutput(self): 144 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 145 initializer = np.array([1, -1.0]) 146 r = functional_ops.foldr(lambda a, x: a + x, elems, initializer) 147 r_value = self.evaluate(r) 148 149 self.assertAllEqual(22, r_value[0]) 150 self.assertAllEqual(20, r_value[1]) 151 152 @test_util.run_in_graph_and_eager_modes 153 def testFoldr_MultiInputSingleOutput(self): 154 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 155 initializer = np.array(1.0) 156 r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems), 157 initializer) 158 self.assertAllEqual(1, self.evaluate(r)) 159 160 @test_util.run_deprecated_v1 161 def testFoldr_Scoped(self): 162 with self.cached_session() as sess: 163 with variable_scope.variable_scope("root") as varscope: 164 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 165 166 r = functional_ops.foldr(simple_scoped_fn, elems) 167 # Check that we have the one variable we asked for here. 168 self.assertEqual(len(variables.trainable_variables()), 1) 169 self.assertEqual(variables.trainable_variables()[0].name, 170 "root/body/two:0") 171 sess.run([variables.global_variables_initializer()]) 172 self.assertAllEqual(450, self.evaluate(r)) 173 174 # Now let's reuse our single variable. 175 varscope.reuse_variables() 176 r = functional_ops.foldr(simple_scoped_fn, elems, initializer=10) 177 self.assertEqual(len(variables.trainable_variables()), 1) 178 self.assertAllEqual(1282, self.evaluate(r)) 179 180 # pylint: disable=unnecessary-lambda 181 @test_util.run_deprecated_v1 182 def testFold_Grad(self): 183 with self.cached_session(): 184 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") 185 v = constant_op.constant(2.0, name="v") 186 r = functional_ops.foldl( 187 lambda a, x: math_ops.multiply(a, x), elems, initializer=v) 188 r = gradients_impl.gradients(r, v)[0] 189 self.assertAllEqual(720.0, self.evaluate(r)) 190 191 r = functional_ops.foldr( 192 lambda a, x: math_ops.multiply(a, x), elems, initializer=v) 193 r = gradients_impl.gradients(r, v)[0] 194 self.assertAllEqual(720.0, self.evaluate(r)) 195 # pylint: enable=unnecessary-lambda 196 197 @test_util.run_in_graph_and_eager_modes 198 def testScan_Simple(self): 199 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") 200 v = constant_op.constant(2.0, name="v") 201 202 # pylint: disable=unnecessary-lambda 203 r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems) 204 self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r)) 205 206 r = functional_ops.scan( 207 lambda a, x: math_ops.multiply(a, x), elems, initializer=v) 208 self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) 209 # pylint: enable=unnecessary-lambda 210 211 @test_util.run_in_graph_and_eager_modes 212 def testScan_Reverse(self): 213 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") 214 v = constant_op.constant(2.0, name="v") 215 216 # pylint: disable=unnecessary-lambda 217 r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems, 218 reverse=True) 219 self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r)) 220 r = functional_ops.scan( 221 lambda a, x: math_ops.multiply(a, x), elems, initializer=v, 222 reverse=True) 223 self.assertAllEqual([1440., 1440., 720., 240., 60., 12.], 224 self.evaluate(r)) 225 # pylint: enable=unnecessary-lambda 226 227 @test_util.run_in_graph_and_eager_modes 228 def testScan_SingleInputMultiOutput(self): 229 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 230 initializer = (np.array(1.0), np.array(-1.0)) 231 r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems, 232 initializer) 233 r_value = self.evaluate(r) 234 235 self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0]) 236 self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1]) 237 238 @test_util.run_in_graph_and_eager_modes 239 def testScan_MultiInputSingleOutput(self): 240 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 241 initializer = np.array(1.0) 242 # Multiply a * 1 each time 243 r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]), 244 (elems + 1, -elems), initializer) 245 self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r)) 246 247 @test_util.run_in_graph_and_eager_modes 248 def testScan_MultiInputSameTypeOutput(self): 249 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 250 r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]), 251 (elems, -elems)) 252 r_value = self.evaluate(r) 253 self.assertAllEqual(np.cumsum(elems), r_value[0]) 254 self.assertAllEqual(np.cumsum(-elems), r_value[1]) 255 256 @test_util.run_in_graph_and_eager_modes 257 def testScan_MultiOutputMismatchedInitializer(self): 258 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 259 initializer = np.array(1.0) 260 # Multiply a * 1 each time 261 with self.assertRaisesRegex( 262 ValueError, "two structures don't have the same nested structure"): 263 functional_ops.scan(lambda a, x: (a, -a), elems, initializer) 264 265 @test_util.run_deprecated_v1 266 def testScan_Scoped(self): 267 with self.cached_session() as sess: 268 with variable_scope.variable_scope("root") as varscope: 269 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 270 271 r = functional_ops.scan(simple_scoped_fn, elems) 272 # Check that we have the one variable we asked for here. 273 self.assertEqual(len(variables.trainable_variables()), 1) 274 self.assertEqual(variables.trainable_variables()[0].name, 275 "root/body/two:0") 276 sess.run([variables.global_variables_initializer()]) 277 results = np.array([1, 6, 18, 44, 98, 208]) 278 self.assertAllEqual(results, self.evaluate(r)) 279 280 # Now let's reuse our single variable. 281 varscope.reuse_variables() 282 r = functional_ops.scan(simple_scoped_fn, elems, initializer=2) 283 self.assertEqual(len(variables.trainable_variables()), 1) 284 results = np.array([6, 16, 38, 84, 178, 368]) 285 self.assertAllEqual(results, self.evaluate(r)) 286 287 @test_util.run_in_graph_and_eager_modes 288 def testScanFoldl_Nested(self): 289 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data") 290 inner_elems = constant_op.constant([0.5, 0.5], name="data") 291 292 def r_inner(a, x): 293 return functional_ops.foldl( 294 lambda b, y: b * y * x, inner_elems, initializer=a) 295 296 r = functional_ops.scan(r_inner, elems) 297 298 # t == 0 (returns 1) 299 # t == 1, a == 1, x == 2 (returns 1) 300 # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1 301 # t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1 302 # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25) 303 # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5 304 # t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5 305 # t == 3, a == 2.25, x == 4 (returns 9) 306 # t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5 307 # t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9 308 self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r)) 309 310 @test_util.run_deprecated_v1 311 def testScan_Control(self): 312 with self.cached_session() as sess: 313 s = array_ops.placeholder(dtypes.float32, shape=[None]) 314 b = array_ops.placeholder(dtypes.bool) 315 316 with ops.control_dependencies([b]): 317 c = functional_ops.scan(lambda a, x: x * a, s) 318 self.assertAllClose( 319 np.array([1.0, 3.0, 9.0]), sess.run(c, {s: [1, 3, 3], 320 b: True})) 321 322 @test_util.run_deprecated_v1 323 def testScan_Grad(self): 324 with self.cached_session(): 325 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") 326 v = constant_op.constant(2.0, name="v") 327 328 # pylint: disable=unnecessary-lambda 329 r = functional_ops.scan( 330 lambda a, x: math_ops.multiply(a, x), elems, initializer=v) 331 # pylint: enable=unnecessary-lambda 332 r = gradients_impl.gradients(r, v)[0] 333 self.assertAllEqual(873.0, self.evaluate(r)) 334 335 @test_util.run_deprecated_v1 336 def testScanGradientWithPartStopGradient(self): 337 a = variables.Variable(0.0, name="a") 338 b = variables.Variable(0.0, name="b") 339 elems = array_ops.zeros(5) 340 l0, l1 = functional_ops.scan( 341 lambda elem_, input_: (a, b), elems, initializer=(0., 0.)) 342 loss = l0 + array_ops.stop_gradient(l1) 343 grad = gradients_impl.gradients(ys=[loss], xs=[a, b]) 344 with self.test_session(): 345 self.evaluate(variables.global_variables_initializer()) 346 self.evaluate(grad) 347 348 @test_util.run_in_graph_and_eager_modes 349 def testFoldShape(self): 350 x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) 351 352 def fn(_, current_input): 353 return current_input 354 355 initializer = constant_op.constant([0, 0, 0]) 356 y = functional_ops.foldl(fn, x, initializer=initializer) 357 self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) 358 359 @test_util.run_in_graph_and_eager_modes 360 def testScanShape(self): 361 x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) 362 363 def fn(_, current_input): 364 return current_input 365 366 initializer = constant_op.constant([0, 0, 0]) 367 y = functional_ops.scan(fn, x, initializer=initializer) 368 self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) 369 370 # TODO(akshayka): this test fails in eager: the iterable is of length 0 so 371 # so the body of the while loop never executes 372 @test_util.run_deprecated_v1 373 def testScanEmptyTensor(self): 374 with self.cached_session(): 375 x = functional_ops.scan( 376 lambda x, _: x, math_ops.range(0), initializer=array_ops.ones([2, 4])) 377 self.assertAllEqual([0, 2, 4], x.get_shape()) 378 self.assertAllEqual(x.get_shape(), self.evaluate(x).shape) 379 380 @test_util.run_deprecated_v1 381 def testScanUnknownShape(self): 382 x = array_ops.placeholder(dtypes.float32) 383 initializer = array_ops.placeholder(dtypes.float32) 384 385 def fn(_, current_input): 386 return current_input 387 388 y = functional_ops.scan(fn, x, initializer=initializer) 389 self.assertIs(None, y.get_shape().dims) 390 391 @test_util.run_deprecated_v1 392 def testScanVaryingShape(self): 393 with self.cached_session() as sess: 394 x = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 2]) 395 x_t = array_ops.transpose(x) 396 # scan over dimension 0 (with shape None) 397 result = functional_ops.scan(lambda a, x: a + x, x) 398 # scanned over transposed dimension 0 (with shape 2) 399 result_t = functional_ops.scan(lambda a, x: a + x, x_t, infer_shape=False) 400 # ensure gradients can be calculated 401 result_grad = gradients_impl.gradients(result, [x])[0] 402 result_t_grad = gradients_impl.gradients(result_t, [x_t])[0] 403 404 # smoke test to ensure they all evaluate 405 sess.run([result, result_t, result_grad, result_t_grad], 406 feed_dict={x: [[1.0, 2.0]]}) 407 408 @test_util.run_deprecated_v1 409 def testRemoteFunction(self): 410 worker_config = config_pb2.ConfigProto() 411 worker_config.device_count["CPU"] = 2 412 worker, _ = test_util.create_local_cluster( 413 1, 1, worker_config=worker_config) 414 415 @function.Defun(dtypes.int32, dtypes.int32) 416 def _remote_fn(a, b): 417 return math_ops.multiply(a, b) 418 419 with ops.device("/job:ps/task:0"): 420 a = variables.Variable(2, dtype=dtypes.int32) 421 b = variables.Variable(3, dtype=dtypes.int32) 422 423 with ops.device("/job:worker/replica:0/task:0/cpu:0"): 424 remote_op = functional_ops.remote_call( 425 args=[a, b], 426 Tout=[dtypes.int32], 427 f=_remote_fn, 428 target="/job:worker/replica:0/task:0/cpu:1") 429 430 with session.Session(worker[0].target) as sess: 431 self.evaluate(variables.global_variables_initializer()) 432 mul = self.evaluate(remote_op) 433 self.assertEqual(mul, [6]) 434 435 @test_util.run_deprecated_v1 436 def testRemoteFunctionDirectSession(self): 437 worker_config = config_pb2.ConfigProto() 438 worker_config.device_count["CPU"] = 2 439 440 @function.Defun(dtypes.int32, dtypes.int32) 441 def _remote_fn(a, b): 442 return math_ops.multiply(a, b) 443 444 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 445 a = variables.Variable(2, dtype=dtypes.int32) 446 b = variables.Variable(3, dtype=dtypes.int32) 447 448 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 449 remote_op = functional_ops.remote_call( 450 args=[a, b], 451 Tout=[dtypes.int32], 452 f=_remote_fn, 453 target="/job:localhost/replica:0/task:0/cpu:1") 454 455 with self.test_session(config=worker_config) as sess: 456 self.evaluate(variables.global_variables_initializer()) 457 mul = self.evaluate(remote_op) 458 self.assertEqual(mul, [6]) 459 460 @test_util.run_deprecated_v1 461 def testRemoteFunctionSameDeviceDirectSession(self): 462 463 @function.Defun(dtypes.int32, dtypes.int32) 464 def _remote_fn(a, b): 465 return math_ops.multiply(a, b) 466 467 with ops.device("/cpu:0"): 468 a = variables.Variable(2, dtype=dtypes.int32) 469 b = variables.Variable(3, dtype=dtypes.int32) 470 471 with ops.device("/cpu:0"): 472 remote_op = functional_ops.remote_call( 473 args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0") 474 475 with self.cached_session() as sess: 476 self.evaluate(variables.global_variables_initializer()) 477 mul = self.evaluate(remote_op) 478 self.assertEqual(mul, [6]) 479 480 @test_util.run_deprecated_v1 481 def testRemoteFunctionCPUGPU(self): 482 if not test_util.is_gpu_available(): 483 self.skipTest("No GPU available") 484 485 @function.Defun(dtypes.float32, dtypes.float32) 486 def _remote_fn(a, b): 487 return math_ops.multiply(a, b) 488 489 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 490 a = variables.Variable(2, dtype=dtypes.float32) 491 b = variables.Variable(3, dtype=dtypes.float32) 492 493 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 494 remote_op = functional_ops.remote_call( 495 args=[a, b], 496 Tout=[dtypes.float32], 497 f=_remote_fn, 498 target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0 499 500 with self.cached_session() as sess: 501 self.evaluate(variables.global_variables_initializer()) 502 mul = self.evaluate(remote_op) 503 self.assertEqual(mul, 9.0) 504 505 @test_util.run_deprecated_v1 506 def testRemoteFunctionGPUCPU(self): 507 if not test_util.is_gpu_available(): 508 self.skipTest("No GPU available") 509 510 @function.Defun(dtypes.float32, dtypes.float32) 511 def _remote_fn(a, b): 512 return math_ops.multiply(a, b) 513 514 with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): 515 a = variables.Variable(2, dtype=dtypes.float32) 516 b = variables.Variable(3, dtype=dtypes.float32) 517 518 with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): 519 remote_op = functional_ops.remote_call( 520 args=[a, b], 521 Tout=[dtypes.float32], 522 f=_remote_fn, 523 target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0 524 525 with self.cached_session() as sess: 526 self.evaluate(variables.global_variables_initializer()) 527 mul = self.evaluate(remote_op) 528 self.assertEqual(mul, 9.0) 529 530 @test_util.run_deprecated_v1 531 def testRemoteFunctionGPUCPUStrings(self): 532 if not test_util.is_gpu_available(): 533 self.skipTest("No GPU available") 534 535 @function.Defun(dtypes.string) 536 def _remote_fn(inp): 537 return array_ops.identity(inp) 538 539 a = array_ops.constant("a") 540 541 with ops.device("/gpu:0"): 542 remote_op = functional_ops.remote_call( 543 args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0") 544 545 with self.cached_session() as sess: 546 ret = self.evaluate(remote_op) 547 self.assertAllEqual(ret, [b"a"]) 548 549 @test_util.run_deprecated_v1 550 def testRemoteFunctionCrossProcess(self): 551 workers, _ = test_util.create_local_cluster(2, 1) 552 553 @function.Defun(dtypes.float32, dtypes.float32) 554 def _remote_fn(a, b): 555 return math_ops.multiply(a, b) 556 557 with ops.device("/job:ps/task:0"): 558 a = variables.Variable(2, dtype=dtypes.float32) 559 b = variables.Variable(3, dtype=dtypes.float32) 560 561 with ops.device("/job:worker/replica:0/task:0/cpu:0"): 562 remote_op = functional_ops.remote_call( 563 args=[a, b], 564 Tout=[dtypes.float32], 565 f=_remote_fn, 566 target="/job:worker/replica:0/task:1/cpu:0")[0] + 3.0 567 568 with session.Session(workers[0].target) as sess: 569 self.evaluate(variables.global_variables_initializer()) 570 mul = self.evaluate(remote_op) 571 self.assertEqual(mul, 9) 572 573 @test_util.run_v2_only 574 def testRemoteFunctionCancellation(self): 575 context._reset_context() 576 logical_devices = [] 577 logical_devices.append(context.LogicalDeviceConfiguration()) 578 logical_devices.append(context.LogicalDeviceConfiguration()) 579 framework_config.set_logical_device_configuration( 580 framework_config.list_physical_devices("CPU")[0], logical_devices) 581 582 @function.Defun(dtypes.float32) 583 def _remote_fn(v): 584 # We run two collectives here to make sure we cancel in the middle of the 585 # RemoteCall. The second one should never finish. 586 anchor = collective_ops.all_reduce_v2( 587 v, group_size=2, group_key=1, instance_key=1) 588 with ops.control_dependencies([anchor]): 589 return collective_ops.all_reduce_v2( 590 v, group_size=2, group_key=1, instance_key=2) 591 592 @eager_def_function.function 593 def run(): 594 with ops.device("/cpu:0"): 595 return functional_ops.remote_call( 596 args=[constant_op.constant([1.])] + _remote_fn.captured_inputs, 597 Tout=[dtypes.float32], 598 f=_remote_fn, 599 target="/cpu:1")[0] 600 601 async_executor = executor.new_executor(enable_async=True) 602 cancel_mgr = cancellation.CancellationManager() 603 with context.executor_scope(async_executor): 604 # This should never finish. 605 cancel_mgr.get_cancelable_function(run.get_concrete_function())() 606 with ops.device("/cpu:0"): 607 collective_ops.all_reduce_v2([1.], 608 group_size=2, 609 group_key=1, 610 instance_key=1) 611 cancel_mgr.start_cancel() 612 with self.assertRaises(errors.CancelledError): 613 async_executor.wait() 614 615 @test_util.run_deprecated_v1 616 def testIf(self): 617 618 @function.Defun(dtypes.float32) 619 def Twice(x): 620 return x * 2 621 622 @function.Defun(dtypes.float32) 623 def Thrice(x): 624 return x * 3 + 1 625 626 with self.test_session(use_gpu=False) as sess: 627 628 x = array_ops.placeholder(dtypes.float32) 629 ret = functional_ops.If(math_ops.greater(x, 0), [x], Twice, Thrice)[0] 630 631 self.assertAllEqual(sess.run(ret, feed_dict={x: 9.}), 18.) 632 self.assertAllEqual(sess.run(ret, feed_dict={x: -8.}), -23.) 633 self.assertAllEqual(sess.run(ret, feed_dict={x: 0.}), 1.) 634 635 def testWhile(self): 636 637 for use_gpu in (True, False): 638 with ops.Graph().as_default() as g: 639 640 @function.Defun(*[dtypes.float32] * 2) 641 def Cond(n, unused_x): 642 return n > 0 643 644 @function.Defun(*[dtypes.float32] * 2) 645 def Body(n, x): 646 return n - 1, x + n 647 648 def Run(sess, n): 649 return sess.run(functional_ops.While([n, 0.], Cond, Body))[1] 650 651 with self.session(graph=g, use_gpu=use_gpu) as sess: 652 self.assertAllEqual(Run(sess, 20.), 210.) 653 self.assertAllEqual(Run(sess, 100.), 5050.) 654 655 def testToBool(self): 656 # For 0D tensors, the truthiness depends on whether the value is "zero". 657 self.assertAllEqual(gen_functional_ops.to_bool(0), False) 658 self.assertAllEqual(gen_functional_ops.to_bool(1), True) 659 self.assertAllEqual(gen_functional_ops.to_bool(42), True) 660 self.assertAllEqual(gen_functional_ops.to_bool(0.), False) 661 self.assertAllEqual(gen_functional_ops.to_bool(1.), True) 662 self.assertAllEqual(gen_functional_ops.to_bool(42.), True) 663 self.assertAllEqual(gen_functional_ops.to_bool(False), False) 664 self.assertAllEqual(gen_functional_ops.to_bool(True), True) 665 # For strings, "zero" is the empty string. 666 self.assertAllEqual(gen_functional_ops.to_bool(""), False) 667 self.assertAllEqual(gen_functional_ops.to_bool("a"), True) 668 669 # For >0D tensors, the truthiness only depends on whether there are 670 # elements or not. 671 self.assertAllEqual(gen_functional_ops.to_bool([]), False) 672 self.assertAllEqual(gen_functional_ops.to_bool([[]]), False) 673 self.assertAllEqual(gen_functional_ops.to_bool([[[]]]), False) 674 self.assertAllEqual(gen_functional_ops.to_bool([0]), True) 675 self.assertAllEqual(gen_functional_ops.to_bool([1]), True) 676 self.assertAllEqual(gen_functional_ops.to_bool([[0]]), True) 677 self.assertAllEqual(gen_functional_ops.to_bool([False]), True) 678 self.assertAllEqual(gen_functional_ops.to_bool([True]), True) 679 680 # Like above, but using int32 in order to ensure that int32 tensors don't get 681 # copied to the GPU during the application of the while. 682 def testWhileInt32(self): 683 with ops.Graph().as_default() as g: 684 685 @function.Defun(*[dtypes.int32] * 2) 686 def Cond(n, unused_x): 687 return n > 0 688 689 @function.Defun(*[dtypes.int32] * 2) 690 def Body(n, x): 691 return n - 1, x + n 692 693 def Run(sess, n): 694 return sess.run(functional_ops.While([n, 0], Cond, Body))[1] 695 696 with self.session(graph=g, use_gpu=True) as sess: 697 self.assertAllEqual(Run(sess, 20), 210) 698 self.assertAllEqual(Run(sess, 100), 5050) 699 700 @test_util.run_deprecated_v1 701 def testWhileLowering(self): 702 703 def Run(n, fetch_by_name): 704 for use_gpu in (True, False): 705 with ops.Graph().as_default() as g: 706 707 @function.Defun(*[dtypes.float32] * 2) 708 def Cond(n, unused_x): 709 return n > 0 710 711 @function.Defun(*[dtypes.float32] * 2) 712 def Body(n, x): 713 return n - 1, x + n 714 715 # outputs: [0, n*(n+1)/2] 716 outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while") 717 718 # `outputs` is the list of output tensors of the While op. We 719 # arbitrarily choose the 0th tensor to get the While op and set the 720 # lowering attribute on it. 721 outputs[0].op._set_attr("_lower_using_switch_merge", 722 attr_value_pb2.AttrValue(b=True)) 723 if not fetch_by_name: 724 fetch = outputs[1] 725 else: 726 fetch = "my_while:1" 727 with self.session(graph=g, use_gpu=use_gpu) as sess: 728 return self.evaluate(fetch) 729 730 self.assertAllEqual(Run(20., False), 210.) 731 self.assertAllEqual(Run(20., True), 210.) 732 self.assertAllEqual(Run(100., False), 5050.) 733 self.assertAllEqual(Run(100., True), 5050.) 734 735 @test_util.run_v1_only("b/120545219") 736 @test_util.disable_xla("b/123337890") # Different error message 737 def testWhileError(self): 738 for use_gpu in (True, False): 739 with ops.Graph().as_default() as g: 740 741 @function.Defun(*[dtypes.float32] * 2) 742 def Cond(n, unused_x): 743 return n > 0 744 745 @function.Defun(*[dtypes.float32] * 2) 746 def CondReturnsTooManyArgs(n, x): 747 return n > 0, x 748 749 @function.Defun(*[dtypes.float32] * 2) 750 def Body(n, x): 751 return n - 1, x + n 752 753 @function.Defun(*[dtypes.float32] * 2) 754 def BodyReturnsTooManyArgs(n, x): 755 return n - 1, x + n, x 756 757 with self.session(graph=g, use_gpu=use_gpu): 758 with self.assertRaisesRegex( 759 errors.InvalidArgumentError, 760 "Expected a single scalar.*got 2 tensors."): 761 functional_ops.While([5., 0.], CondReturnsTooManyArgs, 762 Body)[0].eval() 763 with self.assertRaisesRegex( 764 errors.InvalidArgumentError, 765 "While loop body returned 3 arguments. Expected: 2"): 766 functional_ops.While([5., 0.], Cond, 767 BodyReturnsTooManyArgs)[0].eval() 768 769 def testWhileInMultipleSubgraphs(self): 770 771 for use_gpu in (True, False): 772 with ops.Graph().as_default() as g: 773 774 @function.Defun(*[dtypes.float32] * 2) 775 def Cond(n, x): # pylint: disable=unused-argument 776 return n > 0 777 778 @function.Defun(*[dtypes.float32] * 2) 779 def Body(n, x): 780 return n - 1, x + n 781 782 with self.session(graph=g, use_gpu=use_gpu) as sess: 783 n = array_ops.placeholder(dtypes.float32) 784 _, result = functional_ops.While([n, 0.], Cond, Body) 785 c = constant_op.constant(37.) 786 787 self.assertAllEqual(210., sess.run(result, feed_dict={n: 20.})) 788 self.assertAllEqual(5050., sess.run(result, feed_dict={n: 100.})) 789 # Test that the result is the same when we run a different subgraph. 790 self.assertAllEqual(5050., 791 sess.run([result, c], feed_dict={n: 100.})[0]) 792 793 # pylint: disable=cell-var-from-loop 794 def testWhileCapturedInputs(self): 795 for use_gpu in (True, False): 796 with ops.Graph().as_default() as g: 797 v = variables.Variable(1.0) 798 799 def TestCond(n, *args): 800 del args 801 return n < 10 802 803 @function.Defun(*[dtypes.float32] * 2) 804 def TestUnary(n, x): 805 return math_ops.add(n, 1), x + n + v 806 807 @function.Defun(*[dtypes.float32] * 3) 808 def TestBinary(n, x, x2): 809 return math_ops.add(n, 1), x + n + v, x2 + v 810 811 with self.session(graph=g, use_gpu=use_gpu) as sess: 812 result_unary = functional_ops.While( 813 [1.0, 0.], 814 function.Defun(*[dtypes.float32] * 2)(TestCond), TestUnary) 815 result_binary = functional_ops.While( 816 [1.0, 0., 0.], 817 function.Defun(*[dtypes.float32] * 3)(TestCond), TestBinary) 818 self.evaluate(variables.global_variables_initializer()) 819 assert len(result_unary) == 2 820 self.assertEqual([10.0, 54.0], self.evaluate(result_unary)) 821 assert len(result_binary) == 3 822 self.assertEqual([10.0, 54.0, 9.0], self.evaluate(result_binary)) 823 824 def TestCondCapture(n, *args): 825 del args 826 return math_ops.cast(n, dtypes.float32) + v < 10 827 828 with self.assertRaises(ValueError): 829 _ = functional_ops.While( 830 [1], 831 function.Defun(dtypes.int32)(TestCondCapture), 832 function.Defun(dtypes.int32, dtypes.float32)(TestUnary)) 833 834 # pylint: enable=cell-var-from-loop 835 836 def _tfSum(self, use_gpu, rewrite_with_while): 837 with ops.Graph().as_default() as g: 838 with self.session(graph=g, use_gpu=use_gpu) as sess: 839 840 @function.Defun(dtypes.int32, dtypes.float32) 841 def Body(n, x): 842 return x + math_ops.cast(n, dtypes.float32) 843 844 xs = [ 845 # 1 + 2 + ... + 20 846 functional_ops.For( 847 1, 21, 1, [0.], Body, rewrite_with_while=rewrite_with_while)[0], 848 # 100 + 99 + ... + 1 849 functional_ops.For( 850 100, 0, -1, [0.], Body, rewrite_with_while=rewrite_with_while) 851 [0], 852 ] 853 xvals = self.evaluate(xs) 854 self.assertAllEqual(210, xvals[0]) 855 self.assertAllEqual(5050, xvals[1]) 856 857 def testFor(self): 858 for use_gpu in (True, False): 859 self._tfSum(use_gpu, False) 860 861 def testForWithWhile(self): 862 for use_gpu in (True, False): 863 self._tfSum(use_gpu, True) 864 865 def testForWithWhileNaming(self): 866 g = ops.Graph() 867 with g.as_default(): 868 869 @function.Defun(dtypes.int32, dtypes.float32, func_name="TestBody") 870 def TestBody(n, x): 871 return x + math_ops.cast(n, dtypes.float32) 872 873 _ = functional_ops.For( 874 1, 21, 1, [0.], TestBody, rewrite_with_while=True)[0] 875 876 names = [] 877 for func in g.as_graph_def().library.function: 878 names.append(func.signature.name) 879 self.assertTrue("TestBody" in names) 880 self.assertTrue("TestBody_Cond" in names) 881 self.assertTrue("TestBody_Body" in names) 882 883 @test_util.run_deprecated_v1 884 def testForCapturedInputs(self): 885 v = variables.Variable(1.0) 886 887 @function.Defun(dtypes.int32) 888 def TestNullary(n): 889 v + math_ops.cast(n, dtypes.float32) # pylint: disable=expression-not-assigned 890 891 @function.Defun(dtypes.int32, dtypes.float32) 892 def TestUnary(n, x): 893 return x + math_ops.cast(n, dtypes.float32) + v 894 895 @function.Defun(dtypes.int32, dtypes.float32, dtypes.float32) 896 def TestBinary(n, x, x2): 897 return x + math_ops.cast(n, dtypes.float32) + v, x2 + v 898 899 for rewrite_with_while in (True, False): 900 use_gpu = not rewrite_with_while 901 with self.test_session(use_gpu=use_gpu) as sess: 902 result_nullary = functional_ops.For( 903 1, 10, 1, [], TestNullary, 904 rewrite_with_while=rewrite_with_while) 905 result_unary = functional_ops.For( 906 1, 10, 1, [0.], TestUnary, 907 rewrite_with_while=rewrite_with_while) 908 result_binary = functional_ops.For( 909 1, 10, 1, [0., 0.], TestBinary, 910 rewrite_with_while=rewrite_with_while) 911 self.evaluate(variables.global_variables_initializer()) 912 assert not result_nullary 913 # The nullary variant doesn't return anything so we can't easily run it. 914 # As a total hack, fetch the operation by name and run it. 915 sess.run(ops.get_default_graph().get_operation_by_name( 916 "While" if rewrite_with_while else "For")) 917 assert len(result_unary) == 1 918 self.assertEqual([54.0], self.evaluate(result_unary)) 919 assert len(result_binary) == 2 920 self.assertEqual([54.0, 9.0], self.evaluate(result_binary)) 921 922 def _tfMLP(self, xval, wsval, bsval, rewrite_with_while): 923 # On GPU, don't rewrite using a while loop. 924 use_gpu = not rewrite_with_while 925 with self.test_session(use_gpu=use_gpu): 926 927 @function.Defun(dtypes.int32, *[dtypes.float64] * 3) 928 def MLP(i, a, ws, bs): 929 a = math_ops.tanh(math_ops.matmul(a, ws[i, :]) + bs[i, :]) 930 return a, ws, bs 931 932 ret = functional_ops.For( 933 0, 934 wsval.shape[0], 935 1, [xval, wsval, bsval], 936 MLP, 937 rewrite_with_while=rewrite_with_while)[0] 938 939 return self.evaluate(ret) 940 941 def _npMLP(self, xval, wsval, bsval): 942 for i in range(wsval.shape[0]): 943 xval = np.tanh(np.dot(xval, wsval[i, :]) + bsval[i, :]) 944 return xval 945 946 def _testForMLP(self, rewrite_with_while): 947 # We construct a 5-layer Multi-Layer Perceptron network here. 948 # Each layer have the same number of hidden unites (3), and the 949 # activation function is tanh(). We feed the input (xval) with 950 # batch size 2. 951 xval = np.random.normal(size=(2, 3)) 952 wsval = np.random.normal(size=(5, 3, 3)) 953 bsval = np.random.normal(size=(5, 3)) 954 np_ans = self._npMLP(xval, wsval, bsval) 955 tf_for_ans = self._tfMLP(xval, wsval, bsval, rewrite_with_while) 956 self.assertAllClose(np_ans, tf_for_ans) 957 958 @test_util.run_deprecated_v1 959 def testForMLP(self): 960 self._testForMLP(False) 961 962 @test_util.run_deprecated_v1 963 @test_util.disable_xla( 964 "Test uses strided slice without compile time constant values") 965 def testForMLPWhile(self): 966 self._testForMLP(True) 967 968 @test_util.run_v1_only("b/120545219") 969 def testForError(self): 970 971 @function.Defun(dtypes.int32, dtypes.float32) 972 def Foo(i, v): 973 return math_ops.cast(i, dtypes.float32) + v 974 975 @function.Defun(dtypes.int32, dtypes.float32) 976 def ReturnsTooManyArgs(unused_i, v): 977 return v, v 978 979 with self.test_session(): 980 with self.assertRaisesRegex(errors.InvalidArgumentError, 981 "must be a scalar"): 982 functional_ops.For([0], 10, 1, [0.0], Foo)[0].eval() 983 with self.assertRaisesRegex(errors.InvalidArgumentError, 984 "Invalid start/limit/delta"): 985 functional_ops.For(0, 10, -1, [0.0], Foo)[0].eval() 986 with self.assertRaisesRegex( 987 errors.InvalidArgumentError, 988 "For loop body returned 2 arguments. Expected: 1"): 989 functional_ops.For(0, 10, 1, [0.0], ReturnsTooManyArgs)[0].eval() 990 991 @test_util.run_deprecated_v1 992 def testGradient(self): 993 994 @function.Defun(dtypes.float32) 995 def Poly(x): 996 # y = 2x^3+3x^2+4x+8 997 return 2 * x * x * x + 3 * x * x + 4 * x + 8 998 999 @function.Defun(dtypes.float32) 1000 def Grad(x): 1001 # dy/dx = dy/dy * dy/dx = 1.0 * (6x^2+6x+4) 1002 return functional_ops.Gradient([x, 1.0], Poly)[0] 1003 1004 with self.test_session(use_gpu=False) as sess: 1005 a = constant_op.constant(0.) 1006 avals = [Poly(a), Grad(a)] 1007 b = constant_op.constant(1.) 1008 bvals = [Poly(b), Grad(b)] 1009 self.assertAllEqual(self.evaluate(avals), [8., 4.]) 1010 self.assertAllEqual(self.evaluate(bvals), [17., 16.]) 1011 1012 @test_util.run_v2_only 1013 def testCollective(self): 1014 context._reset_context() 1015 logical_devices = [] 1016 logical_devices.append(context.LogicalDeviceConfiguration()) 1017 logical_devices.append(context.LogicalDeviceConfiguration()) 1018 framework_config.set_logical_device_configuration( 1019 framework_config.list_physical_devices("CPU")[0], logical_devices) 1020 1021 @function.Defun(dtypes.float32) 1022 def collective_fn(t): 1023 # Run a dummy collective of group size 1 to test the setup. 1024 return collective_ops.all_reduce_v2( 1025 t, group_size=1, group_key=1, instance_key=1) 1026 1027 @eager_def_function.function 1028 def run(): 1029 with ops.device("/cpu:0"): 1030 return functional_ops.remote_call( 1031 args=[constant_op.constant([1.])] + collective_fn.captured_inputs, 1032 Tout=[dtypes.float32], 1033 f=collective_fn, 1034 target="/cpu:1") 1035 1036 self.assertAllEqual(run(), [[1.]]) 1037 1038 1039# TODO(akshayka): Replace `function.Defun` with tf.contrib.eager.defun` in the 1040# below test cases. 1041class PartitionedCallTest(test.TestCase): 1042 1043 @test_util.run_deprecated_v1 1044 def testRemoteDeviceInPartitionedCallOp(self): 1045 workers, _ = test_util.create_local_cluster(2, 0) 1046 1047 worker0_device = "/job:worker/replica:0/task:0/cpu:0" 1048 worker1_device = "/job:worker/replica:0/task:1/cpu:0" 1049 1050 @eager_def_function.function 1051 def f(a, b): 1052 return a + b 1053 1054 with session.Session(workers[0].target) as sess: 1055 with ops.device(worker0_device): 1056 a = variable_scope.get_variable( 1057 "a", initializer=constant_op.constant(1.), use_resource=True) 1058 with ops.device(worker1_device): 1059 b = variable_scope.get_variable( 1060 "b", initializer=constant_op.constant(1.), use_resource=True) 1061 1062 sess.run(variables.global_variables_initializer()) 1063 1064 config = config_pb2.ConfigProto() 1065 config.share_cluster_devices_in_session = True 1066 1067 with session.Session(workers[0].target, config=config) as sess: 1068 res = sess.run(f(a, b)) 1069 1070 self.assertEqual(res, 2) 1071 1072 @test_util.run_deprecated_v1 1073 def testBasicSingleDevice(self): 1074 1075 @function.Defun(*[dtypes.float32] * 2) 1076 def Body(x, y): 1077 with ops.device("/cpu:0"): 1078 a = x + x 1079 b = y + y 1080 return a + b 1081 1082 output, = self.evaluate( 1083 functional_ops.partitioned_call( 1084 args=[constant_op.constant(1.), 1085 constant_op.constant(2.)], f=Body)) 1086 self.assertEqual(output, 6.) 1087 1088 @test_util.run_deprecated_v1 1089 def testBasicMultiDevice(self): 1090 config = config_pb2.ConfigProto(device_count={"CPU": 3}) 1091 1092 @function.Defun(*[dtypes.float32] * 2) 1093 def Body(x, y): 1094 # if x = 1, y = 2, ... 1095 with ops.device("/cpu:0"): 1096 # a:= 1 + 1 = 2 1097 a = x + x 1098 with ops.device("/cpu:1"): 1099 # b:= 2 + 2 = 4 1100 b = a + y 1101 with ops.device("/cpu:2"): 1102 # c:= 2 + 4 = 6 1103 c = a + b 1104 # a + b + c = 2 + 4 + 6 = 12 1105 return a + b + c 1106 1107 with self.test_session(config=config): 1108 output, = functional_ops.partitioned_call( 1109 args=[constant_op.constant(1.), 1110 constant_op.constant(2.)], f=Body) 1111 self.assertEqual(self.evaluate(output), 12.) 1112 1113 @test_util.run_deprecated_v1 1114 def testBasicMultiDeviceGPU(self): 1115 if not test_util.is_gpu_available(): 1116 return 1117 1118 @function.Defun(*[dtypes.float32] * 2) 1119 def Body(x, y): 1120 with ops.device("/gpu:0"): 1121 a = x + x 1122 b = y + y 1123 with ops.device("/cpu:0"): 1124 c = a + b 1125 return c 1126 1127 output, = self.evaluate( 1128 functional_ops.partitioned_call( 1129 args=[constant_op.constant(1.), 1130 constant_op.constant(2.)], f=Body)) 1131 self.assertEqual(output, 6.) 1132 1133 @test_util.run_deprecated_v1 1134 def testBasicNoDeviceAnnotations(self): 1135 1136 @function.Defun(*[dtypes.float32] * 2) 1137 def Body(x, y): 1138 a = x + x 1139 b = y + y 1140 return a + b 1141 1142 output, = self.evaluate( 1143 functional_ops.partitioned_call( 1144 args=[constant_op.constant(1.), 1145 constant_op.constant(2.)], f=Body)) 1146 self.assertEqual(output, 6.) 1147 1148 @test_util.run_deprecated_v1 1149 def testShardsRunOnRequestedDevices(self): 1150 config = config_pb2.ConfigProto(device_count={"CPU": 4}) 1151 1152 @function.Defun() 1153 def Body(): 1154 # Serialize DT_RESOURCE handles as DT_STRINGs, which encode the device on 1155 # which the resource was created, so that we can verify that ops were 1156 # actually run on the requested devices. 1157 # 1158 # TODO(akshayka): Provide a cleaner, more idiomatic API for obtaining the 1159 # name of the device on which a resource lives / for determining the 1160 # device on which an op ran. 1161 with ops.device("/cpu:0"): 1162 s1 = iterator_ops.Iterator.from_structure( 1163 (dtypes.float32,)).string_handle() 1164 with ops.device("/cpu:1"): 1165 s2 = iterator_ops.Iterator.from_structure( 1166 (dtypes.float32,)).string_handle() 1167 with ops.device("/cpu:2"): 1168 s3 = iterator_ops.Iterator.from_structure( 1169 (dtypes.float32,)).string_handle() 1170 return s1, s2, s3 1171 1172 with self.test_session(config=config, use_gpu=True) as sess: 1173 outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body)) 1174 self.assertIn(compat.as_bytes("CPU:0"), outputs[0]) 1175 self.assertIn(compat.as_bytes("CPU:1"), outputs[1]) 1176 self.assertIn(compat.as_bytes("CPU:2"), outputs[2]) 1177 1178 @test_util.run_deprecated_v1 1179 def testAssignAddResourceVariable(self): 1180 1181 v = resource_variable_ops.ResourceVariable(1.0) 1182 1183 @function.Defun() 1184 def AssignAdd(): 1185 v.assign_add(1.0) 1186 1187 op = functional_ops.partitioned_call( 1188 args=AssignAdd.captured_inputs, f=AssignAdd) 1189 _ = self.evaluate(variables.global_variables_initializer()) 1190 _ = self.evaluate(op) 1191 value = self.evaluate(v.read_value()) 1192 self.assertEqual(value, 2.0) 1193 1194 @test_util.run_deprecated_v1 1195 def testFunctionWithResourcesOnDifferentDevices(self): 1196 if not test_util.is_gpu_available(): 1197 self.skipTest("No GPUs available.") 1198 1199 with ops.device("/cpu:0"): 1200 v_cpu_zero = resource_variable_ops.ResourceVariable( 1201 [0.0, 1.0, 2.0], name="v_cpu_zero") 1202 1203 with ops.device("/cpu:1"): 1204 v_cpu_one = resource_variable_ops.ResourceVariable( 1205 [0.0, 1.0, 2.0], name="v_cpu_one") 1206 1207 with ops.device("/gpu:0"): 1208 v_gpu = resource_variable_ops.ResourceVariable( 1209 [0.0, 1.0, 2.0], name="v_gpu") 1210 1211 def sum_gather(): 1212 cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_zero, [1, 2])) 1213 also_cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_one, [1, 2])) 1214 gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) 1215 return cpu_result, also_cpu_result, gpu_result 1216 1217 defined = function.Defun()(sum_gather) 1218 with self.test_session( 1219 config=config_pb2.ConfigProto( 1220 allow_soft_placement=False, 1221 log_device_placement=True, 1222 device_count={"CPU": 2})) as sess: 1223 self.evaluate(variables.global_variables_initializer()) 1224 expected = self.evaluate(sum_gather()) 1225 result = sess.run( 1226 functional_ops.partitioned_call( 1227 args=defined.captured_inputs, f=defined)) 1228 self.assertAllEqual(expected, result) 1229 1230 # Use an invalid executor name to test the plumbing of the executor_type attr. 1231 @test_util.run_v1_only("b/120545219") 1232 def testExecutorTypeAttrExecutorNotFound(self): 1233 @function.Defun(dtypes.int32) 1234 def AddFive(x): 1235 return x + 5 1236 1237 op = functional_ops.partitioned_call( 1238 args=[constant_op.constant([1, 2, 3], dtype=dtypes.int32)], 1239 f=AddFive, 1240 executor_type="NON_EXISTENT_EXECUTOR") 1241 with self.assertRaisesRegex(errors.NotFoundError, "NON_EXISTENT_EXECUTOR"): 1242 self.evaluate(op) 1243 1244 1245@test_util.run_all_in_graph_and_eager_modes 1246@test_util.with_control_flow_v2 1247class FunctionalOpsCaseTest(test.TestCase): 1248 1249 def testCase(self): 1250 @eager_function.defun 1251 def two(x): 1252 return x * 2 1253 1254 @eager_function.defun 1255 def three(x): 1256 return x * 3 1257 1258 @eager_function.defun 1259 def four(x): 1260 return x * 4 1261 1262 def f(branch, x): 1263 tmpl = array_ops.zeros_like(x) 1264 return array_ops.identity(gen_functional_ops.case( 1265 branch, input=[x], Tout=[dtypes.float32], 1266 branches=[f.get_concrete_function(tmpl) 1267 for f in (two, three, four)])[0]) 1268 one = array_ops.ones([]) 1269 self.assertAllEqual(np.float32(2), self.evaluate(f(0, one))) 1270 self.assertAllEqual(np.float32(3), self.evaluate(f(1, one))) 1271 self.assertAllEqual(np.float32(4), self.evaluate(f(2, one))) 1272 self.assertAllEqual(np.float32(4), self.evaluate(f(-1, one))) # <0 default 1273 self.assertAllEqual(np.float32(4), self.evaluate(f(6, one))) # >=N default 1274 1275 @test_util.run_deprecated_v1 1276 @test_util.disable_xla("Don't lower for XLA") 1277 def testSkipEagerCaseLoweringPreservesNameForFetch(self): 1278 for use_gpu in (True, False): 1279 def Run(branch, x, fetch_by_name, use_gpu=use_gpu): 1280 with ops.Graph().as_default() as g: 1281 @function.Defun(dtypes.float32) 1282 def two(x): 1283 return -1, x * 2 1284 1285 @function.Defun(dtypes.float32) 1286 def three(x): 1287 return 0, x * 3 1288 1289 @function.Defun(dtypes.float32) 1290 def four(x): 1291 return 1, x * 4 1292 1293 outputs = gen_functional_ops.case(branch, input=[x], 1294 Tout=[dtypes.int32, dtypes.float32], 1295 branches=[two, three, four], 1296 name="my_case") 1297 1298 # `outputs` is the list of output tensors of the Case op. We 1299 # arbitrarily choose the 0th tensor to get the Case op and set the 1300 # lowering attribute on it. 1301 outputs[0].op._set_attr("_lower_using_switch_merge", 1302 attr_value_pb2.AttrValue(b=True)) 1303 outputs = array_ops.identity_n(outputs) 1304 with self.session(graph=g, use_gpu=use_gpu) as sess: 1305 return sess.run("my_case:1" if fetch_by_name else outputs[1]) 1306 1307 self.assertAllEqual(2 * 1., Run(0, 1., False)) 1308 self.assertAllEqual(2 * 1., Run(0, 1., True)) 1309 self.assertAllEqual(3 * 7., Run(1, 7., False)) 1310 self.assertAllEqual(3 * 7., Run(1, 7., True)) 1311 self.assertAllEqual(4 * -3., Run(2, -3., False)) 1312 self.assertAllEqual(4 * -3., Run(2, -3., True)) 1313 self.assertAllEqual(4 * -4., Run(7, -4., False)) # >= N default 1314 self.assertAllEqual(4 * -4., Run(7, -4., True)) # >= N default 1315 self.assertAllEqual(4 * -5., Run(-1, -5., False)) # <0 default 1316 self.assertAllEqual(4 * -5., Run(-1, -5., True)) # <0 default 1317 1318 @test_util.disable_xla("Don't lower for XLA") 1319 def testCaseLowering(self): 1320 for use_gpu in (True, False): 1321 @eager_function.defun 1322 def Run(branch, x): 1323 @function.Defun(dtypes.float32) 1324 def two(x): 1325 return -1, x * 2 1326 1327 @function.Defun(dtypes.float32) 1328 def three(x): 1329 return 0, x * 3 1330 1331 @function.Defun(dtypes.float32) 1332 def four(x): 1333 return 1, x * 4 1334 1335 outputs = gen_functional_ops.case(branch, input=[x], 1336 Tout=[dtypes.int32, dtypes.float32], 1337 branches=[two, three, four]) 1338 1339 # `outputs` is the list of output tensors of the Case op. We 1340 # arbitrarily choose the 0th tensor to get the Case op and set the 1341 # lowering attribute on it. 1342 outputs[0].op._set_attr("_lower_using_switch_merge", 1343 attr_value_pb2.AttrValue(b=True)) 1344 outputs = array_ops.identity_n(outputs) 1345 return outputs[1] 1346 1347 with ops.device(test.gpu_device_name() if use_gpu else "CPU:0"): 1348 self.assertAllEqual(2 * 1., self.evaluate(Run(0, 1.))) 1349 self.assertAllEqual(3 * 7., self.evaluate(Run(1, 7.))) 1350 self.assertAllEqual(4 * -3., self.evaluate(Run(2, -3.))) 1351 self.assertAllEqual(4 * -4., self.evaluate(Run(7, -4.))) # >=N default 1352 self.assertAllEqual(4 * -5., self.evaluate(Run(-1, -5.))) # <0 default 1353 1354if __name__ == "__main__": 1355 test.main() 1356 1357# pylint: enable=invalid-name 1358