xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/control_flow/functional_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.kernels.functional_ops."""
16
17import numpy as np
18
19from tensorflow.core.framework import attr_value_pb2
20from tensorflow.core.protobuf import config_pb2
21from tensorflow.python.client import session
22from tensorflow.python.data.ops import iterator_ops
23from tensorflow.python.eager import cancellation
24from tensorflow.python.eager import context
25from tensorflow.python.eager import def_function as eager_def_function
26from tensorflow.python.eager import executor
27from tensorflow.python.eager import function as eager_function
28from tensorflow.python.framework import config as framework_config
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import function
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import test_util
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import collective_ops
37from tensorflow.python.ops import functional_ops
38from tensorflow.python.ops import gen_functional_ops
39from tensorflow.python.ops import gradients_impl
40from tensorflow.python.ops import init_ops
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops import resource_variable_ops
43from tensorflow.python.ops import variable_scope
44from tensorflow.python.ops import variables
45import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
46from tensorflow.python.platform import test
47from tensorflow.python.util import compat
48
49
50# pylint: disable=invalid-name
51def simple_scoped_fn(a, x):
52  """Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope."""
53  with variable_scope.variable_scope("body"):
54    # Dummy variable, just to check that scoping works as intended.
55    two = variable_scope.get_variable(
56        "two", [],
57        dtype=dtypes.int32,
58        initializer=init_ops.constant_initializer(2))
59    return math_ops.multiply(math_ops.add(a, x), two)
60
61
62@test_util.with_control_flow_v2
63class FunctionalOpsTest(test.TestCase):
64
65  @test_util.run_in_graph_and_eager_modes
66  def testFoldl_Simple(self):
67    elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
68
69    r = functional_ops.foldl(
70        lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
71        elems)
72    self.assertAllEqual(208, self.evaluate(r))
73
74    r = functional_ops.foldl(
75        lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
76        elems,
77        initializer=10)
78    self.assertAllEqual(880, self.evaluate(r))
79
80  @test_util.run_in_graph_and_eager_modes
81  def testFoldl_SingleInputMultiOutput(self):
82    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
83    initializer = np.array([1, -1.0])
84    r = functional_ops.foldl(lambda a, x: a + x, elems, initializer)
85    r_value = self.evaluate(r)
86
87    self.assertAllEqual(22, r_value[0])
88    self.assertAllEqual(20, r_value[1])
89
90  @test_util.run_in_graph_and_eager_modes
91  def testFoldl_MultiInputSingleOutput(self):
92    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
93    initializer = np.array(1.0)
94    r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems),
95                             initializer)
96    self.assertAllEqual(1, self.evaluate(r))
97
98  @test_util.run_in_graph_and_eager_modes
99  def testFoldl_MultiInputDifferentDimsSingleOutput(self):
100    elems = np.array([[1.0, 1.0, 1.0], [2.0, 3.0, 4.0]])
101    other_elems = np.array([-1.0, 1.0])
102    initializer = np.array([0.0, 0.0, 0.0])
103    r = functional_ops.foldl(lambda a, x: a + x[0] * x[1],
104                             (elems, other_elems), initializer)
105    self.assertAllEqual([1.0, 2.0, 3.0], self.evaluate(r))
106
107  @test_util.run_deprecated_v1
108  def testFoldl_Scoped(self):
109    with self.cached_session() as sess:
110      with variable_scope.variable_scope("root") as varscope:
111        elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
112
113        r = functional_ops.foldl(simple_scoped_fn, elems)
114        # Check that we have the one variable we asked for here.
115        self.assertEqual(len(variables.trainable_variables()), 1)
116        self.assertEqual(variables.trainable_variables()[0].name,
117                         "root/body/two:0")
118        sess.run([variables.global_variables_initializer()])
119        self.assertAllEqual(208, self.evaluate(r))
120
121        # Now let's reuse our single variable.
122        varscope.reuse_variables()
123        r = functional_ops.foldl(simple_scoped_fn, elems, initializer=10)
124        self.assertEqual(len(variables.trainable_variables()), 1)
125        self.assertAllEqual(880, self.evaluate(r))
126
127  @test_util.run_in_graph_and_eager_modes
128  def testFoldr_Simple(self):
129    elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
130
131    r = functional_ops.foldr(
132        lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
133        elems)
134    self.assertAllEqual(450, self.evaluate(r))
135
136    r = functional_ops.foldr(
137        lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
138        elems,
139        initializer=10)
140    self.assertAllEqual(1282, self.evaluate(r))
141
142  @test_util.run_in_graph_and_eager_modes
143  def testFoldr_SingleInputMultiOutput(self):
144    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
145    initializer = np.array([1, -1.0])
146    r = functional_ops.foldr(lambda a, x: a + x, elems, initializer)
147    r_value = self.evaluate(r)
148
149    self.assertAllEqual(22, r_value[0])
150    self.assertAllEqual(20, r_value[1])
151
152  @test_util.run_in_graph_and_eager_modes
153  def testFoldr_MultiInputSingleOutput(self):
154    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
155    initializer = np.array(1.0)
156    r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems),
157                             initializer)
158    self.assertAllEqual(1, self.evaluate(r))
159
160  @test_util.run_deprecated_v1
161  def testFoldr_Scoped(self):
162    with self.cached_session() as sess:
163      with variable_scope.variable_scope("root") as varscope:
164        elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
165
166        r = functional_ops.foldr(simple_scoped_fn, elems)
167        # Check that we have the one variable we asked for here.
168        self.assertEqual(len(variables.trainable_variables()), 1)
169        self.assertEqual(variables.trainable_variables()[0].name,
170                         "root/body/two:0")
171        sess.run([variables.global_variables_initializer()])
172        self.assertAllEqual(450, self.evaluate(r))
173
174        # Now let's reuse our single variable.
175        varscope.reuse_variables()
176        r = functional_ops.foldr(simple_scoped_fn, elems, initializer=10)
177        self.assertEqual(len(variables.trainable_variables()), 1)
178        self.assertAllEqual(1282, self.evaluate(r))
179
180  # pylint: disable=unnecessary-lambda
181  @test_util.run_deprecated_v1
182  def testFold_Grad(self):
183    with self.cached_session():
184      elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
185      v = constant_op.constant(2.0, name="v")
186      r = functional_ops.foldl(
187          lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
188      r = gradients_impl.gradients(r, v)[0]
189      self.assertAllEqual(720.0, self.evaluate(r))
190
191      r = functional_ops.foldr(
192          lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
193      r = gradients_impl.gradients(r, v)[0]
194      self.assertAllEqual(720.0, self.evaluate(r))
195  # pylint: enable=unnecessary-lambda
196
197  @test_util.run_in_graph_and_eager_modes
198  def testScan_Simple(self):
199    elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
200    v = constant_op.constant(2.0, name="v")
201
202    # pylint: disable=unnecessary-lambda
203    r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems)
204    self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r))
205
206    r = functional_ops.scan(
207        lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
208    self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
209    # pylint: enable=unnecessary-lambda
210
211  @test_util.run_in_graph_and_eager_modes
212  def testScan_Reverse(self):
213    elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
214    v = constant_op.constant(2.0, name="v")
215
216    # pylint: disable=unnecessary-lambda
217    r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems,
218                            reverse=True)
219    self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r))
220    r = functional_ops.scan(
221        lambda a, x: math_ops.multiply(a, x), elems, initializer=v,
222        reverse=True)
223    self.assertAllEqual([1440., 1440., 720., 240., 60., 12.],
224                        self.evaluate(r))
225    # pylint: enable=unnecessary-lambda
226
227  @test_util.run_in_graph_and_eager_modes
228  def testScan_SingleInputMultiOutput(self):
229    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
230    initializer = (np.array(1.0), np.array(-1.0))
231    r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems,
232                            initializer)
233    r_value = self.evaluate(r)
234
235    self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
236    self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])
237
238  @test_util.run_in_graph_and_eager_modes
239  def testScan_MultiInputSingleOutput(self):
240    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
241    initializer = np.array(1.0)
242    # Multiply a * 1 each time
243    r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]),
244                            (elems + 1, -elems), initializer)
245    self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r))
246
247  @test_util.run_in_graph_and_eager_modes
248  def testScan_MultiInputSameTypeOutput(self):
249    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
250    r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]),
251                            (elems, -elems))
252    r_value = self.evaluate(r)
253    self.assertAllEqual(np.cumsum(elems), r_value[0])
254    self.assertAllEqual(np.cumsum(-elems), r_value[1])
255
256  @test_util.run_in_graph_and_eager_modes
257  def testScan_MultiOutputMismatchedInitializer(self):
258    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
259    initializer = np.array(1.0)
260    # Multiply a * 1 each time
261    with self.assertRaisesRegex(
262        ValueError, "two structures don't have the same nested structure"):
263      functional_ops.scan(lambda a, x: (a, -a), elems, initializer)
264
265  @test_util.run_deprecated_v1
266  def testScan_Scoped(self):
267    with self.cached_session() as sess:
268      with variable_scope.variable_scope("root") as varscope:
269        elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
270
271        r = functional_ops.scan(simple_scoped_fn, elems)
272        # Check that we have the one variable we asked for here.
273        self.assertEqual(len(variables.trainable_variables()), 1)
274        self.assertEqual(variables.trainable_variables()[0].name,
275                         "root/body/two:0")
276        sess.run([variables.global_variables_initializer()])
277        results = np.array([1, 6, 18, 44, 98, 208])
278        self.assertAllEqual(results, self.evaluate(r))
279
280        # Now let's reuse our single variable.
281        varscope.reuse_variables()
282        r = functional_ops.scan(simple_scoped_fn, elems, initializer=2)
283        self.assertEqual(len(variables.trainable_variables()), 1)
284        results = np.array([6, 16, 38, 84, 178, 368])
285        self.assertAllEqual(results, self.evaluate(r))
286
287  @test_util.run_in_graph_and_eager_modes
288  def testScanFoldl_Nested(self):
289    elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data")
290    inner_elems = constant_op.constant([0.5, 0.5], name="data")
291
292    def r_inner(a, x):
293      return functional_ops.foldl(
294          lambda b, y: b * y * x, inner_elems, initializer=a)
295
296    r = functional_ops.scan(r_inner, elems)
297
298    # t == 0 (returns 1)
299    # t == 1, a == 1, x == 2 (returns 1)
300    #   t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1
301    #   t_1 == 1, b == 1,      y == 0.5, returns b * y * x = 1
302    # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25)
303    #   t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5
304    #   t_1 == 1, b == 1.5,    y == 0.5, returns b * y * x = 1.5*1.5
305    # t == 3, a == 2.25, x == 4 (returns 9)
306    #   t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5
307    #   t_1 == 1, b == 4.5,       y == 0.5, returns b * y * x = 9
308    self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r))
309
310  @test_util.run_deprecated_v1
311  def testScan_Control(self):
312    with self.cached_session() as sess:
313      s = array_ops.placeholder(dtypes.float32, shape=[None])
314      b = array_ops.placeholder(dtypes.bool)
315
316      with ops.control_dependencies([b]):
317        c = functional_ops.scan(lambda a, x: x * a, s)
318      self.assertAllClose(
319          np.array([1.0, 3.0, 9.0]), sess.run(c, {s: [1, 3, 3],
320                                                  b: True}))
321
322  @test_util.run_deprecated_v1
323  def testScan_Grad(self):
324    with self.cached_session():
325      elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
326      v = constant_op.constant(2.0, name="v")
327
328      # pylint: disable=unnecessary-lambda
329      r = functional_ops.scan(
330          lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
331      # pylint: enable=unnecessary-lambda
332      r = gradients_impl.gradients(r, v)[0]
333      self.assertAllEqual(873.0, self.evaluate(r))
334
335  @test_util.run_deprecated_v1
336  def testScanGradientWithPartStopGradient(self):
337    a = variables.Variable(0.0, name="a")
338    b = variables.Variable(0.0, name="b")
339    elems = array_ops.zeros(5)
340    l0, l1 = functional_ops.scan(
341        lambda elem_, input_: (a, b), elems, initializer=(0., 0.))
342    loss = l0 + array_ops.stop_gradient(l1)
343    grad = gradients_impl.gradients(ys=[loss], xs=[a, b])
344    with self.test_session():
345      self.evaluate(variables.global_variables_initializer())
346      self.evaluate(grad)
347
348  @test_util.run_in_graph_and_eager_modes
349  def testFoldShape(self):
350    x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
351
352    def fn(_, current_input):
353      return current_input
354
355    initializer = constant_op.constant([0, 0, 0])
356    y = functional_ops.foldl(fn, x, initializer=initializer)
357    self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
358
359  @test_util.run_in_graph_and_eager_modes
360  def testScanShape(self):
361    x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
362
363    def fn(_, current_input):
364      return current_input
365
366    initializer = constant_op.constant([0, 0, 0])
367    y = functional_ops.scan(fn, x, initializer=initializer)
368    self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
369
370  # TODO(akshayka): this test fails in eager: the iterable is of length 0 so
371  # so the body of the while loop never executes
372  @test_util.run_deprecated_v1
373  def testScanEmptyTensor(self):
374    with self.cached_session():
375      x = functional_ops.scan(
376          lambda x, _: x, math_ops.range(0), initializer=array_ops.ones([2, 4]))
377      self.assertAllEqual([0, 2, 4], x.get_shape())
378      self.assertAllEqual(x.get_shape(), self.evaluate(x).shape)
379
380  @test_util.run_deprecated_v1
381  def testScanUnknownShape(self):
382    x = array_ops.placeholder(dtypes.float32)
383    initializer = array_ops.placeholder(dtypes.float32)
384
385    def fn(_, current_input):
386      return current_input
387
388    y = functional_ops.scan(fn, x, initializer=initializer)
389    self.assertIs(None, y.get_shape().dims)
390
391  @test_util.run_deprecated_v1
392  def testScanVaryingShape(self):
393    with self.cached_session() as sess:
394      x = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 2])
395      x_t = array_ops.transpose(x)
396      # scan over dimension 0 (with shape None)
397      result = functional_ops.scan(lambda a, x: a + x, x)
398      # scanned over transposed dimension 0 (with shape 2)
399      result_t = functional_ops.scan(lambda a, x: a + x, x_t, infer_shape=False)
400      # ensure gradients can be calculated
401      result_grad = gradients_impl.gradients(result, [x])[0]
402      result_t_grad = gradients_impl.gradients(result_t, [x_t])[0]
403
404      # smoke test to ensure they all evaluate
405      sess.run([result, result_t, result_grad, result_t_grad],
406               feed_dict={x: [[1.0, 2.0]]})
407
408  @test_util.run_deprecated_v1
409  def testRemoteFunction(self):
410    worker_config = config_pb2.ConfigProto()
411    worker_config.device_count["CPU"] = 2
412    worker, _ = test_util.create_local_cluster(
413        1, 1, worker_config=worker_config)
414
415    @function.Defun(dtypes.int32, dtypes.int32)
416    def _remote_fn(a, b):
417      return math_ops.multiply(a, b)
418
419    with ops.device("/job:ps/task:0"):
420      a = variables.Variable(2, dtype=dtypes.int32)
421      b = variables.Variable(3, dtype=dtypes.int32)
422
423    with ops.device("/job:worker/replica:0/task:0/cpu:0"):
424      remote_op = functional_ops.remote_call(
425          args=[a, b],
426          Tout=[dtypes.int32],
427          f=_remote_fn,
428          target="/job:worker/replica:0/task:0/cpu:1")
429
430    with session.Session(worker[0].target) as sess:
431      self.evaluate(variables.global_variables_initializer())
432      mul = self.evaluate(remote_op)
433      self.assertEqual(mul, [6])
434
435  @test_util.run_deprecated_v1
436  def testRemoteFunctionDirectSession(self):
437    worker_config = config_pb2.ConfigProto()
438    worker_config.device_count["CPU"] = 2
439
440    @function.Defun(dtypes.int32, dtypes.int32)
441    def _remote_fn(a, b):
442      return math_ops.multiply(a, b)
443
444    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
445      a = variables.Variable(2, dtype=dtypes.int32)
446      b = variables.Variable(3, dtype=dtypes.int32)
447
448    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
449      remote_op = functional_ops.remote_call(
450          args=[a, b],
451          Tout=[dtypes.int32],
452          f=_remote_fn,
453          target="/job:localhost/replica:0/task:0/cpu:1")
454
455    with self.test_session(config=worker_config) as sess:
456      self.evaluate(variables.global_variables_initializer())
457      mul = self.evaluate(remote_op)
458      self.assertEqual(mul, [6])
459
460  @test_util.run_deprecated_v1
461  def testRemoteFunctionSameDeviceDirectSession(self):
462
463    @function.Defun(dtypes.int32, dtypes.int32)
464    def _remote_fn(a, b):
465      return math_ops.multiply(a, b)
466
467    with ops.device("/cpu:0"):
468      a = variables.Variable(2, dtype=dtypes.int32)
469      b = variables.Variable(3, dtype=dtypes.int32)
470
471    with ops.device("/cpu:0"):
472      remote_op = functional_ops.remote_call(
473          args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0")
474
475    with self.cached_session() as sess:
476      self.evaluate(variables.global_variables_initializer())
477      mul = self.evaluate(remote_op)
478      self.assertEqual(mul, [6])
479
480  @test_util.run_deprecated_v1
481  def testRemoteFunctionCPUGPU(self):
482    if not test_util.is_gpu_available():
483      self.skipTest("No GPU available")
484
485    @function.Defun(dtypes.float32, dtypes.float32)
486    def _remote_fn(a, b):
487      return math_ops.multiply(a, b)
488
489    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
490      a = variables.Variable(2, dtype=dtypes.float32)
491      b = variables.Variable(3, dtype=dtypes.float32)
492
493    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
494      remote_op = functional_ops.remote_call(
495          args=[a, b],
496          Tout=[dtypes.float32],
497          f=_remote_fn,
498          target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0
499
500    with self.cached_session() as sess:
501      self.evaluate(variables.global_variables_initializer())
502      mul = self.evaluate(remote_op)
503      self.assertEqual(mul, 9.0)
504
505  @test_util.run_deprecated_v1
506  def testRemoteFunctionGPUCPU(self):
507    if not test_util.is_gpu_available():
508      self.skipTest("No GPU available")
509
510    @function.Defun(dtypes.float32, dtypes.float32)
511    def _remote_fn(a, b):
512      return math_ops.multiply(a, b)
513
514    with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
515      a = variables.Variable(2, dtype=dtypes.float32)
516      b = variables.Variable(3, dtype=dtypes.float32)
517
518    with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
519      remote_op = functional_ops.remote_call(
520          args=[a, b],
521          Tout=[dtypes.float32],
522          f=_remote_fn,
523          target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0
524
525    with self.cached_session() as sess:
526      self.evaluate(variables.global_variables_initializer())
527      mul = self.evaluate(remote_op)
528      self.assertEqual(mul, 9.0)
529
530  @test_util.run_deprecated_v1
531  def testRemoteFunctionGPUCPUStrings(self):
532    if not test_util.is_gpu_available():
533      self.skipTest("No GPU available")
534
535    @function.Defun(dtypes.string)
536    def _remote_fn(inp):
537      return array_ops.identity(inp)
538
539    a = array_ops.constant("a")
540
541    with ops.device("/gpu:0"):
542      remote_op = functional_ops.remote_call(
543          args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0")
544
545    with self.cached_session() as sess:
546      ret = self.evaluate(remote_op)
547      self.assertAllEqual(ret, [b"a"])
548
549  @test_util.run_deprecated_v1
550  def testRemoteFunctionCrossProcess(self):
551    workers, _ = test_util.create_local_cluster(2, 1)
552
553    @function.Defun(dtypes.float32, dtypes.float32)
554    def _remote_fn(a, b):
555      return math_ops.multiply(a, b)
556
557    with ops.device("/job:ps/task:0"):
558      a = variables.Variable(2, dtype=dtypes.float32)
559      b = variables.Variable(3, dtype=dtypes.float32)
560
561    with ops.device("/job:worker/replica:0/task:0/cpu:0"):
562      remote_op = functional_ops.remote_call(
563          args=[a, b],
564          Tout=[dtypes.float32],
565          f=_remote_fn,
566          target="/job:worker/replica:0/task:1/cpu:0")[0] + 3.0
567
568    with session.Session(workers[0].target) as sess:
569      self.evaluate(variables.global_variables_initializer())
570      mul = self.evaluate(remote_op)
571      self.assertEqual(mul, 9)
572
573  @test_util.run_v2_only
574  def testRemoteFunctionCancellation(self):
575    context._reset_context()
576    logical_devices = []
577    logical_devices.append(context.LogicalDeviceConfiguration())
578    logical_devices.append(context.LogicalDeviceConfiguration())
579    framework_config.set_logical_device_configuration(
580        framework_config.list_physical_devices("CPU")[0], logical_devices)
581
582    @function.Defun(dtypes.float32)
583    def _remote_fn(v):
584      # We run two collectives here to make sure we cancel in the middle of the
585      # RemoteCall. The second one should never finish.
586      anchor = collective_ops.all_reduce_v2(
587          v, group_size=2, group_key=1, instance_key=1)
588      with ops.control_dependencies([anchor]):
589        return collective_ops.all_reduce_v2(
590            v, group_size=2, group_key=1, instance_key=2)
591
592    @eager_def_function.function
593    def run():
594      with ops.device("/cpu:0"):
595        return functional_ops.remote_call(
596            args=[constant_op.constant([1.])] + _remote_fn.captured_inputs,
597            Tout=[dtypes.float32],
598            f=_remote_fn,
599            target="/cpu:1")[0]
600
601    async_executor = executor.new_executor(enable_async=True)
602    cancel_mgr = cancellation.CancellationManager()
603    with context.executor_scope(async_executor):
604      # This should never finish.
605      cancel_mgr.get_cancelable_function(run.get_concrete_function())()
606    with ops.device("/cpu:0"):
607      collective_ops.all_reduce_v2([1.],
608                                   group_size=2,
609                                   group_key=1,
610                                   instance_key=1)
611    cancel_mgr.start_cancel()
612    with self.assertRaises(errors.CancelledError):
613      async_executor.wait()
614
615  @test_util.run_deprecated_v1
616  def testIf(self):
617
618    @function.Defun(dtypes.float32)
619    def Twice(x):
620      return x * 2
621
622    @function.Defun(dtypes.float32)
623    def Thrice(x):
624      return x * 3 + 1
625
626    with self.test_session(use_gpu=False) as sess:
627
628      x = array_ops.placeholder(dtypes.float32)
629      ret = functional_ops.If(math_ops.greater(x, 0), [x], Twice, Thrice)[0]
630
631      self.assertAllEqual(sess.run(ret, feed_dict={x: 9.}), 18.)
632      self.assertAllEqual(sess.run(ret, feed_dict={x: -8.}), -23.)
633      self.assertAllEqual(sess.run(ret, feed_dict={x: 0.}), 1.)
634
635  def testWhile(self):
636
637    for use_gpu in (True, False):
638      with ops.Graph().as_default() as g:
639
640        @function.Defun(*[dtypes.float32] * 2)
641        def Cond(n, unused_x):
642          return n > 0
643
644        @function.Defun(*[dtypes.float32] * 2)
645        def Body(n, x):
646          return n - 1, x + n
647
648        def Run(sess, n):
649          return sess.run(functional_ops.While([n, 0.], Cond, Body))[1]
650
651        with self.session(graph=g, use_gpu=use_gpu) as sess:
652          self.assertAllEqual(Run(sess, 20.), 210.)
653          self.assertAllEqual(Run(sess, 100.), 5050.)
654
655  def testToBool(self):
656    # For 0D tensors, the truthiness depends on whether the value is "zero".
657    self.assertAllEqual(gen_functional_ops.to_bool(0), False)
658    self.assertAllEqual(gen_functional_ops.to_bool(1), True)
659    self.assertAllEqual(gen_functional_ops.to_bool(42), True)
660    self.assertAllEqual(gen_functional_ops.to_bool(0.), False)
661    self.assertAllEqual(gen_functional_ops.to_bool(1.), True)
662    self.assertAllEqual(gen_functional_ops.to_bool(42.), True)
663    self.assertAllEqual(gen_functional_ops.to_bool(False), False)
664    self.assertAllEqual(gen_functional_ops.to_bool(True), True)
665    # For strings, "zero" is the empty string.
666    self.assertAllEqual(gen_functional_ops.to_bool(""), False)
667    self.assertAllEqual(gen_functional_ops.to_bool("a"), True)
668
669    # For >0D tensors, the truthiness only depends on whether there are
670    # elements or not.
671    self.assertAllEqual(gen_functional_ops.to_bool([]), False)
672    self.assertAllEqual(gen_functional_ops.to_bool([[]]), False)
673    self.assertAllEqual(gen_functional_ops.to_bool([[[]]]), False)
674    self.assertAllEqual(gen_functional_ops.to_bool([0]), True)
675    self.assertAllEqual(gen_functional_ops.to_bool([1]), True)
676    self.assertAllEqual(gen_functional_ops.to_bool([[0]]), True)
677    self.assertAllEqual(gen_functional_ops.to_bool([False]), True)
678    self.assertAllEqual(gen_functional_ops.to_bool([True]), True)
679
680  # Like above, but using int32 in order to ensure that int32 tensors don't get
681  # copied to the GPU during the application of the while.
682  def testWhileInt32(self):
683    with ops.Graph().as_default() as g:
684
685      @function.Defun(*[dtypes.int32] * 2)
686      def Cond(n, unused_x):
687        return n > 0
688
689      @function.Defun(*[dtypes.int32] * 2)
690      def Body(n, x):
691        return n - 1, x + n
692
693      def Run(sess, n):
694        return sess.run(functional_ops.While([n, 0], Cond, Body))[1]
695
696      with self.session(graph=g, use_gpu=True) as sess:
697        self.assertAllEqual(Run(sess, 20), 210)
698        self.assertAllEqual(Run(sess, 100), 5050)
699
700  @test_util.run_deprecated_v1
701  def testWhileLowering(self):
702
703    def Run(n, fetch_by_name):
704      for use_gpu in (True, False):
705        with ops.Graph().as_default() as g:
706
707          @function.Defun(*[dtypes.float32] * 2)
708          def Cond(n, unused_x):
709            return n > 0
710
711          @function.Defun(*[dtypes.float32] * 2)
712          def Body(n, x):
713            return n - 1, x + n
714
715          # outputs: [0, n*(n+1)/2]
716          outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while")
717
718          # `outputs` is the list of output tensors of the While op. We
719          # arbitrarily choose the 0th tensor to get the While op and set the
720          # lowering attribute on it.
721          outputs[0].op._set_attr("_lower_using_switch_merge",
722                                  attr_value_pb2.AttrValue(b=True))
723          if not fetch_by_name:
724            fetch = outputs[1]
725          else:
726            fetch = "my_while:1"
727        with self.session(graph=g, use_gpu=use_gpu) as sess:
728          return self.evaluate(fetch)
729
730    self.assertAllEqual(Run(20., False), 210.)
731    self.assertAllEqual(Run(20., True), 210.)
732    self.assertAllEqual(Run(100., False), 5050.)
733    self.assertAllEqual(Run(100., True), 5050.)
734
735  @test_util.run_v1_only("b/120545219")
736  @test_util.disable_xla("b/123337890")  # Different error message
737  def testWhileError(self):
738    for use_gpu in (True, False):
739      with ops.Graph().as_default() as g:
740
741        @function.Defun(*[dtypes.float32] * 2)
742        def Cond(n, unused_x):
743          return n > 0
744
745        @function.Defun(*[dtypes.float32] * 2)
746        def CondReturnsTooManyArgs(n, x):
747          return n > 0, x
748
749        @function.Defun(*[dtypes.float32] * 2)
750        def Body(n, x):
751          return n - 1, x + n
752
753        @function.Defun(*[dtypes.float32] * 2)
754        def BodyReturnsTooManyArgs(n, x):
755          return n - 1, x + n, x
756
757        with self.session(graph=g, use_gpu=use_gpu):
758          with self.assertRaisesRegex(
759              errors.InvalidArgumentError,
760              "Expected a single scalar.*got 2 tensors."):
761            functional_ops.While([5., 0.], CondReturnsTooManyArgs,
762                                 Body)[0].eval()
763          with self.assertRaisesRegex(
764              errors.InvalidArgumentError,
765              "While loop body returned 3 arguments. Expected: 2"):
766            functional_ops.While([5., 0.], Cond,
767                                 BodyReturnsTooManyArgs)[0].eval()
768
769  def testWhileInMultipleSubgraphs(self):
770
771    for use_gpu in (True, False):
772      with ops.Graph().as_default() as g:
773
774        @function.Defun(*[dtypes.float32] * 2)
775        def Cond(n, x):  # pylint: disable=unused-argument
776          return n > 0
777
778        @function.Defun(*[dtypes.float32] * 2)
779        def Body(n, x):
780          return n - 1, x + n
781
782        with self.session(graph=g, use_gpu=use_gpu) as sess:
783          n = array_ops.placeholder(dtypes.float32)
784          _, result = functional_ops.While([n, 0.], Cond, Body)
785          c = constant_op.constant(37.)
786
787          self.assertAllEqual(210., sess.run(result, feed_dict={n: 20.}))
788          self.assertAllEqual(5050., sess.run(result, feed_dict={n: 100.}))
789          # Test that the result is the same when we run a different subgraph.
790          self.assertAllEqual(5050.,
791                              sess.run([result, c], feed_dict={n: 100.})[0])
792
793  # pylint: disable=cell-var-from-loop
794  def testWhileCapturedInputs(self):
795    for use_gpu in (True, False):
796      with ops.Graph().as_default() as g:
797        v = variables.Variable(1.0)
798
799        def TestCond(n, *args):
800          del args
801          return n < 10
802
803        @function.Defun(*[dtypes.float32] * 2)
804        def TestUnary(n, x):
805          return math_ops.add(n, 1), x + n + v
806
807        @function.Defun(*[dtypes.float32] * 3)
808        def TestBinary(n, x, x2):
809          return math_ops.add(n, 1), x + n + v, x2 + v
810
811        with self.session(graph=g, use_gpu=use_gpu) as sess:
812          result_unary = functional_ops.While(
813              [1.0, 0.],
814              function.Defun(*[dtypes.float32] * 2)(TestCond), TestUnary)
815          result_binary = functional_ops.While(
816              [1.0, 0., 0.],
817              function.Defun(*[dtypes.float32] * 3)(TestCond), TestBinary)
818          self.evaluate(variables.global_variables_initializer())
819          assert len(result_unary) == 2
820          self.assertEqual([10.0, 54.0], self.evaluate(result_unary))
821          assert len(result_binary) == 3
822          self.assertEqual([10.0, 54.0, 9.0], self.evaluate(result_binary))
823
824          def TestCondCapture(n, *args):
825            del args
826            return math_ops.cast(n, dtypes.float32) + v < 10
827
828          with self.assertRaises(ValueError):
829            _ = functional_ops.While(
830                [1],
831                function.Defun(dtypes.int32)(TestCondCapture),
832                function.Defun(dtypes.int32, dtypes.float32)(TestUnary))
833
834  # pylint: enable=cell-var-from-loop
835
836  def _tfSum(self, use_gpu, rewrite_with_while):
837    with ops.Graph().as_default() as g:
838      with self.session(graph=g, use_gpu=use_gpu) as sess:
839
840        @function.Defun(dtypes.int32, dtypes.float32)
841        def Body(n, x):
842          return x + math_ops.cast(n, dtypes.float32)
843
844        xs = [
845            # 1 + 2  + ... + 20
846            functional_ops.For(
847                1, 21, 1, [0.], Body, rewrite_with_while=rewrite_with_while)[0],
848            # 100 + 99 + ... + 1
849            functional_ops.For(
850                100, 0, -1, [0.], Body, rewrite_with_while=rewrite_with_while)
851            [0],
852        ]
853        xvals = self.evaluate(xs)
854      self.assertAllEqual(210, xvals[0])
855      self.assertAllEqual(5050, xvals[1])
856
857  def testFor(self):
858    for use_gpu in (True, False):
859      self._tfSum(use_gpu, False)
860
861  def testForWithWhile(self):
862    for use_gpu in (True, False):
863      self._tfSum(use_gpu, True)
864
865  def testForWithWhileNaming(self):
866    g = ops.Graph()
867    with g.as_default():
868
869      @function.Defun(dtypes.int32, dtypes.float32, func_name="TestBody")
870      def TestBody(n, x):
871        return x + math_ops.cast(n, dtypes.float32)
872
873      _ = functional_ops.For(
874          1, 21, 1, [0.], TestBody, rewrite_with_while=True)[0]
875
876    names = []
877    for func in g.as_graph_def().library.function:
878      names.append(func.signature.name)
879    self.assertTrue("TestBody" in names)
880    self.assertTrue("TestBody_Cond" in names)
881    self.assertTrue("TestBody_Body" in names)
882
883  @test_util.run_deprecated_v1
884  def testForCapturedInputs(self):
885    v = variables.Variable(1.0)
886
887    @function.Defun(dtypes.int32)
888    def TestNullary(n):
889      v + math_ops.cast(n, dtypes.float32)  # pylint: disable=expression-not-assigned
890
891    @function.Defun(dtypes.int32, dtypes.float32)
892    def TestUnary(n, x):
893      return x + math_ops.cast(n, dtypes.float32) + v
894
895    @function.Defun(dtypes.int32, dtypes.float32, dtypes.float32)
896    def TestBinary(n, x, x2):
897      return x + math_ops.cast(n, dtypes.float32) + v, x2 + v
898
899    for rewrite_with_while in (True, False):
900      use_gpu = not rewrite_with_while
901      with self.test_session(use_gpu=use_gpu) as sess:
902        result_nullary = functional_ops.For(
903            1, 10, 1, [], TestNullary,
904            rewrite_with_while=rewrite_with_while)
905        result_unary = functional_ops.For(
906            1, 10, 1, [0.], TestUnary,
907            rewrite_with_while=rewrite_with_while)
908        result_binary = functional_ops.For(
909            1, 10, 1, [0., 0.], TestBinary,
910            rewrite_with_while=rewrite_with_while)
911        self.evaluate(variables.global_variables_initializer())
912        assert not result_nullary
913        # The nullary variant doesn't return anything so we can't easily run it.
914        # As a total hack, fetch the operation by name and run it.
915        sess.run(ops.get_default_graph().get_operation_by_name(
916            "While" if rewrite_with_while else "For"))
917        assert len(result_unary) == 1
918        self.assertEqual([54.0], self.evaluate(result_unary))
919        assert len(result_binary) == 2
920        self.assertEqual([54.0, 9.0], self.evaluate(result_binary))
921
922  def _tfMLP(self, xval, wsval, bsval, rewrite_with_while):
923    # On GPU, don't rewrite using a while loop.
924    use_gpu = not rewrite_with_while
925    with self.test_session(use_gpu=use_gpu):
926
927      @function.Defun(dtypes.int32, *[dtypes.float64] * 3)
928      def MLP(i, a, ws, bs):
929        a = math_ops.tanh(math_ops.matmul(a, ws[i, :]) + bs[i, :])
930        return a, ws, bs
931
932      ret = functional_ops.For(
933          0,
934          wsval.shape[0],
935          1, [xval, wsval, bsval],
936          MLP,
937          rewrite_with_while=rewrite_with_while)[0]
938
939      return self.evaluate(ret)
940
941  def _npMLP(self, xval, wsval, bsval):
942    for i in range(wsval.shape[0]):
943      xval = np.tanh(np.dot(xval, wsval[i, :]) + bsval[i, :])
944    return xval
945
946  def _testForMLP(self, rewrite_with_while):
947    # We construct a 5-layer Multi-Layer Perceptron network here.
948    # Each layer have the same number of hidden unites (3), and the
949    # activation function is tanh().  We feed the input (xval) with
950    # batch size 2.
951    xval = np.random.normal(size=(2, 3))
952    wsval = np.random.normal(size=(5, 3, 3))
953    bsval = np.random.normal(size=(5, 3))
954    np_ans = self._npMLP(xval, wsval, bsval)
955    tf_for_ans = self._tfMLP(xval, wsval, bsval, rewrite_with_while)
956    self.assertAllClose(np_ans, tf_for_ans)
957
958  @test_util.run_deprecated_v1
959  def testForMLP(self):
960    self._testForMLP(False)
961
962  @test_util.run_deprecated_v1
963  @test_util.disable_xla(
964      "Test uses strided slice without compile time constant values")
965  def testForMLPWhile(self):
966    self._testForMLP(True)
967
968  @test_util.run_v1_only("b/120545219")
969  def testForError(self):
970
971    @function.Defun(dtypes.int32, dtypes.float32)
972    def Foo(i, v):
973      return math_ops.cast(i, dtypes.float32) + v
974
975    @function.Defun(dtypes.int32, dtypes.float32)
976    def ReturnsTooManyArgs(unused_i, v):
977      return v, v
978
979    with self.test_session():
980      with self.assertRaisesRegex(errors.InvalidArgumentError,
981                                  "must be a scalar"):
982        functional_ops.For([0], 10, 1, [0.0], Foo)[0].eval()
983      with self.assertRaisesRegex(errors.InvalidArgumentError,
984                                  "Invalid start/limit/delta"):
985        functional_ops.For(0, 10, -1, [0.0], Foo)[0].eval()
986      with self.assertRaisesRegex(
987          errors.InvalidArgumentError,
988          "For loop body returned 2 arguments. Expected: 1"):
989        functional_ops.For(0, 10, 1, [0.0], ReturnsTooManyArgs)[0].eval()
990
991  @test_util.run_deprecated_v1
992  def testGradient(self):
993
994    @function.Defun(dtypes.float32)
995    def Poly(x):
996      # y = 2x^3+3x^2+4x+8
997      return 2 * x * x * x + 3 * x * x + 4 * x + 8
998
999    @function.Defun(dtypes.float32)
1000    def Grad(x):
1001      # dy/dx = dy/dy * dy/dx = 1.0 * (6x^2+6x+4)
1002      return functional_ops.Gradient([x, 1.0], Poly)[0]
1003
1004    with self.test_session(use_gpu=False) as sess:
1005      a = constant_op.constant(0.)
1006      avals = [Poly(a), Grad(a)]
1007      b = constant_op.constant(1.)
1008      bvals = [Poly(b), Grad(b)]
1009      self.assertAllEqual(self.evaluate(avals), [8., 4.])
1010      self.assertAllEqual(self.evaluate(bvals), [17., 16.])
1011
1012  @test_util.run_v2_only
1013  def testCollective(self):
1014    context._reset_context()
1015    logical_devices = []
1016    logical_devices.append(context.LogicalDeviceConfiguration())
1017    logical_devices.append(context.LogicalDeviceConfiguration())
1018    framework_config.set_logical_device_configuration(
1019        framework_config.list_physical_devices("CPU")[0], logical_devices)
1020
1021    @function.Defun(dtypes.float32)
1022    def collective_fn(t):
1023      # Run a dummy collective of group size 1 to test the setup.
1024      return collective_ops.all_reduce_v2(
1025          t, group_size=1, group_key=1, instance_key=1)
1026
1027    @eager_def_function.function
1028    def run():
1029      with ops.device("/cpu:0"):
1030        return functional_ops.remote_call(
1031            args=[constant_op.constant([1.])] + collective_fn.captured_inputs,
1032            Tout=[dtypes.float32],
1033            f=collective_fn,
1034            target="/cpu:1")
1035
1036    self.assertAllEqual(run(), [[1.]])
1037
1038
1039# TODO(akshayka): Replace `function.Defun` with tf.contrib.eager.defun` in the
1040# below test cases.
1041class PartitionedCallTest(test.TestCase):
1042
1043  @test_util.run_deprecated_v1
1044  def testRemoteDeviceInPartitionedCallOp(self):
1045    workers, _ = test_util.create_local_cluster(2, 0)
1046
1047    worker0_device = "/job:worker/replica:0/task:0/cpu:0"
1048    worker1_device = "/job:worker/replica:0/task:1/cpu:0"
1049
1050    @eager_def_function.function
1051    def f(a, b):
1052      return a + b
1053
1054    with session.Session(workers[0].target) as sess:
1055      with ops.device(worker0_device):
1056        a = variable_scope.get_variable(
1057            "a", initializer=constant_op.constant(1.), use_resource=True)
1058      with ops.device(worker1_device):
1059        b = variable_scope.get_variable(
1060            "b", initializer=constant_op.constant(1.), use_resource=True)
1061
1062      sess.run(variables.global_variables_initializer())
1063
1064    config = config_pb2.ConfigProto()
1065    config.share_cluster_devices_in_session = True
1066
1067    with session.Session(workers[0].target, config=config) as sess:
1068      res = sess.run(f(a, b))
1069
1070    self.assertEqual(res, 2)
1071
1072  @test_util.run_deprecated_v1
1073  def testBasicSingleDevice(self):
1074
1075    @function.Defun(*[dtypes.float32] * 2)
1076    def Body(x, y):
1077      with ops.device("/cpu:0"):
1078        a = x + x
1079        b = y + y
1080        return a + b
1081
1082    output, = self.evaluate(
1083        functional_ops.partitioned_call(
1084            args=[constant_op.constant(1.),
1085                  constant_op.constant(2.)], f=Body))
1086    self.assertEqual(output, 6.)
1087
1088  @test_util.run_deprecated_v1
1089  def testBasicMultiDevice(self):
1090    config = config_pb2.ConfigProto(device_count={"CPU": 3})
1091
1092    @function.Defun(*[dtypes.float32] * 2)
1093    def Body(x, y):
1094      # if x = 1, y = 2, ...
1095      with ops.device("/cpu:0"):
1096        # a:= 1 + 1 = 2
1097        a = x + x
1098      with ops.device("/cpu:1"):
1099        # b:= 2 + 2 = 4
1100        b = a + y
1101      with ops.device("/cpu:2"):
1102        # c:= 2 + 4 = 6
1103        c = a + b
1104      # a + b + c = 2 + 4 + 6 = 12
1105      return a + b + c
1106
1107    with self.test_session(config=config):
1108      output, = functional_ops.partitioned_call(
1109          args=[constant_op.constant(1.),
1110                constant_op.constant(2.)], f=Body)
1111      self.assertEqual(self.evaluate(output), 12.)
1112
1113  @test_util.run_deprecated_v1
1114  def testBasicMultiDeviceGPU(self):
1115    if not test_util.is_gpu_available():
1116      return
1117
1118    @function.Defun(*[dtypes.float32] * 2)
1119    def Body(x, y):
1120      with ops.device("/gpu:0"):
1121        a = x + x
1122        b = y + y
1123      with ops.device("/cpu:0"):
1124        c = a + b
1125        return c
1126
1127    output, = self.evaluate(
1128        functional_ops.partitioned_call(
1129            args=[constant_op.constant(1.),
1130                  constant_op.constant(2.)], f=Body))
1131    self.assertEqual(output, 6.)
1132
1133  @test_util.run_deprecated_v1
1134  def testBasicNoDeviceAnnotations(self):
1135
1136    @function.Defun(*[dtypes.float32] * 2)
1137    def Body(x, y):
1138      a = x + x
1139      b = y + y
1140      return a + b
1141
1142    output, = self.evaluate(
1143        functional_ops.partitioned_call(
1144            args=[constant_op.constant(1.),
1145                  constant_op.constant(2.)], f=Body))
1146    self.assertEqual(output, 6.)
1147
1148  @test_util.run_deprecated_v1
1149  def testShardsRunOnRequestedDevices(self):
1150    config = config_pb2.ConfigProto(device_count={"CPU": 4})
1151
1152    @function.Defun()
1153    def Body():
1154      # Serialize DT_RESOURCE handles as DT_STRINGs, which encode the device on
1155      # which the resource was created, so that we can verify that ops were
1156      # actually run on the requested devices.
1157      #
1158      # TODO(akshayka): Provide a cleaner, more idiomatic API for obtaining the
1159      # name of the device on which a resource lives / for determining the
1160      # device on which an op ran.
1161      with ops.device("/cpu:0"):
1162        s1 = iterator_ops.Iterator.from_structure(
1163            (dtypes.float32,)).string_handle()
1164      with ops.device("/cpu:1"):
1165        s2 = iterator_ops.Iterator.from_structure(
1166            (dtypes.float32,)).string_handle()
1167      with ops.device("/cpu:2"):
1168        s3 = iterator_ops.Iterator.from_structure(
1169            (dtypes.float32,)).string_handle()
1170      return s1, s2, s3
1171
1172    with self.test_session(config=config, use_gpu=True) as sess:
1173      outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body))
1174    self.assertIn(compat.as_bytes("CPU:0"), outputs[0])
1175    self.assertIn(compat.as_bytes("CPU:1"), outputs[1])
1176    self.assertIn(compat.as_bytes("CPU:2"), outputs[2])
1177
1178  @test_util.run_deprecated_v1
1179  def testAssignAddResourceVariable(self):
1180
1181    v = resource_variable_ops.ResourceVariable(1.0)
1182
1183    @function.Defun()
1184    def AssignAdd():
1185      v.assign_add(1.0)
1186
1187    op = functional_ops.partitioned_call(
1188        args=AssignAdd.captured_inputs, f=AssignAdd)
1189    _ = self.evaluate(variables.global_variables_initializer())
1190    _ = self.evaluate(op)
1191    value = self.evaluate(v.read_value())
1192    self.assertEqual(value, 2.0)
1193
1194  @test_util.run_deprecated_v1
1195  def testFunctionWithResourcesOnDifferentDevices(self):
1196    if not test_util.is_gpu_available():
1197      self.skipTest("No GPUs available.")
1198
1199    with ops.device("/cpu:0"):
1200      v_cpu_zero = resource_variable_ops.ResourceVariable(
1201          [0.0, 1.0, 2.0], name="v_cpu_zero")
1202
1203    with ops.device("/cpu:1"):
1204      v_cpu_one = resource_variable_ops.ResourceVariable(
1205          [0.0, 1.0, 2.0], name="v_cpu_one")
1206
1207    with ops.device("/gpu:0"):
1208      v_gpu = resource_variable_ops.ResourceVariable(
1209          [0.0, 1.0, 2.0], name="v_gpu")
1210
1211    def sum_gather():
1212      cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_zero, [1, 2]))
1213      also_cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_one, [1, 2]))
1214      gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
1215      return cpu_result, also_cpu_result, gpu_result
1216
1217    defined = function.Defun()(sum_gather)
1218    with self.test_session(
1219        config=config_pb2.ConfigProto(
1220            allow_soft_placement=False,
1221            log_device_placement=True,
1222            device_count={"CPU": 2})) as sess:
1223      self.evaluate(variables.global_variables_initializer())
1224      expected = self.evaluate(sum_gather())
1225      result = sess.run(
1226          functional_ops.partitioned_call(
1227              args=defined.captured_inputs, f=defined))
1228      self.assertAllEqual(expected, result)
1229
1230  # Use an invalid executor name to test the plumbing of the executor_type attr.
1231  @test_util.run_v1_only("b/120545219")
1232  def testExecutorTypeAttrExecutorNotFound(self):
1233    @function.Defun(dtypes.int32)
1234    def AddFive(x):
1235      return x + 5
1236
1237    op = functional_ops.partitioned_call(
1238        args=[constant_op.constant([1, 2, 3], dtype=dtypes.int32)],
1239        f=AddFive,
1240        executor_type="NON_EXISTENT_EXECUTOR")
1241    with self.assertRaisesRegex(errors.NotFoundError, "NON_EXISTENT_EXECUTOR"):
1242      self.evaluate(op)
1243
1244
1245@test_util.run_all_in_graph_and_eager_modes
1246@test_util.with_control_flow_v2
1247class FunctionalOpsCaseTest(test.TestCase):
1248
1249  def testCase(self):
1250    @eager_function.defun
1251    def two(x):
1252      return x * 2
1253
1254    @eager_function.defun
1255    def three(x):
1256      return x * 3
1257
1258    @eager_function.defun
1259    def four(x):
1260      return x * 4
1261
1262    def f(branch, x):
1263      tmpl = array_ops.zeros_like(x)
1264      return array_ops.identity(gen_functional_ops.case(
1265          branch, input=[x], Tout=[dtypes.float32],
1266          branches=[f.get_concrete_function(tmpl)
1267                    for f in (two, three, four)])[0])
1268    one = array_ops.ones([])
1269    self.assertAllEqual(np.float32(2), self.evaluate(f(0, one)))
1270    self.assertAllEqual(np.float32(3), self.evaluate(f(1, one)))
1271    self.assertAllEqual(np.float32(4), self.evaluate(f(2, one)))
1272    self.assertAllEqual(np.float32(4), self.evaluate(f(-1, one)))  # <0 default
1273    self.assertAllEqual(np.float32(4), self.evaluate(f(6, one)))  # >=N default
1274
1275  @test_util.run_deprecated_v1
1276  @test_util.disable_xla("Don't lower for XLA")
1277  def testSkipEagerCaseLoweringPreservesNameForFetch(self):
1278    for use_gpu in (True, False):
1279      def Run(branch, x, fetch_by_name, use_gpu=use_gpu):
1280        with ops.Graph().as_default() as g:
1281          @function.Defun(dtypes.float32)
1282          def two(x):
1283            return -1, x * 2
1284
1285          @function.Defun(dtypes.float32)
1286          def three(x):
1287            return 0, x * 3
1288
1289          @function.Defun(dtypes.float32)
1290          def four(x):
1291            return 1, x * 4
1292
1293          outputs = gen_functional_ops.case(branch, input=[x],
1294                                            Tout=[dtypes.int32, dtypes.float32],
1295                                            branches=[two, three, four],
1296                                            name="my_case")
1297
1298          # `outputs` is the list of output tensors of the Case op. We
1299          # arbitrarily choose the 0th tensor to get the Case op and set the
1300          # lowering attribute on it.
1301          outputs[0].op._set_attr("_lower_using_switch_merge",
1302                                  attr_value_pb2.AttrValue(b=True))
1303          outputs = array_ops.identity_n(outputs)
1304        with self.session(graph=g, use_gpu=use_gpu) as sess:
1305          return sess.run("my_case:1" if fetch_by_name else outputs[1])
1306
1307      self.assertAllEqual(2 * 1., Run(0, 1., False))
1308      self.assertAllEqual(2 * 1., Run(0, 1., True))
1309      self.assertAllEqual(3 * 7., Run(1, 7., False))
1310      self.assertAllEqual(3 * 7., Run(1, 7., True))
1311      self.assertAllEqual(4 * -3., Run(2, -3., False))
1312      self.assertAllEqual(4 * -3., Run(2, -3., True))
1313      self.assertAllEqual(4 * -4., Run(7, -4., False))  # >= N default
1314      self.assertAllEqual(4 * -4., Run(7, -4., True))  # >= N default
1315      self.assertAllEqual(4 * -5., Run(-1, -5., False))  # <0 default
1316      self.assertAllEqual(4 * -5., Run(-1, -5., True))  # <0 default
1317
1318  @test_util.disable_xla("Don't lower for XLA")
1319  def testCaseLowering(self):
1320    for use_gpu in (True, False):
1321      @eager_function.defun
1322      def Run(branch, x):
1323        @function.Defun(dtypes.float32)
1324        def two(x):
1325          return -1, x * 2
1326
1327        @function.Defun(dtypes.float32)
1328        def three(x):
1329          return 0, x * 3
1330
1331        @function.Defun(dtypes.float32)
1332        def four(x):
1333          return 1, x * 4
1334
1335        outputs = gen_functional_ops.case(branch, input=[x],
1336                                          Tout=[dtypes.int32, dtypes.float32],
1337                                          branches=[two, three, four])
1338
1339        # `outputs` is the list of output tensors of the Case op. We
1340        # arbitrarily choose the 0th tensor to get the Case op and set the
1341        # lowering attribute on it.
1342        outputs[0].op._set_attr("_lower_using_switch_merge",
1343                                attr_value_pb2.AttrValue(b=True))
1344        outputs = array_ops.identity_n(outputs)
1345        return outputs[1]
1346
1347      with ops.device(test.gpu_device_name() if use_gpu else "CPU:0"):
1348        self.assertAllEqual(2 * 1., self.evaluate(Run(0, 1.)))
1349        self.assertAllEqual(3 * 7., self.evaluate(Run(1, 7.)))
1350        self.assertAllEqual(4 * -3., self.evaluate(Run(2, -3.)))
1351        self.assertAllEqual(4 * -4., self.evaluate(Run(7, -4.)))  # >=N default
1352        self.assertAllEqual(4 * -5., self.evaluate(Run(-1, -5.)))  # <0 default
1353
1354if __name__ == "__main__":
1355  test.main()
1356
1357# pylint: enable=invalid-name
1358