xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""XLA tests for pfor."""
16# pylint: disable=g-direct-tensorflow-import
17
18from tensorflow.compiler.tf2xla.python import xla as xla_ops
19from tensorflow.python.compiler.xla import jit
20from tensorflow.python.compiler.xla import xla
21from tensorflow.python.eager import context
22from tensorflow.python.eager import def_function
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import control_flow_v2_toggles
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import random_ops
30from tensorflow.python.ops import resource_variable_ops
31from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
32from tensorflow.python.ops.parallel_for.test_util import PForTestCase
33from tensorflow.python.platform import test
34
35
36@test_util.run_all_in_graph_and_eager_modes
37class PForTest(PForTestCase):
38
39  def __init__(self, method_name="runTest"):
40    super(PForTest, self).__init__(method_name)
41    context.context().enable_xla_devices()
42
43  def test_xla_einsum(self):
44    num_loop = 10
45    x_series = random_ops.random_uniform([num_loop, 9, 9])
46    y_series = random_ops.random_uniform([num_loop, 9, 1])
47
48    def loop_fn(i):
49      x = array_ops.gather(x_series, 0)  # invariant.
50      y = array_ops.gather(y_series, 0)  # invariant.
51      x_i = array_ops.gather(x_series, i)
52      y_i = array_ops.gather(y_series, i)
53      z1 = xla_ops.einsum(x_i, y, "ab,bc->ac")
54      z2 = xla_ops.einsum(x, y_i, "ab,bc->ac")
55      z3 = xla_ops.einsum(x, y, "ab,bc->ac")
56      z4 = xla_ops.einsum(x_i, y_i, "ab,bc->ac")
57      z5 = xla_ops.einsum(y_i, x_i, "cd,ce->de")  # Includes transpose.
58      outputs = [z1, z2, z3, z4, z5]
59      return outputs
60
61    self._test_loop_fn(loop_fn, num_loop)
62
63  def test_xla(self):
64
65    def compute(x):
66      return math_ops.reduce_mean(x, axis=0, keepdims=True)
67
68    def vectorized_compute(x):
69      return pfor_control_flow_ops.vectorized_map(compute, x)
70
71    result = xla.compile(
72        vectorized_compute, inputs=[array_ops.ones((10, 5, 3))])
73    self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
74
75  def test_function_jit_compile(self):
76
77    def compute(x):
78      return math_ops.reduce_mean(x, axis=0, keepdims=True)
79
80    @def_function.function(jit_compile=True)
81    def vectorized_compute(x):
82      return pfor_control_flow_ops.vectorized_map(compute, x)
83
84    result = vectorized_compute(array_ops.ones((10, 5, 3)))
85    self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
86
87  def test_xla_while_loop(self):
88
89    def compute(x):
90      return math_ops.reduce_mean(x, axis=0, keepdims=True)
91
92    def vectorized_compute(x, i):
93      inp = array_ops.gather(x, i)
94      output = pfor_control_flow_ops.vectorized_map(compute, inp)
95      output.set_shape([5, 1])
96      return output
97
98    def while_compute(x):
99      return control_flow_ops.while_loop_v2(
100          lambda i, _: i < 10,
101          lambda i, y: (i + 1, y + vectorized_compute(x, i)),
102          (0, array_ops.zeros([5, 1])))[1]
103
104    result = xla.compile(while_compute, inputs=[array_ops.ones((10, 5, 3))])
105    expected = array_ops.ones([5, 1]) * 10
106    self.run_and_assert_equal(expected, result)
107
108  def test_reduce_mean(self):
109    x = random_ops.random_uniform([8, 3])
110
111    @def_function.function(jit_compile=True)
112    def f():
113
114      def loop_fn(i, pfor_config):
115        x_i = array_ops.gather(x, i)
116        return x_i - pfor_config.reduce_mean(x_i)
117
118      return pfor_control_flow_ops.pfor(loop_fn, 8)
119
120    output = f()
121    ans = x - math_ops.reduce_mean(x, axis=0)
122    output_val, ans_val = self.evaluate([output, ans])
123    self.assertAllClose(ans_val, output_val)
124
125
126def _make_unstacked(cond, body, pfor_config):
127
128  def _cond(*args):
129    return math_ops.reduce_any(pfor_config.reduce_concat(args[0]))
130
131  def _body(*args):
132    not_done = args[0]
133    args = args[1:]
134    not_done = math_ops.logical_and(not_done, cond(*args))
135    outputs = body(*args)
136    return (not_done,) + tuple(
137        array_ops.where_v2(not_done, x, y) for x, y in zip(outputs, args))
138
139  return _cond, _body
140
141
142@test_util.run_all_in_graph_and_eager_modes
143class WhileV2Test(PForTestCase):
144
145  def setUp(self):
146    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
147    control_flow_v2_toggles.enable_control_flow_v2()
148    super(WhileV2Test, self).setUp()
149
150  def tearDown(self):
151    if not self._enabled:
152      control_flow_v2_toggles.disable_control_flow_v2()
153    super(WhileV2Test, self).tearDown()
154
155  def _test_loop_fn(self, loop_fn, iters, force_xla=False):
156
157    def f():
158      return pfor_control_flow_ops.pfor(loop_fn, iters)
159
160    @def_function.function
161    def jit_f():
162      with jit.experimental_jit_scope():
163        return f()
164
165    out = f()
166    jit_out = jit_f()
167    self.run_and_assert_equal(out, jit_out)
168    # TODO(agarwal): The following may complain about uncompilable nodes. Hence
169    # these are currently not enabled for all tests.
170    if force_xla:
171      out_exp_compile_f = def_function.function(jit_compile=True)(f)()
172      self.run_and_assert_equal(out, out_exp_compile_f)
173      out_xla_compile_f = xla.compile(f, inputs=[])
174      self.run_and_assert_equal(out, out_xla_compile_f)
175
176  def test_stateless_while(self):
177    x = random_ops.random_uniform([3, 5])
178    lengths = constant_op.constant([4, 0, 2])
179
180    def loop_fn(i):
181      x_i = array_ops.gather(x, i)
182      lengths_i = array_ops.gather(lengths, i)
183
184      return control_flow_ops.while_loop(
185          lambda j, _: j < lengths_i,
186          lambda j, t: (j + 1, t + array_ops.gather(x_i, j)),
187          [0, 0.])
188
189    self._test_loop_fn(loop_fn, 3)
190
191  def test_while_with_variable(self):
192    if not context.executing_eagerly():
193      self.skipTest("Flaky with tf.Session")
194
195    v = resource_variable_ops.ResourceVariable(5.)
196
197    def loop_fn(_):
198      _, output = control_flow_ops.while_loop(
199          lambda j, x: j < 4,
200          lambda j, x: (j + 1, x + v),
201          [0, 0.])
202      return output
203
204    self._test_loop_fn(loop_fn, 3)
205
206  def test_while_unstacked_condition(self):
207
208    def loop_fn(i):
209      return control_flow_ops.while_loop(
210          lambda j, x: j < 4,
211          lambda j, x: (j + 1, x + i), [0, 0])
212
213    self._test_loop_fn(loop_fn, 3, force_xla=True)
214
215  def test_while_force_unstacked_condition(self):
216    # The while_loop in this setup is similar to the one in test_stateless_while
217    # whose condition is loop variant. However here we wrap the cond and body of
218    # the loop in a way that makes the while_loop condition pfor loop invariant.
219    # This allows xla compilation to work since the vectorized code no longer
220    # needs to perform dynamic partitioning of the inputs.
221    x = random_ops.random_uniform([3, 5])
222    lengths = constant_op.constant([4, 0, 2])
223
224    def loop_fn(i, pfor_config):
225      x_i = array_ops.gather(x, i)
226      lengths_i = array_ops.gather(lengths, i)
227
228      def _cond(j, _):
229        return j < lengths_i
230
231      def _body(j, t):
232        return (j + 1, t + array_ops.gather(x_i, j))
233
234      cond, body = _make_unstacked(_cond, _body, pfor_config)
235      return control_flow_ops.while_loop(
236          cond,
237          body,
238          [True, 0, 0.])
239
240    self._test_loop_fn(loop_fn, 3, force_xla=True)
241
242
243if __name__ == "__main__":
244  test.main()
245